diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..dd298b0da9755afc1882a9787a0dcca7f67af6ae --- /dev/null +++ b/.gitattributes @@ -0,0 +1,95 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_fd27ff6.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_6e51d70.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cu128-x86_64-windows/sage_attention/_sage_attention_ac695bf.pyd filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4eabbf5.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cu128-x86_64-windows/_sage_attention_cuda_554dbc8.pyd filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu126-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_5568690.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cu128-x86_64-windows/_sage_attention_cuda_a8f8348.pyd filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu129-aarch64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch211-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu129-x86_64-linux/_sage_attention_cuda_4597889.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cab37c4e8ea93ee03c97821d53342e399cb89757 --- /dev/null +++ b/README.md @@ -0,0 +1,53 @@ +--- +library_name: kernels +license: apache-2.0 +--- + + + + +This is the repository card of {repo_id} that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated. + + +## How to use + +```python +# make sure `kernels` is installed: `pip install -U kernels` +from kernels import get_kernel + +kernel_module = get_kernel("kernels-community/sage-attention") # <- change the ID if needed +per_block_int8 = kernel_module.per_block_int8 + +per_block_int8(...) +``` + +## Available functions + +- `per_block_int8` +- `per_warp_int8` +- `sub_mean` +- `per_channel_fp8` +- `sageattn` + +## Supported backends + +- cuda + +## CUDA Capabilities + +- 8.0 +- 8.9 +- 9.0a + +## Benchmarks + +[TODO: provide benchmarks if available] + +## Source code + +[TODO: provide original source code and other relevant citations if available] + +## Notes + +[TODO: provide additional notes about this kernel if needed] \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..b6832d8439d3f899d38c74f289fd8af40308f6ac --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c6b02e8658941d4c5a1008993fd41fbf31d6a38300a18077661e34f03fb30fe +size 33330136 diff --git a/build/torch210-cxx11-cu128-aarch64-linux/core.py b/build/torch210-cxx11-cu128-aarch64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu128-aarch64-linux/quant.py b/build/torch210-cxx11-cu128-aarch64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch210-cxx11-cu128-aarch64-linux/quant_per_thread.py b/build/torch210-cxx11-cu128-aarch64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-aarch64-linux/sage_attention/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/sm100_compile.py b/build/torch210-cxx11-cu128-aarch64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch210-cxx11-cu128-aarch64-linux/sm80_compile.py b/build/torch210-cxx11-cu128-aarch64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu128-aarch64-linux/sm89_compile.py b/build/torch210-cxx11-cu128-aarch64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu128-aarch64-linux/sm90_compile.py b/build/torch210-cxx11-cu128-aarch64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..6838b1fb1de844b62c47de28e963a7af53004f19 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ee6c9f8117e0b7f9e51165bbb83a4b7da8a94924e3ab171d6858a733725adb2 +size 33431488 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/core.py b/build/torch210-cxx11-cu128-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu128-x86_64-linux/quant.py b/build/torch210-cxx11-cu128-x86_64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch210-cxx11-cu128-x86_64-linux/quant_per_thread.py b/build/torch210-cxx11-cu128-x86_64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch210-cxx11-cu128-x86_64-linux/sage_attention/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/sm100_compile.py b/build/torch210-cxx11-cu128-x86_64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/sm80_compile.py b/build/torch210-cxx11-cu128-x86_64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu128-x86_64-linux/sm89_compile.py b/build/torch210-cxx11-cu128-x86_64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu128-x86_64-linux/sm90_compile.py b/build/torch210-cxx11-cu128-x86_64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..2dc8004aaf946e344db56c2ffb2a3927f581cb0d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6593ed6bd82b9e0d24ce7b69d0899e1854f04a0d4af9c33fae5de94e1cbf4239 +size 33875224 diff --git a/build/torch210-cxx11-cu130-aarch64-linux/core.py b/build/torch210-cxx11-cu130-aarch64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/quant.py b/build/torch210-cxx11-cu130-aarch64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch210-cxx11-cu130-aarch64-linux/quant_per_thread.py b/build/torch210-cxx11-cu130-aarch64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-aarch64-linux/sage_attention/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/sm100_compile.py b/build/torch210-cxx11-cu130-aarch64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch210-cxx11-cu130-aarch64-linux/sm80_compile.py b/build/torch210-cxx11-cu130-aarch64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu130-aarch64-linux/sm89_compile.py b/build/torch210-cxx11-cu130-aarch64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu130-aarch64-linux/sm90_compile.py b/build/torch210-cxx11-cu130-aarch64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..07b7022bd8d78574aa271a970e35ce0eedbea81f --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee3c2c8be2231bfbe36c760a5417d741d249a807e8e3f0ec1216efce94167c00 +size 34165352 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/core.py b/build/torch210-cxx11-cu130-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu130-x86_64-linux/quant.py b/build/torch210-cxx11-cu130-x86_64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch210-cxx11-cu130-x86_64-linux/quant_per_thread.py b/build/torch210-cxx11-cu130-x86_64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch210-cxx11-cu130-x86_64-linux/sage_attention/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/sm100_compile.py b/build/torch210-cxx11-cu130-x86_64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/sm80_compile.py b/build/torch210-cxx11-cu130-x86_64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu130-x86_64-linux/sm89_compile.py b/build/torch210-cxx11-cu130-x86_64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch210-cxx11-cu130-x86_64-linux/sm90_compile.py b/build/torch210-cxx11-cu130-x86_64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..13827d2bccbcc42774623d734f1c6bf2732bbfd5 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:736ae53e476341b23da7a539d758dd21ec3b52e78117d8ac6bc371f4831703d6 +size 33326264 diff --git a/build/torch211-cxx11-cu128-aarch64-linux/core.py b/build/torch211-cxx11-cu128-aarch64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu128-aarch64-linux/quant.py b/build/torch211-cxx11-cu128-aarch64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch211-cxx11-cu128-aarch64-linux/quant_per_thread.py b/build/torch211-cxx11-cu128-aarch64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-aarch64-linux/sage_attention/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/sm100_compile.py b/build/torch211-cxx11-cu128-aarch64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch211-cxx11-cu128-aarch64-linux/sm80_compile.py b/build/torch211-cxx11-cu128-aarch64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu128-aarch64-linux/sm89_compile.py b/build/torch211-cxx11-cu128-aarch64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu128-aarch64-linux/sm90_compile.py b/build/torch211-cxx11-cu128-aarch64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..53e09f03509d10e814129606333713d8e043c1e2 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8802a4b35928241c511fc38d81f01f1dfc53d1efb894021f35f5443a1dcfdc36 +size 33420488 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/core.py b/build/torch211-cxx11-cu128-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quant.py b/build/torch211-cxx11-cu128-x86_64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch211-cxx11-cu128-x86_64-linux/quant_per_thread.py b/build/torch211-cxx11-cu128-x86_64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch211-cxx11-cu128-x86_64-linux/sage_attention/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/sm100_compile.py b/build/torch211-cxx11-cu128-x86_64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/sm80_compile.py b/build/torch211-cxx11-cu128-x86_64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu128-x86_64-linux/sm89_compile.py b/build/torch211-cxx11-cu128-x86_64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu128-x86_64-linux/sm90_compile.py b/build/torch211-cxx11-cu128-x86_64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..472eab4dc2ec3ada3e911879d55d01d79ce4fc0f --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:979c9a4ac436b2308df78e2d3c303a306a6c5c951c5991ed471059d30343c90d +size 33871360 diff --git a/build/torch211-cxx11-cu130-aarch64-linux/core.py b/build/torch211-cxx11-cu130-aarch64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu130-aarch64-linux/quant.py b/build/torch211-cxx11-cu130-aarch64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch211-cxx11-cu130-aarch64-linux/quant_per_thread.py b/build/torch211-cxx11-cu130-aarch64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-aarch64-linux/sage_attention/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/sm100_compile.py b/build/torch211-cxx11-cu130-aarch64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch211-cxx11-cu130-aarch64-linux/sm80_compile.py b/build/torch211-cxx11-cu130-aarch64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu130-aarch64-linux/sm89_compile.py b/build/torch211-cxx11-cu130-aarch64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu130-aarch64-linux/sm90_compile.py b/build/torch211-cxx11-cu130-aarch64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f9f6098e9c483e937c2689b22f9f3c47f9e9a263 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38c640d80eaf40fe38d71e107b43ebf3943306d453dcb4c755941fcd176475d0 +size 34154352 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/core.py b/build/torch211-cxx11-cu130-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quant.py b/build/torch211-cxx11-cu130-x86_64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch211-cxx11-cu130-x86_64-linux/quant_per_thread.py b/build/torch211-cxx11-cu130-x86_64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch211-cxx11-cu130-x86_64-linux/sage_attention/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/sm100_compile.py b/build/torch211-cxx11-cu130-x86_64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/sm80_compile.py b/build/torch211-cxx11-cu130-x86_64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu130-x86_64-linux/sm89_compile.py b/build/torch211-cxx11-cu130-x86_64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch211-cxx11-cu130-x86_64-linux/sm90_compile.py b/build/torch211-cxx11-cu130-x86_64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..33bcf153a986fe6f54330cbb3e7cc9b01b880783 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4523ce2 +ops = torch.ops._sage_attention_cuda_4523ce2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4523ce2::{op_name}" diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..948b8895dbebcce05ed580b89d00c6c35ea70951 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2cefef9f2849779311190f1308d0f9aaf91c37bbba5848e149d55ddf8d4b75f0 +size 33328496 diff --git a/build/torch29-cxx11-cu128-aarch64-linux/core.py b/build/torch29-cxx11-cu128-aarch64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu128-aarch64-linux/quant.py b/build/torch29-cxx11-cu128-aarch64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch29-cxx11-cu128-aarch64-linux/quant_per_thread.py b/build/torch29-cxx11-cu128-aarch64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-aarch64-linux/sage_attention/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/sm100_compile.py b/build/torch29-cxx11-cu128-aarch64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch29-cxx11-cu128-aarch64-linux/sm80_compile.py b/build/torch29-cxx11-cu128-aarch64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu128-aarch64-linux/sm89_compile.py b/build/torch29-cxx11-cu128-aarch64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu128-aarch64-linux/sm90_compile.py b/build/torch29-cxx11-cu128-aarch64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..33bcf153a986fe6f54330cbb3e7cc9b01b880783 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4523ce2 +ops = torch.ops._sage_attention_cuda_4523ce2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4523ce2::{op_name}" diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..53a487ff7d853e7a08fce350d2fd2e285ca97e98 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f20ab4c1f2b6b244906554ed7e5d842c271ff55018a9e45d0e71648f5929a13d +size 33405504 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/core.py b/build/torch29-cxx11-cu128-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu128-x86_64-linux/quant.py b/build/torch29-cxx11-cu128-x86_64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch29-cxx11-cu128-x86_64-linux/quant_per_thread.py b/build/torch29-cxx11-cu128-x86_64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/sage_attention/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/sm100_compile.py b/build/torch29-cxx11-cu128-x86_64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/sm80_compile.py b/build/torch29-cxx11-cu128-x86_64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu128-x86_64-linux/sm89_compile.py b/build/torch29-cxx11-cu128-x86_64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu128-x86_64-linux/sm90_compile.py b/build/torch29-cxx11-cu128-x86_64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..19851e7a60583e504f915a61aec78b751a98e11d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f485093ba5a7d1ba2e7a2ce108e9bbe731019e6943defde110823972f71adb95 +size 33657728 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/core.py b/build/torch29-cxx11-cu129-aarch64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/quant.py b/build/torch29-cxx11-cu129-aarch64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch29-cxx11-cu129-aarch64-linux/quant_per_thread.py b/build/torch29-cxx11-cu129-aarch64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch29-cxx11-cu129-aarch64-linux/sage_attention/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/sm100_compile.py b/build/torch29-cxx11-cu129-aarch64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/sm80_compile.py b/build/torch29-cxx11-cu129-aarch64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu129-aarch64-linux/sm89_compile.py b/build/torch29-cxx11-cu129-aarch64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu129-aarch64-linux/sm90_compile.py b/build/torch29-cxx11-cu129-aarch64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0c899b0f46abb46ddf06458b980c7a7ba69539d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4597889 +ops = torch.ops._sage_attention_cuda_4597889 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4597889::{op_name}" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_sage_attention_cuda_4597889.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_sage_attention_cuda_4597889.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..a41aca85b7327b0efa88c450de168769a6afe342 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_sage_attention_cuda_4597889.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cfdc25d37d1d7e969fb8be305e4fa8d496d77592b97a7a4fa77e481bdbff2a8c +size 33746328 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/core.py b/build/torch29-cxx11-cu129-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/quant.py b/build/torch29-cxx11-cu129-x86_64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch29-cxx11-cu129-x86_64-linux/quant_per_thread.py b/build/torch29-cxx11-cu129-x86_64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch29-cxx11-cu129-x86_64-linux/sage_attention/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/sm100_compile.py b/build/torch29-cxx11-cu129-x86_64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/sm80_compile.py b/build/torch29-cxx11-cu129-x86_64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu129-x86_64-linux/sm89_compile.py b/build/torch29-cxx11-cu129-x86_64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu129-x86_64-linux/sm90_compile.py b/build/torch29-cxx11-cu129-x86_64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..33bcf153a986fe6f54330cbb3e7cc9b01b880783 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4523ce2 +ops = torch.ops._sage_attention_cuda_4523ce2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4523ce2::{op_name}" diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..5df2176ae60aa0fae8db49b084601d77e3a58a9f --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_sage_attention_cuda_4523ce2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f57e0fec934a85f56ed39582a7bbf541cf610dfc74a70f24e5f0a2d80666882f +size 33808080 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/core.py b/build/torch29-cxx11-cu130-aarch64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/quant.py b/build/torch29-cxx11-cu130-aarch64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch29-cxx11-cu130-aarch64-linux/quant_per_thread.py b/build/torch29-cxx11-cu130-aarch64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-aarch64-linux/sage_attention/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/sm100_compile.py b/build/torch29-cxx11-cu130-aarch64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/sm80_compile.py b/build/torch29-cxx11-cu130-aarch64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu130-aarch64-linux/sm89_compile.py b/build/torch29-cxx11-cu130-aarch64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu130-aarch64-linux/sm90_compile.py b/build/torch29-cxx11-cu130-aarch64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..959423da909d8e23a8bef3f0a18ef50484dd30bb --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,17 @@ +from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8 +from .core import sageattn + +try: + from .sm100_compile import sageattn3_blackwell + SM100_ENABLED = True +except Exception: + SM100_ENABLED = False + +__all__ = [ + "per_block_int8", + "per_warp_int8", + "sub_mean", + "per_channel_fp8", + "sageattn", + "sageattn3_blackwell", +] \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..33bcf153a986fe6f54330cbb3e7cc9b01b880783 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _sage_attention_cuda_4523ce2 +ops = torch.ops._sage_attention_cuda_4523ce2 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_sage_attention_cuda_4523ce2::{op_name}" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..97814c3be9e9811096dda1c5fdb64983b2d37082 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_sage_attention_cuda_4523ce2.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f06aa0efe6b4a7e5d473ce229ac29403b855f611aaca789334fa6a4bdb4aedbf +size 34139864 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/core.py b/build/torch29-cxx11-cu130-x86_64-linux/core.py new file mode 100644 index 0000000000000000000000000000000000000000..13eff23d691beee1788df58e91c8470b02fc2b6b --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/core.py @@ -0,0 +1,1013 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import warnings + +from ._ops import ops + + +from .quant import per_warp_int8 as per_warp_int8_cuda +from .quant import sub_mean +from .quant import per_channel_fp8 +from .quant_per_thread import per_thread_int8 as per_thread_int8_triton + +try: + from .sm80_compile import ( + qk_int8_sv_f16_accum_f32_attn as sm80_qk_int8_sv_f16_accum_f32_attn, + qk_int8_sv_f16_accum_f16_fuse_v_mean_attn as sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn, + qk_int8_sv_f16_accum_f16_attn as sm80_qk_int8_sv_f16_accum_f16_attn, + qk_int8_sv_f16_accum_f16_attn_inst_buf as sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf, + ) + SM80_ENABLED = True +except Exception as e: + SM80_ENABLED = False + warnings.warn(f"Failed to load SM80 SageAttention kernels: {e}") + +try: + from .sm89_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn, + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf, + qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf as sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf, + ) + SM89_ENABLED = True +except Exception as e: + SM89_ENABLED = False + warnings.warn(f"Failed to load SM89 SageAttention kernels: {e}") + +try: + from .sm90_compile import ( + qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 as sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90, + ) + SM90_ENABLED = True +except Exception as e: + SM90_ENABLED = False + warnings.warn(f"Failed to load SM90 SageAttention kernels: {e}") + +from typing import Any, List, Literal, Optional, Tuple, Union + +import subprocess +import re + + +def get_cuda_version(): + try: + output = subprocess.check_output(["nvcc", "--version"]).decode() + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + return_lse: bool = False, + **kwargs: Any, +): + """ + Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + arch = get_cuda_arch_versions()[q.device.index] + if arch == "sm80": + if not SM80_ENABLED: + raise RuntimeError( + "SM80 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM80 (Ampere)." + ) + return sageattn_qk_int8_pv_fp16_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32", + ) + elif arch == "sm89": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM89 (Ada Lovelace)." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) + elif arch == "sm90": + if not SM90_ENABLED: + raise RuntimeError( + "SM90 SageAttention kernels failed to load. " + "Ensure the kernel was compiled for SM90 (Hopper)." + ) + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp32", + ) + elif arch == "sm120": + if not SM89_ENABLED: + raise RuntimeError( + "SM89 SageAttention kernels failed to load. " + "SM120 (Blackwell) uses SM89 kernels; ensure they were compiled." + ) + return sageattn_qk_int8_pv_fp8_cuda( + q, + k, + v, + tensor_layout=tensor_layout, + is_causal=is_causal, + qk_quant_gran="per_warp", + sm_scale=sm_scale, + return_lse=return_lse, + pv_accum_dtype="fp32+fp16", + ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + else: + raise ValueError(f"Unsupported CUDA architecture: {arch}") + +def sageattn_qk_int8_pv_fp16_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32". + - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b). + - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead. + - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=128, + WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, + WARPK=64, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f32_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16": + if smooth_v: + smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) + lse = sm80_qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + q_int8, + k_int8, + smoothed_v, + o, + q_scale, + k_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp16+fp32": + v = v.to(torch.float16) + lse = sm80_qk_int8_sv_f16_accum_f16_attn_inst_buf( + q_int8, + k_int8, + v, + o, + q_scale, + k_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + else: + raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + +def sageattn_qk_int8_pv_fp8_cuda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp16", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + smooth_v : bool + Whether to smooth the value tensor by subtracting the mean along the sequence dimension. + smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". + Default: False. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64 + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + if pv_accum_dtype == "fp32+fp32" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.") + smooth_v = False + + if pv_accum_dtype == "fp32+fp16" and smooth_v: + warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") + smooth_v = False + + quant_v_scale_max = 448.0 + if pv_accum_dtype == "fp32+fp16": + quant_v_scale_max = 2.25 + + v_fp8, v_scale, vm = per_channel_fp8( + v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v + ) + if pv_accum_dtype == "fp32": + if smooth_v: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + vm, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + else: + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp32": + lse = sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + elif pv_accum_dtype == "fp32+fp16": + lse = sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + torch.cuda.synchronize() + o = o[..., :head_dim_og] + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o + + +def sageattn_qk_int8_pv_fp8_cuda_sm90( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + qk_quant_gran: str = "per_thread", + sm_scale: Optional[float] = None, + pv_accum_dtype: str = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + + qk_quant_gran : str + The granularity of quantization for Q and K, either "per_warp" or "per_thread". + Default: "per_thread". + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + pv_accum_dtype : str + The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32". + - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. + - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. + Default: "fp32+fp32". + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + return_lse : bool + Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. + Default: False. + + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + torch.Tensor + The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). + Shape: ``[batch_size, num_qo_heads, qo_len]``. + Only returned if `return_lse` is True. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [torch.float16, torch.bfloat16], ( + "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + ) + assert qk_quant_gran in ["per_warp", "per_thread"], ( + "qk_quant_gran must be either 'per_warp' or 'per_thread'." + ) + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + torch.cuda.set_device(v.device) + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + _is_caual = 1 if is_causal else 0 + _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 + _return_lse = 1 if return_lse else 0 + + head_dim_og = q.size(-1) + + if head_dim_og < 64: + q = torch.nn.functional.pad(q, (0, 64 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 64 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 64 - head_dim_og)) + elif head_dim_og > 64 and head_dim_og < 128: + q = torch.nn.functional.pad(q, (0, 128 - head_dim_og)) + k = torch.nn.functional.pad(k, (0, 128 - head_dim_og)) + v = torch.nn.functional.pad(v, (0, 128 - head_dim_og)) + elif head_dim_og > 128: + raise ValueError(f"Unsupported head_dim: {head_dim_og}") + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, ( + "Last dim of qkv must be contiguous." + ) + + if sm_scale is None: + sm_scale = head_dim_og**-0.5 + + seq_dim = 1 if _tensor_layout == 0 else 2 + nh_dim = 2 if _tensor_layout == 0 else 1 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + nqheads = q.size(2) + nkheads = k.size(2) + q_per_kv_heads = nqheads // nkheads + if q_per_kv_heads > 1: + # nheads_k => nheads_q + km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim) + else: + km_broadcast = km + if return_lse: + if tensor_layout == "NHD": + lse_correction = ( + torch.matmul( + q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3) + ) + .squeeze(-1) + .to(torch.float32) + ) + else: + lse_correction = ( + torch.matmul(q, km_broadcast.transpose(2, 3)) + .squeeze(-1) + .to(torch.float32) + ) + else: + km = None + + if qk_quant_gran == "per_warp": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128 + ) + elif qk_quant_gran == "per_thread": + q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton( + q, + k, + km, + tensor_layout=tensor_layout, + BLKQ=64, + WARPQ=16, + BLKK=128, + WARPK=128, + ) + + o = torch.empty(q.size(), dtype=dtype, device=q.device) + + # pad v to multiple of 128 + # TODO: modify per_channel_fp8 kernel to handle this + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v.size(1), + v_pad_len, + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=2, + ) + else: + v = torch.cat( + [ + v, + torch.zeros( + v.size(0), + v_pad_len, + v.size(2), + v.size(3), + dtype=v.dtype, + device=v.device, + ), + ], + dim=1, + ) + + v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False) + + if pv_accum_dtype == "fp32": + raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + elif pv_accum_dtype == "fp32+fp32": + lse = sm90_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90( + q_int8, + k_int8, + v_fp8, + o, + q_scale, + k_scale, + v_scale, + _tensor_layout, + _is_caual, + _qk_quant_gran, + sm_scale, + _return_lse, + ) + + o = o[..., :head_dim_og] + + if return_lse: + return ( + o, + lse / 1.44269504 + lse_correction * sm_scale + if smooth_k + else lse / 1.44269504, + ) + else: + return o diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..d6530d3a7dee7b03bc5bb1f69c33aa615bd9bfef --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,14 @@ +{ + "version": 2, + "license": "Apache-2.0", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0a", + "8.0", + "8.9", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/quant.py b/build/torch29-cxx11-cu130-x86_64-linux/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7a32c6a1e502bee12d0e0564ff2b90f6b00462 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/quant.py @@ -0,0 +1,326 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +from typing import Optional + +from ._ops import ops + + +def per_block_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + BLKK: int = 64, + sm_scale: Optional[float] = None, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` and the key tensor `k` with per block quantization. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + sm_scale : Optional[float] + The scale factor for the softmax operation. Default is ``head_dim**-0.5``. + It will be multiplied by ``1.44269504`` to work together with the triton attention kernel. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + sm_scale *= 1.44269504 + + ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout) + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def per_warp_int8( + q: torch.Tensor, + k: torch.Tensor, + km: Optional[torch.Tensor] = None, + BLKQ: int = 128, + WARPQ: int = 32, + BLKK: int = 64, + tensor_layout: str = "HND", +): + """ + Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization. + Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128. + Block size of quantizing `k` is 64 or 128. + + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + km : Optional[torch.Tensor] + The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``. + Should be of the same dtype as `k` if provided. Default is None. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing: + - The quantized query tensor. Shape: Same as `q` but with `int8` dtype. + - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype. + - The quantized key tensor. Shape: Same as `k` but with `int8` dtype. + - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype. + + Note + ---- + - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + """ + + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + q_scale = torch.empty( + (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)), + device=q.device, + dtype=torch.float32, + ) + k_scale = torch.empty( + (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32 + ) + + ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout) + + if km is not None: + km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2) + ops.quant_per_block_int8_fuse_sub_mean_cuda( + k, km, k_int8, k_scale, BLKK, _tensor_layout + ) + else: + # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling + ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout) + + return q_int8, q_scale, k_int8, k_scale + + +def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"): + """ + Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype. + - The returned mean tensor will have the same dtype as the input tensor. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=1 if _tensor_layout == 0 else 2) + + v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device) + + # subtract mean and store the result as fp16 + ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout) + + return v_smoothed, vm + + +def per_channel_fp8( + v: torch.Tensor, + tensor_layout: str = "HND", + scale_max: float = 448.0, + smooth_v: bool = True, +): + """ + Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization. + `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64. + After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``. + The quantization is done per channel, with the scale value and smooth factor calculated per channel. + + Parameters + ---------- + v : torch.Tensor + The input tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + + scale_max : float + The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format). + + smooth_v : bool + Whether to smooth the quantized tensor. Default is True. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + A tuple containing: + - The quantized tensor `v_fp8`. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype. + - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype. + + Note + ---- + - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``. + """ + + _tensor_layout = 0 if tensor_layout == "NHD" else 1 + + if tensor_layout == "HND": + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device + ) + + elif tensor_layout == "NHD": + b, kv_len, h_kv, head_dim = v.shape + padded_len = (kv_len + 63) // 64 * 64 + v_transposed_permutted = torch.empty( + (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device + ) + + ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout) + + v_fp8 = torch.empty( + v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device + ) + + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + + if smooth_v: + ops.mean_scale_fuse_quant_cuda( + v_transposed_permutted, + v_fp8, + vm, + v_scale, + kv_len, + scale_max, + _tensor_layout, + ) + return v_fp8, v_scale, vm + else: + ops.scale_fuse_quant_cuda( + v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout + ) + return v_fp8, v_scale, None diff --git a/build/torch29-cxx11-cu130-x86_64-linux/quant_per_thread.py b/build/torch29-cxx11-cu130-x86_64-linux/quant_per_thread.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f57c3e6dd9df89946ba54c4e2c3844c94d34 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/quant_per_thread.py @@ -0,0 +1,204 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def quant_query_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + # offs_k = tl.arange(0, C) + + # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + # x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + # x = x.to(tl.float32) + # scale = tl.max(tl.abs(x)) / 127. + 0.0000001 + # x_int8 = x / scale + # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + # x_int8 = x_int8.to(tl.int8) + # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + # tl.store(scale_ptrs, scale) + + offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1 + offs_k = tl.arange(0, C) + + input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :] + input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :] + output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :] + output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L) + x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L) + x0 = x0.to(tl.float32) + x1 = x1.to(tl.float32) + scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001 + x0_int8 = x0 / scale + x1_int8 = x1 / scale + x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1) + x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1) + x0_int8 = x0_int8.to(tl.int8) + x1_int8 = x1_int8.to(tl.int8) + tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L) + tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_query_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 8 + off_tld = tl.program_id(0) % 8 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +@triton.jit +def quant_key_per_thread_int4_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) // 4 + off_tld = tl.program_id(0) % 4 + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2 + offs_k = tl.arange(0, C) + + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + scale = tl.max(tl.abs(x)) / 7. + 0.0000001 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + if km is not None: + k = k - km + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b) + quant_query_per_thread_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + C=head_dim, BLK=WARPQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b) + quant_key_per_thread_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + C=head_dim, BLK=WARPK + ) + + return q_int8, q_scale, k_int8, k_scale \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/sage_attention/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/sage_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/sage_attention/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/sm100_compile.py b/build/torch29-cxx11-cu130-x86_64-linux/sm100_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4aa99649cca6673a91467ee522257d645d9e72 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/sm100_compile.py @@ -0,0 +1,327 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from typing import List, Optional, Tuple + +from ._ops import ops, add_op_namespace_prefix +from torch.nn.functional import scaled_dot_product_attention as sdpa + + +# --------------------------------------------------------------------------- +# Low-level ops with torch.compile support (custom_op + register_fake) +# --------------------------------------------------------------------------- + +@torch.library.custom_op( + add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" +) +def mha_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + return ops.mha_fwd( + q, k, v, sfq, sfk, sfv, delta_s, + unpadded_k, out, softmax_scale, is_causal, + per_block_mean, is_bf16, + ) + + +@torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) +def mha_fwd_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sfq: torch.Tensor, + sfk: torch.Tensor, + sfv: torch.Tensor, + delta_s: torch.Tensor, + unpadded_k: int, + out: Optional[torch.Tensor], + softmax_scale: float, + is_causal: bool, + per_block_mean: bool, + is_bf16: bool, +) -> List[torch.Tensor]: + batch_size = q.size(0) + num_heads = q.size(1) + seqlen_q = q.size(2) + head_size_packed = q.size(3) + unpacked_head_size = head_size_packed * 2 + dtype = torch.bfloat16 if is_bf16 else torch.float16 + fake_out = torch.empty( + (batch_size, num_heads, seqlen_q, unpacked_head_size), + dtype=dtype, device=q.device, + ) + fake_lse = torch.empty( + (batch_size, num_heads, seqlen_q), + dtype=torch.float32, device=q.device, + ) + return [fake_out, fake_lse] + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) +def scaled_fp4_quant_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_permute"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_permute( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) +def scaled_fp4_quant_permute_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +@torch.library.custom_op( + add_op_namespace_prefix("scaled_fp4_quant_trans"), + mutates_args=("output", "output_sf"), + device_types="cuda", +) +def scaled_fp4_quant_trans( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) + + +@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) +def scaled_fp4_quant_trans_fake( + input: torch.Tensor, + output: torch.Tensor, + output_sf: torch.Tensor, + tensor_layout: int, +) -> None: + pass + + +# --------------------------------------------------------------------------- +# Triton kernel for grouped mean subtraction +# --------------------------------------------------------------------------- + +@triton.jit +def _group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = ( + pid_b * stride_qb + + pid_h * stride_qh + + offsets[:, None] * stride_ql + + tl.arange(0, D)[None, :] * stride_qd + ) + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = ( + pid_b * stride_qmb + + pid_h * stride_qmh + + pid_group * stride_qml + + tl.arange(0, D) * stride_qmd + ) + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = 128 + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + _group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE, + ) + return q_out, qm + + +# --------------------------------------------------------------------------- +# High-level Python API (ported from sageattn3/api.py) +# --------------------------------------------------------------------------- + +def preprocess_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + per_block_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def pad_128(x): + L = x.size(2) + pad_len = (128 - L % 128) % 128 + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + k = k - k.mean(dim=-2, keepdim=True) + q, k, v = map(pad_128, [q, k, v]) + if per_block_mean: + q, qm = triton_group_mean(q) + else: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + return q, k, v, delta_s + + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def scale_and_quant_fp4_transpose( + x: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty( + (B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn + ) + scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + + +def blockscaled_fp4_attn( + qlist: Tuple[torch.Tensor, torch.Tensor], + klist: Tuple[torch.Tensor, torch.Tensor], + vlist: Tuple[torch.Tensor, torch.Tensor], + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, +) -> List[torch.Tensor]: + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return mha_fwd( + qlist[0], klist[0], vlist[0], + qlist[1], klist[1], vlist[1], + delta_s, KL, None, + softmax_scale, is_causal, per_block_mean, is_bf16, + ) + + +def sageattn3_blackwell( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + per_block_mean: bool = True, + **kwargs, +) -> torch.Tensor: + if q.size(-1) >= 256: + return sdpa(q, k, v, is_causal=is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q) + klist = scale_and_quant_fp4_permute(k) + vlist = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist, klist, vlist, delta_s, + KL, is_causal, per_block_mean, is_bf16, + )[0][:, :, :QL, :].contiguous() + return o_fp4 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/sm80_compile.py b/build/torch29-cxx11-cu130-x86_64-linux/sm80_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..99e1c8d87c278d2935922309bb58eedd2369bfbf --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/sm80_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn")) +def qk_int8_sv_f16_accum_f16_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f32_attn")) +def qk_int8_sv_f16_accum_f32_attn_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_attn_inst_buf")) +def qk_int8_sv_f16_accum_f16_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn")) +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f16_accum_f16_attn = ops.qk_int8_sv_f16_accum_f16_attn +qk_int8_sv_f16_accum_f32_attn = ops.qk_int8_sv_f16_accum_f32_attn +qk_int8_sv_f16_accum_f16_attn_inst_buf = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf +qk_int8_sv_f16_accum_f16_fuse_v_mean_attn = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu130-x86_64-linux/sm89_compile.py b/build/torch29-cxx11-cu130-x86_64-linux/sm89_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..714d471f7d780fb53cb96f08dd5cc27f20581c8c --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/sm89_compile.py @@ -0,0 +1,54 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")) +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_fake( + query, key, value, output, query_scale, key_scale, value_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn diff --git a/build/torch29-cxx11-cu130-x86_64-linux/sm90_compile.py b/build/torch29-cxx11-cu130-x86_64-linux/sm90_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..05898751d86f648aa7218bda29288421bb0308c3 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/sm90_compile.py @@ -0,0 +1,36 @@ +from ._ops import ops +import torch +from ._ops import add_op_namespace_prefix + + +def _lse_fake_impl(query, tensor_layout, return_lse): + batch_size = query.size(0) + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + if return_lse: + return torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device=query.device) + return torch.empty((0)) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_attn_inst_buf")) +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake( + query, key, value, output, query_scale, key_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +@torch.library.register_fake(add_op_namespace_prefix("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90")) +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_fake( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse, +): + return _lse_fake_impl(query, tensor_layout, return_lse) + + +qk_int8_sv_f8_accum_f32_attn_inst_buf = ops.qk_int8_sv_f8_accum_f32_attn_inst_buf +qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90 = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90