medmekk HF Staff commited on
Commit
af2d0c0
·
1 Parent(s): 44b112f

add some builds

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CMakeLists.txt +1 -0
  2. build.toml +10 -6
  3. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__init__.py +12 -0
  4. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  5. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  6. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  7. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  8. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  9. {torch-ext → build/torch27-cxx11-cu126-x86_64-linux}/sage_attention/_ops.py +3 -3
  10. torch-ext/sage_attention/_sage_attention_57cb7ec_dirty.abi3.so → build/torch27-cxx11-cu126-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so +2 -2
  11. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/core.py +983 -0
  12. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/layers.py +0 -0
  13. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/quant.py +326 -0
  14. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/quant_per_thread.py +204 -0
  15. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__init__.py +12 -0
  16. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  17. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  18. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  19. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  20. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  21. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/_ops.py +9 -0
  22. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so +3 -0
  23. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/core.py +983 -0
  24. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/layers.py +0 -0
  25. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/quant.py +326 -0
  26. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/quant_per_thread.py +204 -0
  27. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__init__.py +12 -0
  28. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  29. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  30. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  31. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  32. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  33. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/_ops.py +9 -0
  34. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so +3 -0
  35. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/core.py +983 -0
  36. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/layers.py +0 -0
  37. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/quant.py +326 -0
  38. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/quant_per_thread.py +204 -0
  39. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__init__.py +12 -0
  40. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  41. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  42. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  43. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  44. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  45. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/_ops.py +9 -0
  46. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so +3 -0
  47. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/core.py +983 -0
  48. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/layers.py +0 -0
  49. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/quant.py +326 -0
  50. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/quant_per_thread.py +204 -0
CMakeLists.txt CHANGED
@@ -142,6 +142,7 @@ set(_qattn_sm90_SRC
142
  "sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu"
143
  "sage_attention/qattn/attn_cuda_sm90.h"
144
  "sage_attention/qattn/attn_utils.cuh"
 
145
  )
146
 
147
  # TODO: check if CLion support this:
 
142
  "sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu"
143
  "sage_attention/qattn/attn_cuda_sm90.h"
144
  "sage_attention/qattn/attn_utils.cuh"
145
+ "sage_attention/cuda_tensormap_shim.cuh"
146
  )
147
 
148
  # TODO: check if CLion support this:
build.toml CHANGED
@@ -1,21 +1,20 @@
1
  [general]
2
  name = "sage_attention"
3
  universal = false
 
4
 
5
  [torch]
6
  src = [
7
  "torch-ext/torch_binding.cpp",
8
  "torch-ext/torch_binding.h",
9
  ]
10
- cuda-capabilities = [
11
- "8.0", "9.0"
12
- ]
13
 
14
  [kernel._qattn]
15
  depends = ["torch"]
16
  backend = "cuda"
 
17
  cuda-capabilities = [
18
- "9.0"
19
  ]
20
  src = [
21
  "sage_attention/cp_async.cuh",
@@ -27,6 +26,7 @@ src = [
27
  "sage_attention/reduction_utils.cuh",
28
  "sage_attention/wgmma.cuh",
29
  "sage_attention/utils.cuh",
 
30
  ]
31
  cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
32
  cuda-flags = [
@@ -43,6 +43,7 @@ cuda-flags = [
43
  [kernel._qattn_sm80]
44
  depends = ["torch"]
45
  backend = "cuda"
 
46
  cuda-capabilities = [
47
  "8.0"
48
  ]
@@ -68,6 +69,7 @@ cuda-flags = [
68
  [kernel._qattn_sm89]
69
  depends = ["torch"]
70
  backend = "cuda"
 
71
  cuda-capabilities = [
72
  "8.9",
73
  ]
@@ -100,8 +102,9 @@ cuda-flags = [
100
  [kernel._qattn_sm90]
101
  depends = ["torch"]
102
  backend = "cuda"
 
103
  cuda-capabilities = [
104
- "9.0",
105
  ]
106
  include = ["."]
107
  src = [
@@ -124,8 +127,9 @@ cuda-flags = [
124
  [kernel._fused]
125
  depends = ["torch"]
126
  backend = "cuda"
 
127
  cuda-capabilities = [
128
- "9.0",
129
  ]
130
  include = ["."]
131
  src = [
 
1
  [general]
2
  name = "sage_attention"
3
  universal = false
4
+ cuda-minver = "12.4"
5
 
6
  [torch]
7
  src = [
8
  "torch-ext/torch_binding.cpp",
9
  "torch-ext/torch_binding.h",
10
  ]
 
 
 
11
 
12
  [kernel._qattn]
13
  depends = ["torch"]
14
  backend = "cuda"
15
+ cuda-minver = "12.4"
16
  cuda-capabilities = [
17
+ "8.0", "8.9", "9.0a"
18
  ]
19
  src = [
20
  "sage_attention/cp_async.cuh",
 
26
  "sage_attention/reduction_utils.cuh",
27
  "sage_attention/wgmma.cuh",
28
  "sage_attention/utils.cuh",
29
+ "sage_attention/cuda_tensormap_shim.cuh",
30
  ]
31
  cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
32
  cuda-flags = [
 
43
  [kernel._qattn_sm80]
44
  depends = ["torch"]
45
  backend = "cuda"
46
+ cuda-minver = "12.4"
47
  cuda-capabilities = [
48
  "8.0"
49
  ]
 
69
  [kernel._qattn_sm89]
70
  depends = ["torch"]
71
  backend = "cuda"
72
+ cuda-minver = "12.4"
73
  cuda-capabilities = [
74
  "8.9",
75
  ]
 
102
  [kernel._qattn_sm90]
103
  depends = ["torch"]
104
  backend = "cuda"
105
+ cuda-minver = "12.4"
106
  cuda-capabilities = [
107
+ "9.0a",
108
  ]
109
  include = ["."]
110
  src = [
 
127
  [kernel._fused]
128
  depends = ["torch"]
129
  backend = "cuda"
130
+ cuda-minver = "12.4"
131
  cuda-capabilities = [
132
+ "8.0", "8.9", "9.0a",
133
  ]
134
  include = ["."]
135
  src = [
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
+
4
+
5
+ __all__ = [
6
+ "per_block_int8",
7
+ "per_warp_int8",
8
+ "sub_mean",
9
+ "per_channel_fp8",
10
+ "sageattn",
11
+ "sageattn_qk_int8_pv_fp8_cuda",
12
+ ]
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (433 Bytes). View file
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (550 Bytes). View file
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc ADDED
Binary file (33.4 kB). View file
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc ADDED
Binary file (13.4 kB). View file
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc ADDED
Binary file (13 kB). View file
 
{torch-ext → build/torch27-cxx11-cu126-x86_64-linux}/sage_attention/_ops.py RENAMED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _sage_attention_57cb7ec_dirty
3
- ops = torch.ops._sage_attention_57cb7ec_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_sage_attention_57cb7ec_dirty::{op_name}"
 
1
  import torch
2
+ from . import _sage_attention_44b112f_dirty
3
+ ops = torch.ops._sage_attention_44b112f_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_sage_attention_44b112f_dirty::{op_name}"
torch-ext/sage_attention/_sage_attention_57cb7ec_dirty.abi3.so → build/torch27-cxx11-cu126-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:afa4831d0d218167c818a3871cf9fc01f154a6fc3c4671efdfede77a83e3b083
3
- size 26036368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b577da1986b76b2571e8dd55412621e6fc85fe1a2f847bc0a5af9851bf388cf2
3
+ size 26037568
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/core.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ from .quant import per_warp_int8 as per_warp_int8_cuda
24
+ from .quant import sub_mean
25
+ from .quant import per_channel_fp8
26
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
27
+
28
+ from typing import Any, List, Literal, Optional, Tuple, Union
29
+ import warnings
30
+
31
+
32
+ import subprocess
33
+ import re
34
+
35
+
36
+ def get_cuda_version():
37
+ try:
38
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
39
+ match = re.search(r"release (\d+)\.(\d+)", output)
40
+ if match:
41
+ major, minor = int(match.group(1)), int(match.group(2))
42
+ return major, minor
43
+ except Exception as e:
44
+ print("Failed to get CUDA version:", e)
45
+ return None, None
46
+
47
+
48
+ def get_cuda_arch_versions():
49
+ cuda_archs = []
50
+ for i in range(torch.cuda.device_count()):
51
+ major, minor = torch.cuda.get_device_capability(i)
52
+ cuda_archs.append(f"sm{major}{minor}")
53
+ return cuda_archs
54
+
55
+
56
+ def sageattn(
57
+ q: torch.Tensor,
58
+ k: torch.Tensor,
59
+ v: torch.Tensor,
60
+ tensor_layout: str = "HND",
61
+ is_causal: bool = False,
62
+ sm_scale: Optional[float] = None,
63
+ return_lse: bool = False,
64
+ **kwargs: Any,
65
+ ):
66
+ """
67
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
68
+
69
+ Parameters
70
+ ----------
71
+ q : torch.Tensor
72
+ The query tensor. Shape:
73
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
74
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
75
+
76
+ k : torch.Tensor
77
+ The key tensor. Shape:
78
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
79
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
80
+
81
+ v : torch.Tensor
82
+ The value tensor. Shape:
83
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
84
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
85
+
86
+ tensor_layout : str
87
+ The tensor layout, either "HND" or "NHD".
88
+ Default: "HND".
89
+
90
+ is_causal : bool
91
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
92
+ Default: False.
93
+
94
+ sm_scale : Optional[float]
95
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
96
+
97
+ return_lse : bool
98
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
99
+ Default: False.
100
+
101
+ Returns
102
+ -------
103
+ torch.Tensor
104
+ The output tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ torch.Tensor
109
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
110
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
111
+ Only returned if `return_lse` is True.
112
+
113
+ Note
114
+ ----
115
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
116
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
117
+ - All tensors must be on the same cuda device.
118
+ """
119
+
120
+ arch = get_cuda_arch_versions()[q.device.index]
121
+ if arch == "sm80":
122
+ return sageattn_qk_int8_pv_fp16_cuda(
123
+ q,
124
+ k,
125
+ v,
126
+ tensor_layout=tensor_layout,
127
+ is_causal=is_causal,
128
+ sm_scale=sm_scale,
129
+ return_lse=return_lse,
130
+ pv_accum_dtype="fp32",
131
+ )
132
+ elif arch == "sm89":
133
+ return sageattn_qk_int8_pv_fp8_cuda(
134
+ q,
135
+ k,
136
+ v,
137
+ tensor_layout=tensor_layout,
138
+ is_causal=is_causal,
139
+ sm_scale=sm_scale,
140
+ return_lse=return_lse,
141
+ pv_accum_dtype="fp32+fp16",
142
+ )
143
+ elif arch == "sm90":
144
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
145
+ q,
146
+ k,
147
+ v,
148
+ tensor_layout=tensor_layout,
149
+ is_causal=is_causal,
150
+ sm_scale=sm_scale,
151
+ return_lse=return_lse,
152
+ pv_accum_dtype="fp32+fp32",
153
+ )
154
+ elif arch == "sm120":
155
+ return sageattn_qk_int8_pv_fp8_cuda(
156
+ q,
157
+ k,
158
+ v,
159
+ tensor_layout=tensor_layout,
160
+ is_causal=is_causal,
161
+ qk_quant_gran="per_warp",
162
+ sm_scale=sm_scale,
163
+ return_lse=return_lse,
164
+ pv_accum_dtype="fp32+fp16",
165
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
166
+ else:
167
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
168
+
169
+
170
+ @torch.compiler.disable
171
+ def sageattn_qk_int8_pv_fp16_cuda(
172
+ q: torch.Tensor,
173
+ k: torch.Tensor,
174
+ v: torch.Tensor,
175
+ tensor_layout: str = "HND",
176
+ is_causal: bool = False,
177
+ qk_quant_gran: str = "per_thread",
178
+ sm_scale: Optional[float] = None,
179
+ pv_accum_dtype: str = "fp32",
180
+ smooth_k: bool = True,
181
+ smooth_v: bool = False,
182
+ return_lse: bool = False,
183
+ **kwargs: Any,
184
+ ) -> torch.Tensor:
185
+ """
186
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
187
+
188
+ Parameters
189
+ ----------
190
+ q : torch.Tensor
191
+ The query tensor. Shape:
192
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
193
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
194
+
195
+ k : torch.Tensor
196
+ The key tensor. Shape:
197
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
198
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
199
+
200
+ v : torch.Tensor
201
+ The value tensor. Shape:
202
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
203
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
204
+
205
+ tensor_layout : str
206
+ The tensor layout, either "HND" or "NHD".
207
+ Default: "HND".
208
+
209
+ is_causal : bool
210
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
211
+ Default: False.
212
+
213
+ qk_quant_gran : str
214
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
215
+ Default: "per_thread".
216
+
217
+ sm_scale : Optional[float]
218
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
219
+
220
+ pv_accum_dtype : str
221
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
222
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
223
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
224
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
225
+ Default: "fp32".
226
+
227
+ smooth_k : bool
228
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
229
+ Default: True.
230
+
231
+ smooth_v : bool
232
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
233
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
234
+ Default: False.
235
+
236
+ return_lse : bool
237
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
238
+ Default: False.
239
+
240
+ Returns
241
+ -------
242
+ torch.Tensor
243
+ The output tensor. Shape:
244
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
245
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
246
+
247
+ torch.Tensor
248
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
249
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
250
+ Only returned if `return_lse` is True.
251
+
252
+ Note
253
+ ----
254
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
255
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
256
+ - All tensors must be on the same cuda device.
257
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
258
+ """
259
+
260
+ dtype = q.dtype
261
+ assert q.is_cuda, "Input tensors must be on cuda."
262
+ assert dtype in [torch.float16, torch.bfloat16], (
263
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
264
+ )
265
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
266
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
267
+ )
268
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
269
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
270
+
271
+ # FIXME(DefTruth): make sage attention work compatible with distributed
272
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
273
+ # sage attention will run into illegal memory access error after first
274
+ # inference step in distributed env for multi gpus inference. This small
275
+ # workaround also make sage attention work compatible with torch.compile
276
+ # through non-fullgraph compile mode.
277
+ torch.cuda.set_device(v.device)
278
+
279
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
280
+ _is_caual = 1 if is_causal else 0
281
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
282
+ _return_lse = 1 if return_lse else 0
283
+
284
+ head_dim_og = q.size(-1)
285
+
286
+ if head_dim_og < 64:
287
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
288
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
289
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
290
+ elif head_dim_og > 64 and head_dim_og < 128:
291
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
292
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
293
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
294
+ elif head_dim_og > 128:
295
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
296
+
297
+ # assert last dim is contiguous
298
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
299
+ "Last dim of qkv must be contiguous."
300
+ )
301
+
302
+ if sm_scale is None:
303
+ sm_scale = head_dim_og**-0.5
304
+
305
+ seq_dim = 1 if _tensor_layout == 0 else 2
306
+ nh_dim = 2 if _tensor_layout == 0 else 1
307
+
308
+ if smooth_k:
309
+ km = k.mean(dim=seq_dim, keepdim=True)
310
+ nqheads = q.size(2)
311
+ nkheads = k.size(2)
312
+ q_per_kv_heads = nqheads // nkheads
313
+ if q_per_kv_heads > 1:
314
+ # nheads_k => nheads_q
315
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
316
+ else:
317
+ km_broadcast = km
318
+ if return_lse:
319
+ if tensor_layout == "NHD":
320
+ lse_correction = (
321
+ torch.matmul(
322
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
323
+ )
324
+ .squeeze(-1)
325
+ .to(torch.float32)
326
+ )
327
+ else:
328
+ lse_correction = (
329
+ torch.matmul(q, km_broadcast.transpose(2, 3))
330
+ .squeeze(-1)
331
+ .to(torch.float32)
332
+ )
333
+ else:
334
+ km = None
335
+
336
+ if qk_quant_gran == "per_warp":
337
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
338
+ q,
339
+ k,
340
+ km,
341
+ tensor_layout=tensor_layout,
342
+ BLKQ=128,
343
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
344
+ BLKK=64,
345
+ )
346
+ elif qk_quant_gran == "per_thread":
347
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
348
+ q,
349
+ k,
350
+ km,
351
+ tensor_layout=tensor_layout,
352
+ BLKQ=128,
353
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
354
+ BLKK=64,
355
+ WARPK=64,
356
+ )
357
+
358
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
359
+
360
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
361
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
362
+ smooth_v = False
363
+
364
+ if pv_accum_dtype == "fp32":
365
+ v = v.to(torch.float16)
366
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
+ q_int8,
368
+ k_int8,
369
+ v,
370
+ o,
371
+ q_scale,
372
+ k_scale,
373
+ _tensor_layout,
374
+ _is_caual,
375
+ _qk_quant_gran,
376
+ sm_scale,
377
+ _return_lse,
378
+ )
379
+ elif pv_accum_dtype == "fp16":
380
+ if smooth_v:
381
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
+ q_int8,
384
+ k_int8,
385
+ smoothed_v,
386
+ o,
387
+ q_scale,
388
+ k_scale,
389
+ vm,
390
+ _tensor_layout,
391
+ _is_caual,
392
+ _qk_quant_gran,
393
+ sm_scale,
394
+ _return_lse,
395
+ )
396
+ else:
397
+ v = v.to(torch.float16)
398
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
+ q_int8,
400
+ k_int8,
401
+ v,
402
+ o,
403
+ q_scale,
404
+ k_scale,
405
+ _tensor_layout,
406
+ _is_caual,
407
+ _qk_quant_gran,
408
+ sm_scale,
409
+ _return_lse,
410
+ )
411
+ elif pv_accum_dtype == "fp16+fp32":
412
+ v = v.to(torch.float16)
413
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
+ q_int8,
415
+ k_int8,
416
+ v,
417
+ o,
418
+ q_scale,
419
+ k_scale,
420
+ _tensor_layout,
421
+ _is_caual,
422
+ _qk_quant_gran,
423
+ sm_scale,
424
+ _return_lse,
425
+ )
426
+ else:
427
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
428
+
429
+ o = o[..., :head_dim_og]
430
+
431
+ if return_lse:
432
+ return (
433
+ o,
434
+ lse / 1.44269504 + lse_correction * sm_scale
435
+ if smooth_k
436
+ else lse / 1.44269504,
437
+ )
438
+ else:
439
+ return o
440
+
441
+
442
+ @torch.compiler.disable
443
+ def sageattn_qk_int8_pv_fp8_cuda(
444
+ q: torch.Tensor,
445
+ k: torch.Tensor,
446
+ v: torch.Tensor,
447
+ tensor_layout: str = "HND",
448
+ is_causal: bool = False,
449
+ qk_quant_gran: str = "per_thread",
450
+ sm_scale: Optional[float] = None,
451
+ pv_accum_dtype: str = "fp32+fp16",
452
+ smooth_k: bool = True,
453
+ smooth_v: bool = False,
454
+ return_lse: bool = False,
455
+ **kwargs: Any,
456
+ ) -> torch.Tensor:
457
+ """
458
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
459
+
460
+ Parameters
461
+ ----------
462
+ q : torch.Tensor
463
+ The query tensor. Shape:
464
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
465
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
466
+
467
+ k : torch.Tensor
468
+ The key tensor. Shape:
469
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
470
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
471
+
472
+ v : torch.Tensor
473
+ The value tensor. Shape:
474
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
475
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
476
+
477
+ tensor_layout : str
478
+ The tensor layout, either "HND" or "NHD".
479
+ Default: "HND".
480
+
481
+ is_causal : bool
482
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
483
+ Default: False.
484
+
485
+ qk_quant_gran : str
486
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
487
+ Default: "per_thread".
488
+
489
+ sm_scale : Optional[float]
490
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
491
+
492
+ pv_accum_dtype : str
493
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
494
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
495
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
496
+ Default: "fp32+fp32".
497
+
498
+ smooth_k : bool
499
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
500
+ Default: True.
501
+
502
+ smooth_v : bool
503
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
504
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
505
+ Default: False.
506
+
507
+ return_lse : bool
508
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
509
+ Default: False.
510
+
511
+ Returns
512
+ -------
513
+ torch.Tensor
514
+ The output tensor. Shape:
515
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
516
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
517
+
518
+ torch.Tensor
519
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
520
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
521
+ Only returned if `return_lse` is True.
522
+
523
+ Note
524
+ ----
525
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
526
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
527
+ - All tensors must be on the same cuda device.
528
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
529
+ """
530
+
531
+ dtype = q.dtype
532
+ assert q.is_cuda, "Input tensors must be on cuda."
533
+ assert dtype in [torch.float16, torch.bfloat16], (
534
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
535
+ )
536
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
537
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
538
+ )
539
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
540
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
541
+
542
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
543
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
544
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
545
+ # pv_accum_dtype = 'fp32+fp32'
546
+
547
+ # FIXME(DefTruth): make sage attention work compatible with distributed
548
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
549
+ # sage attention will run into illegal memory access error after first
550
+ # inference step in distributed env for multi gpus inference. This small
551
+ # workaround also make sage attention work compatible with torch.compile
552
+ # through non-fullgraph compile mode.
553
+ torch.cuda.set_device(v.device)
554
+
555
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
556
+ _is_caual = 1 if is_causal else 0
557
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
558
+ _return_lse = 1 if return_lse else 0
559
+
560
+ head_dim_og = q.size(-1)
561
+
562
+ if head_dim_og < 64:
563
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
564
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
565
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
566
+ elif head_dim_og > 64 and head_dim_og < 128:
567
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
568
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
569
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
570
+ elif head_dim_og > 128:
571
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
572
+
573
+ # assert last dim is contiguous
574
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
575
+ "Last dim of qkv must be contiguous."
576
+ )
577
+
578
+ if sm_scale is None:
579
+ sm_scale = head_dim_og**-0.5
580
+
581
+ seq_dim = 1 if _tensor_layout == 0 else 2
582
+ nh_dim = 2 if _tensor_layout == 0 else 1
583
+
584
+ if smooth_k:
585
+ km = k.mean(dim=seq_dim, keepdim=True)
586
+ nqheads = q.size(2)
587
+ nkheads = k.size(2)
588
+ q_per_kv_heads = nqheads // nkheads
589
+ if q_per_kv_heads > 1:
590
+ # nheads_k => nheads_q
591
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
592
+ else:
593
+ km_broadcast = km
594
+ if return_lse:
595
+ if tensor_layout == "NHD":
596
+ lse_correction = (
597
+ torch.matmul(
598
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
599
+ )
600
+ .squeeze(-1)
601
+ .to(torch.float32)
602
+ )
603
+ else:
604
+ lse_correction = (
605
+ torch.matmul(q, km_broadcast.transpose(2, 3))
606
+ .squeeze(-1)
607
+ .to(torch.float32)
608
+ )
609
+ else:
610
+ km = None
611
+
612
+ if qk_quant_gran == "per_warp":
613
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
614
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
615
+ )
616
+ elif qk_quant_gran == "per_thread":
617
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
618
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
619
+ )
620
+
621
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
622
+
623
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
624
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
625
+ smooth_v = False
626
+
627
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
628
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
629
+ smooth_v = False
630
+
631
+ quant_v_scale_max = 448.0
632
+ if pv_accum_dtype == "fp32+fp16":
633
+ quant_v_scale_max = 2.25
634
+
635
+ v_fp8, v_scale, vm = per_channel_fp8(
636
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
637
+ )
638
+ print("before kernel call")
639
+ if pv_accum_dtype == "fp32":
640
+ if smooth_v:
641
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
642
+ q_int8,
643
+ k_int8,
644
+ v_fp8,
645
+ o,
646
+ q_scale,
647
+ k_scale,
648
+ v_scale,
649
+ vm,
650
+ _tensor_layout,
651
+ _is_caual,
652
+ _qk_quant_gran,
653
+ sm_scale,
654
+ _return_lse,
655
+ )
656
+ torch.cuda.synchronize()
657
+ else:
658
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
659
+ q_int8,
660
+ k_int8,
661
+ v_fp8,
662
+ o,
663
+ q_scale,
664
+ k_scale,
665
+ v_scale,
666
+ _tensor_layout,
667
+ _is_caual,
668
+ _qk_quant_gran,
669
+ sm_scale,
670
+ _return_lse,
671
+ )
672
+ torch.cuda.synchronize()
673
+ elif pv_accum_dtype == "fp32+fp32":
674
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
675
+ q_int8,
676
+ k_int8,
677
+ v_fp8,
678
+ o,
679
+ q_scale,
680
+ k_scale,
681
+ v_scale,
682
+ _tensor_layout,
683
+ _is_caual,
684
+ _qk_quant_gran,
685
+ sm_scale,
686
+ _return_lse,
687
+ )
688
+ torch.cuda.synchronize()
689
+ elif pv_accum_dtype == "fp32+fp16":
690
+ lse = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
691
+ q_int8,
692
+ k_int8,
693
+ v_fp8,
694
+ o,
695
+ q_scale,
696
+ k_scale,
697
+ v_scale,
698
+ _tensor_layout,
699
+ _is_caual,
700
+ _qk_quant_gran,
701
+ sm_scale,
702
+ _return_lse,
703
+ )
704
+ torch.cuda.synchronize()
705
+ o = o[..., :head_dim_og]
706
+ print("after kernel call")
707
+ if return_lse:
708
+ return (
709
+ o,
710
+ lse / 1.44269504 + lse_correction * sm_scale
711
+ if smooth_k
712
+ else lse / 1.44269504,
713
+ )
714
+ else:
715
+ return o
716
+
717
+
718
+ @torch.compiler.disable
719
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
720
+ q: torch.Tensor,
721
+ k: torch.Tensor,
722
+ v: torch.Tensor,
723
+ tensor_layout: str = "HND",
724
+ is_causal: bool = False,
725
+ qk_quant_gran: str = "per_thread",
726
+ sm_scale: Optional[float] = None,
727
+ pv_accum_dtype: str = "fp32+fp32",
728
+ smooth_k: bool = True,
729
+ return_lse: bool = False,
730
+ **kwargs: Any,
731
+ ) -> torch.Tensor:
732
+ """
733
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
734
+
735
+ Parameters
736
+ ----------
737
+ q : torch.Tensor
738
+ The query tensor. Shape:
739
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
740
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
741
+
742
+ k : torch.Tensor
743
+ The key tensor. Shape:
744
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
745
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
746
+
747
+ v : torch.Tensor
748
+ The value tensor. Shape:
749
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
750
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
751
+
752
+ tensor_layout : str
753
+ The tensor layout, either "HND" or "NHD".
754
+ Default: "HND".
755
+
756
+ is_causal : bool
757
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
758
+ Default: False.
759
+
760
+ qk_quant_gran : str
761
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
762
+ Default: "per_thread".
763
+
764
+ sm_scale : Optional[float]
765
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
766
+
767
+ pv_accum_dtype : str
768
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
769
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
770
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
771
+ Default: "fp32+fp32".
772
+
773
+ smooth_k : bool
774
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
775
+ Default: True.
776
+
777
+ return_lse : bool
778
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
779
+ Default: False.
780
+
781
+ Returns
782
+ -------
783
+ torch.Tensor
784
+ The output tensor. Shape:
785
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
786
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
787
+
788
+ torch.Tensor
789
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
790
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
791
+ Only returned if `return_lse` is True.
792
+
793
+ Note
794
+ ----
795
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
796
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
797
+ - All tensors must be on the same cuda device.
798
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
799
+ """
800
+
801
+ dtype = q.dtype
802
+ assert q.is_cuda, "Input tensors must be on cuda."
803
+ assert dtype in [torch.float16, torch.bfloat16], (
804
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
805
+ )
806
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
807
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
808
+ )
809
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
810
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
811
+
812
+ torch.cuda.set_device(v.device)
813
+
814
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
815
+ _is_caual = 1 if is_causal else 0
816
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
817
+ _return_lse = 1 if return_lse else 0
818
+
819
+ head_dim_og = q.size(-1)
820
+
821
+ if head_dim_og < 64:
822
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
823
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
824
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
825
+ elif head_dim_og > 64 and head_dim_og < 128:
826
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
827
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
828
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
829
+ elif head_dim_og > 128:
830
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
831
+
832
+ # assert last dim is contiguous
833
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
834
+ "Last dim of qkv must be contiguous."
835
+ )
836
+
837
+ if sm_scale is None:
838
+ sm_scale = head_dim_og**-0.5
839
+
840
+ seq_dim = 1 if _tensor_layout == 0 else 2
841
+ nh_dim = 2 if _tensor_layout == 0 else 1
842
+
843
+ if smooth_k:
844
+ km = k.mean(dim=seq_dim, keepdim=True)
845
+ nqheads = q.size(2)
846
+ nkheads = k.size(2)
847
+ q_per_kv_heads = nqheads // nkheads
848
+ if q_per_kv_heads > 1:
849
+ # nheads_k => nheads_q
850
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
851
+ else:
852
+ km_broadcast = km
853
+ if return_lse:
854
+ if tensor_layout == "NHD":
855
+ lse_correction = (
856
+ torch.matmul(
857
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
858
+ )
859
+ .squeeze(-1)
860
+ .to(torch.float32)
861
+ )
862
+ else:
863
+ lse_correction = (
864
+ torch.matmul(q, km_broadcast.transpose(2, 3))
865
+ .squeeze(-1)
866
+ .to(torch.float32)
867
+ )
868
+ else:
869
+ km = None
870
+
871
+ if qk_quant_gran == "per_warp":
872
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
873
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
874
+ )
875
+ elif qk_quant_gran == "per_thread":
876
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
877
+ q,
878
+ k,
879
+ km,
880
+ tensor_layout=tensor_layout,
881
+ BLKQ=64,
882
+ WARPQ=16,
883
+ BLKK=128,
884
+ WARPK=128,
885
+ )
886
+
887
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
888
+
889
+ # pad v to multiple of 128
890
+ # TODO: modify per_channel_fp8 kernel to handle this
891
+ kv_len = k.size(seq_dim)
892
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
893
+ if v_pad_len > 0:
894
+ if tensor_layout == "HND":
895
+ v = torch.cat(
896
+ [
897
+ v,
898
+ torch.zeros(
899
+ v.size(0),
900
+ v.size(1),
901
+ v_pad_len,
902
+ v.size(3),
903
+ dtype=v.dtype,
904
+ device=v.device,
905
+ ),
906
+ ],
907
+ dim=2,
908
+ )
909
+ else:
910
+ v = torch.cat(
911
+ [
912
+ v,
913
+ torch.zeros(
914
+ v.size(0),
915
+ v_pad_len,
916
+ v.size(2),
917
+ v.size(3),
918
+ dtype=v.dtype,
919
+ device=v.device,
920
+ ),
921
+ ],
922
+ dim=1,
923
+ )
924
+
925
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
926
+
927
+ if pv_accum_dtype == "fp32":
928
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
929
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
930
+ q_int8,
931
+ k_int8,
932
+ v_fp8,
933
+ o,
934
+ q_scale,
935
+ k_scale,
936
+ v_scale,
937
+ _tensor_layout,
938
+ _is_caual,
939
+ _qk_quant_gran,
940
+ sm_scale,
941
+ _return_lse,
942
+ )
943
+ elif pv_accum_dtype == "fp32+fp32":
944
+ print(
945
+ "qint8",
946
+ q_int8.shape,
947
+ "qscale",
948
+ q_scale.shape,
949
+ "kint8",
950
+ k_int8.shape,
951
+ "kscale",
952
+ k_scale.shape,
953
+ "vfp8",
954
+ v_fp8.shape,
955
+ "vscale",
956
+ v_scale.shape,
957
+ )
958
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
+ q_int8,
960
+ k_int8,
961
+ v_fp8,
962
+ o,
963
+ q_scale,
964
+ k_scale,
965
+ v_scale,
966
+ _tensor_layout,
967
+ _is_caual,
968
+ _qk_quant_gran,
969
+ sm_scale,
970
+ _return_lse,
971
+ )
972
+
973
+ o = o[..., :head_dim_og]
974
+
975
+ if return_lse:
976
+ return (
977
+ o,
978
+ lse / 1.44269504 + lse_correction * sm_scale
979
+ if smooth_k
980
+ else lse / 1.44269504,
981
+ )
982
+ else:
983
+ return o
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/layers.py ADDED
File without changes
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
+
4
+
5
+ __all__ = [
6
+ "per_block_int8",
7
+ "per_warp_int8",
8
+ "sub_mean",
9
+ "per_channel_fp8",
10
+ "sageattn",
11
+ "sageattn_qk_int8_pv_fp8_cuda",
12
+ ]
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (433 Bytes). View file
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (550 Bytes). View file
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc ADDED
Binary file (33.4 kB). View file
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc ADDED
Binary file (13.4 kB). View file
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc ADDED
Binary file (13 kB). View file
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _sage_attention_44b112f_dirty
3
+ ops = torch.ops._sage_attention_44b112f_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_sage_attention_44b112f_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d47c952dd9781283ff0dcbd533779de33b0bfa1966dcc0cc8accd0412217c1c5
3
+ size 26553840
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/core.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ from .quant import per_warp_int8 as per_warp_int8_cuda
24
+ from .quant import sub_mean
25
+ from .quant import per_channel_fp8
26
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
27
+
28
+ from typing import Any, List, Literal, Optional, Tuple, Union
29
+ import warnings
30
+
31
+
32
+ import subprocess
33
+ import re
34
+
35
+
36
+ def get_cuda_version():
37
+ try:
38
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
39
+ match = re.search(r"release (\d+)\.(\d+)", output)
40
+ if match:
41
+ major, minor = int(match.group(1)), int(match.group(2))
42
+ return major, minor
43
+ except Exception as e:
44
+ print("Failed to get CUDA version:", e)
45
+ return None, None
46
+
47
+
48
+ def get_cuda_arch_versions():
49
+ cuda_archs = []
50
+ for i in range(torch.cuda.device_count()):
51
+ major, minor = torch.cuda.get_device_capability(i)
52
+ cuda_archs.append(f"sm{major}{minor}")
53
+ return cuda_archs
54
+
55
+
56
+ def sageattn(
57
+ q: torch.Tensor,
58
+ k: torch.Tensor,
59
+ v: torch.Tensor,
60
+ tensor_layout: str = "HND",
61
+ is_causal: bool = False,
62
+ sm_scale: Optional[float] = None,
63
+ return_lse: bool = False,
64
+ **kwargs: Any,
65
+ ):
66
+ """
67
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
68
+
69
+ Parameters
70
+ ----------
71
+ q : torch.Tensor
72
+ The query tensor. Shape:
73
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
74
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
75
+
76
+ k : torch.Tensor
77
+ The key tensor. Shape:
78
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
79
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
80
+
81
+ v : torch.Tensor
82
+ The value tensor. Shape:
83
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
84
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
85
+
86
+ tensor_layout : str
87
+ The tensor layout, either "HND" or "NHD".
88
+ Default: "HND".
89
+
90
+ is_causal : bool
91
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
92
+ Default: False.
93
+
94
+ sm_scale : Optional[float]
95
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
96
+
97
+ return_lse : bool
98
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
99
+ Default: False.
100
+
101
+ Returns
102
+ -------
103
+ torch.Tensor
104
+ The output tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ torch.Tensor
109
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
110
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
111
+ Only returned if `return_lse` is True.
112
+
113
+ Note
114
+ ----
115
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
116
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
117
+ - All tensors must be on the same cuda device.
118
+ """
119
+
120
+ arch = get_cuda_arch_versions()[q.device.index]
121
+ if arch == "sm80":
122
+ return sageattn_qk_int8_pv_fp16_cuda(
123
+ q,
124
+ k,
125
+ v,
126
+ tensor_layout=tensor_layout,
127
+ is_causal=is_causal,
128
+ sm_scale=sm_scale,
129
+ return_lse=return_lse,
130
+ pv_accum_dtype="fp32",
131
+ )
132
+ elif arch == "sm89":
133
+ return sageattn_qk_int8_pv_fp8_cuda(
134
+ q,
135
+ k,
136
+ v,
137
+ tensor_layout=tensor_layout,
138
+ is_causal=is_causal,
139
+ sm_scale=sm_scale,
140
+ return_lse=return_lse,
141
+ pv_accum_dtype="fp32+fp16",
142
+ )
143
+ elif arch == "sm90":
144
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
145
+ q,
146
+ k,
147
+ v,
148
+ tensor_layout=tensor_layout,
149
+ is_causal=is_causal,
150
+ sm_scale=sm_scale,
151
+ return_lse=return_lse,
152
+ pv_accum_dtype="fp32+fp32",
153
+ )
154
+ elif arch == "sm120":
155
+ return sageattn_qk_int8_pv_fp8_cuda(
156
+ q,
157
+ k,
158
+ v,
159
+ tensor_layout=tensor_layout,
160
+ is_causal=is_causal,
161
+ qk_quant_gran="per_warp",
162
+ sm_scale=sm_scale,
163
+ return_lse=return_lse,
164
+ pv_accum_dtype="fp32+fp16",
165
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
166
+ else:
167
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
168
+
169
+
170
+ @torch.compiler.disable
171
+ def sageattn_qk_int8_pv_fp16_cuda(
172
+ q: torch.Tensor,
173
+ k: torch.Tensor,
174
+ v: torch.Tensor,
175
+ tensor_layout: str = "HND",
176
+ is_causal: bool = False,
177
+ qk_quant_gran: str = "per_thread",
178
+ sm_scale: Optional[float] = None,
179
+ pv_accum_dtype: str = "fp32",
180
+ smooth_k: bool = True,
181
+ smooth_v: bool = False,
182
+ return_lse: bool = False,
183
+ **kwargs: Any,
184
+ ) -> torch.Tensor:
185
+ """
186
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
187
+
188
+ Parameters
189
+ ----------
190
+ q : torch.Tensor
191
+ The query tensor. Shape:
192
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
193
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
194
+
195
+ k : torch.Tensor
196
+ The key tensor. Shape:
197
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
198
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
199
+
200
+ v : torch.Tensor
201
+ The value tensor. Shape:
202
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
203
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
204
+
205
+ tensor_layout : str
206
+ The tensor layout, either "HND" or "NHD".
207
+ Default: "HND".
208
+
209
+ is_causal : bool
210
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
211
+ Default: False.
212
+
213
+ qk_quant_gran : str
214
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
215
+ Default: "per_thread".
216
+
217
+ sm_scale : Optional[float]
218
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
219
+
220
+ pv_accum_dtype : str
221
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
222
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
223
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
224
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
225
+ Default: "fp32".
226
+
227
+ smooth_k : bool
228
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
229
+ Default: True.
230
+
231
+ smooth_v : bool
232
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
233
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
234
+ Default: False.
235
+
236
+ return_lse : bool
237
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
238
+ Default: False.
239
+
240
+ Returns
241
+ -------
242
+ torch.Tensor
243
+ The output tensor. Shape:
244
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
245
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
246
+
247
+ torch.Tensor
248
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
249
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
250
+ Only returned if `return_lse` is True.
251
+
252
+ Note
253
+ ----
254
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
255
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
256
+ - All tensors must be on the same cuda device.
257
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
258
+ """
259
+
260
+ dtype = q.dtype
261
+ assert q.is_cuda, "Input tensors must be on cuda."
262
+ assert dtype in [torch.float16, torch.bfloat16], (
263
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
264
+ )
265
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
266
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
267
+ )
268
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
269
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
270
+
271
+ # FIXME(DefTruth): make sage attention work compatible with distributed
272
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
273
+ # sage attention will run into illegal memory access error after first
274
+ # inference step in distributed env for multi gpus inference. This small
275
+ # workaround also make sage attention work compatible with torch.compile
276
+ # through non-fullgraph compile mode.
277
+ torch.cuda.set_device(v.device)
278
+
279
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
280
+ _is_caual = 1 if is_causal else 0
281
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
282
+ _return_lse = 1 if return_lse else 0
283
+
284
+ head_dim_og = q.size(-1)
285
+
286
+ if head_dim_og < 64:
287
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
288
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
289
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
290
+ elif head_dim_og > 64 and head_dim_og < 128:
291
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
292
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
293
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
294
+ elif head_dim_og > 128:
295
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
296
+
297
+ # assert last dim is contiguous
298
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
299
+ "Last dim of qkv must be contiguous."
300
+ )
301
+
302
+ if sm_scale is None:
303
+ sm_scale = head_dim_og**-0.5
304
+
305
+ seq_dim = 1 if _tensor_layout == 0 else 2
306
+ nh_dim = 2 if _tensor_layout == 0 else 1
307
+
308
+ if smooth_k:
309
+ km = k.mean(dim=seq_dim, keepdim=True)
310
+ nqheads = q.size(2)
311
+ nkheads = k.size(2)
312
+ q_per_kv_heads = nqheads // nkheads
313
+ if q_per_kv_heads > 1:
314
+ # nheads_k => nheads_q
315
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
316
+ else:
317
+ km_broadcast = km
318
+ if return_lse:
319
+ if tensor_layout == "NHD":
320
+ lse_correction = (
321
+ torch.matmul(
322
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
323
+ )
324
+ .squeeze(-1)
325
+ .to(torch.float32)
326
+ )
327
+ else:
328
+ lse_correction = (
329
+ torch.matmul(q, km_broadcast.transpose(2, 3))
330
+ .squeeze(-1)
331
+ .to(torch.float32)
332
+ )
333
+ else:
334
+ km = None
335
+
336
+ if qk_quant_gran == "per_warp":
337
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
338
+ q,
339
+ k,
340
+ km,
341
+ tensor_layout=tensor_layout,
342
+ BLKQ=128,
343
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
344
+ BLKK=64,
345
+ )
346
+ elif qk_quant_gran == "per_thread":
347
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
348
+ q,
349
+ k,
350
+ km,
351
+ tensor_layout=tensor_layout,
352
+ BLKQ=128,
353
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
354
+ BLKK=64,
355
+ WARPK=64,
356
+ )
357
+
358
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
359
+
360
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
361
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
362
+ smooth_v = False
363
+
364
+ if pv_accum_dtype == "fp32":
365
+ v = v.to(torch.float16)
366
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
+ q_int8,
368
+ k_int8,
369
+ v,
370
+ o,
371
+ q_scale,
372
+ k_scale,
373
+ _tensor_layout,
374
+ _is_caual,
375
+ _qk_quant_gran,
376
+ sm_scale,
377
+ _return_lse,
378
+ )
379
+ elif pv_accum_dtype == "fp16":
380
+ if smooth_v:
381
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
+ q_int8,
384
+ k_int8,
385
+ smoothed_v,
386
+ o,
387
+ q_scale,
388
+ k_scale,
389
+ vm,
390
+ _tensor_layout,
391
+ _is_caual,
392
+ _qk_quant_gran,
393
+ sm_scale,
394
+ _return_lse,
395
+ )
396
+ else:
397
+ v = v.to(torch.float16)
398
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
+ q_int8,
400
+ k_int8,
401
+ v,
402
+ o,
403
+ q_scale,
404
+ k_scale,
405
+ _tensor_layout,
406
+ _is_caual,
407
+ _qk_quant_gran,
408
+ sm_scale,
409
+ _return_lse,
410
+ )
411
+ elif pv_accum_dtype == "fp16+fp32":
412
+ v = v.to(torch.float16)
413
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
+ q_int8,
415
+ k_int8,
416
+ v,
417
+ o,
418
+ q_scale,
419
+ k_scale,
420
+ _tensor_layout,
421
+ _is_caual,
422
+ _qk_quant_gran,
423
+ sm_scale,
424
+ _return_lse,
425
+ )
426
+ else:
427
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
428
+
429
+ o = o[..., :head_dim_og]
430
+
431
+ if return_lse:
432
+ return (
433
+ o,
434
+ lse / 1.44269504 + lse_correction * sm_scale
435
+ if smooth_k
436
+ else lse / 1.44269504,
437
+ )
438
+ else:
439
+ return o
440
+
441
+
442
+ @torch.compiler.disable
443
+ def sageattn_qk_int8_pv_fp8_cuda(
444
+ q: torch.Tensor,
445
+ k: torch.Tensor,
446
+ v: torch.Tensor,
447
+ tensor_layout: str = "HND",
448
+ is_causal: bool = False,
449
+ qk_quant_gran: str = "per_thread",
450
+ sm_scale: Optional[float] = None,
451
+ pv_accum_dtype: str = "fp32+fp16",
452
+ smooth_k: bool = True,
453
+ smooth_v: bool = False,
454
+ return_lse: bool = False,
455
+ **kwargs: Any,
456
+ ) -> torch.Tensor:
457
+ """
458
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
459
+
460
+ Parameters
461
+ ----------
462
+ q : torch.Tensor
463
+ The query tensor. Shape:
464
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
465
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
466
+
467
+ k : torch.Tensor
468
+ The key tensor. Shape:
469
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
470
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
471
+
472
+ v : torch.Tensor
473
+ The value tensor. Shape:
474
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
475
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
476
+
477
+ tensor_layout : str
478
+ The tensor layout, either "HND" or "NHD".
479
+ Default: "HND".
480
+
481
+ is_causal : bool
482
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
483
+ Default: False.
484
+
485
+ qk_quant_gran : str
486
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
487
+ Default: "per_thread".
488
+
489
+ sm_scale : Optional[float]
490
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
491
+
492
+ pv_accum_dtype : str
493
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
494
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
495
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
496
+ Default: "fp32+fp32".
497
+
498
+ smooth_k : bool
499
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
500
+ Default: True.
501
+
502
+ smooth_v : bool
503
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
504
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
505
+ Default: False.
506
+
507
+ return_lse : bool
508
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
509
+ Default: False.
510
+
511
+ Returns
512
+ -------
513
+ torch.Tensor
514
+ The output tensor. Shape:
515
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
516
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
517
+
518
+ torch.Tensor
519
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
520
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
521
+ Only returned if `return_lse` is True.
522
+
523
+ Note
524
+ ----
525
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
526
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
527
+ - All tensors must be on the same cuda device.
528
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
529
+ """
530
+
531
+ dtype = q.dtype
532
+ assert q.is_cuda, "Input tensors must be on cuda."
533
+ assert dtype in [torch.float16, torch.bfloat16], (
534
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
535
+ )
536
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
537
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
538
+ )
539
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
540
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
541
+
542
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
543
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
544
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
545
+ # pv_accum_dtype = 'fp32+fp32'
546
+
547
+ # FIXME(DefTruth): make sage attention work compatible with distributed
548
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
549
+ # sage attention will run into illegal memory access error after first
550
+ # inference step in distributed env for multi gpus inference. This small
551
+ # workaround also make sage attention work compatible with torch.compile
552
+ # through non-fullgraph compile mode.
553
+ torch.cuda.set_device(v.device)
554
+
555
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
556
+ _is_caual = 1 if is_causal else 0
557
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
558
+ _return_lse = 1 if return_lse else 0
559
+
560
+ head_dim_og = q.size(-1)
561
+
562
+ if head_dim_og < 64:
563
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
564
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
565
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
566
+ elif head_dim_og > 64 and head_dim_og < 128:
567
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
568
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
569
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
570
+ elif head_dim_og > 128:
571
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
572
+
573
+ # assert last dim is contiguous
574
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
575
+ "Last dim of qkv must be contiguous."
576
+ )
577
+
578
+ if sm_scale is None:
579
+ sm_scale = head_dim_og**-0.5
580
+
581
+ seq_dim = 1 if _tensor_layout == 0 else 2
582
+ nh_dim = 2 if _tensor_layout == 0 else 1
583
+
584
+ if smooth_k:
585
+ km = k.mean(dim=seq_dim, keepdim=True)
586
+ nqheads = q.size(2)
587
+ nkheads = k.size(2)
588
+ q_per_kv_heads = nqheads // nkheads
589
+ if q_per_kv_heads > 1:
590
+ # nheads_k => nheads_q
591
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
592
+ else:
593
+ km_broadcast = km
594
+ if return_lse:
595
+ if tensor_layout == "NHD":
596
+ lse_correction = (
597
+ torch.matmul(
598
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
599
+ )
600
+ .squeeze(-1)
601
+ .to(torch.float32)
602
+ )
603
+ else:
604
+ lse_correction = (
605
+ torch.matmul(q, km_broadcast.transpose(2, 3))
606
+ .squeeze(-1)
607
+ .to(torch.float32)
608
+ )
609
+ else:
610
+ km = None
611
+
612
+ if qk_quant_gran == "per_warp":
613
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
614
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
615
+ )
616
+ elif qk_quant_gran == "per_thread":
617
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
618
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
619
+ )
620
+
621
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
622
+
623
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
624
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
625
+ smooth_v = False
626
+
627
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
628
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
629
+ smooth_v = False
630
+
631
+ quant_v_scale_max = 448.0
632
+ if pv_accum_dtype == "fp32+fp16":
633
+ quant_v_scale_max = 2.25
634
+
635
+ v_fp8, v_scale, vm = per_channel_fp8(
636
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
637
+ )
638
+ print("before kernel call")
639
+ if pv_accum_dtype == "fp32":
640
+ if smooth_v:
641
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
642
+ q_int8,
643
+ k_int8,
644
+ v_fp8,
645
+ o,
646
+ q_scale,
647
+ k_scale,
648
+ v_scale,
649
+ vm,
650
+ _tensor_layout,
651
+ _is_caual,
652
+ _qk_quant_gran,
653
+ sm_scale,
654
+ _return_lse,
655
+ )
656
+ torch.cuda.synchronize()
657
+ else:
658
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
659
+ q_int8,
660
+ k_int8,
661
+ v_fp8,
662
+ o,
663
+ q_scale,
664
+ k_scale,
665
+ v_scale,
666
+ _tensor_layout,
667
+ _is_caual,
668
+ _qk_quant_gran,
669
+ sm_scale,
670
+ _return_lse,
671
+ )
672
+ torch.cuda.synchronize()
673
+ elif pv_accum_dtype == "fp32+fp32":
674
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
675
+ q_int8,
676
+ k_int8,
677
+ v_fp8,
678
+ o,
679
+ q_scale,
680
+ k_scale,
681
+ v_scale,
682
+ _tensor_layout,
683
+ _is_caual,
684
+ _qk_quant_gran,
685
+ sm_scale,
686
+ _return_lse,
687
+ )
688
+ torch.cuda.synchronize()
689
+ elif pv_accum_dtype == "fp32+fp16":
690
+ lse = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
691
+ q_int8,
692
+ k_int8,
693
+ v_fp8,
694
+ o,
695
+ q_scale,
696
+ k_scale,
697
+ v_scale,
698
+ _tensor_layout,
699
+ _is_caual,
700
+ _qk_quant_gran,
701
+ sm_scale,
702
+ _return_lse,
703
+ )
704
+ torch.cuda.synchronize()
705
+ o = o[..., :head_dim_og]
706
+ print("after kernel call")
707
+ if return_lse:
708
+ return (
709
+ o,
710
+ lse / 1.44269504 + lse_correction * sm_scale
711
+ if smooth_k
712
+ else lse / 1.44269504,
713
+ )
714
+ else:
715
+ return o
716
+
717
+
718
+ @torch.compiler.disable
719
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
720
+ q: torch.Tensor,
721
+ k: torch.Tensor,
722
+ v: torch.Tensor,
723
+ tensor_layout: str = "HND",
724
+ is_causal: bool = False,
725
+ qk_quant_gran: str = "per_thread",
726
+ sm_scale: Optional[float] = None,
727
+ pv_accum_dtype: str = "fp32+fp32",
728
+ smooth_k: bool = True,
729
+ return_lse: bool = False,
730
+ **kwargs: Any,
731
+ ) -> torch.Tensor:
732
+ """
733
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
734
+
735
+ Parameters
736
+ ----------
737
+ q : torch.Tensor
738
+ The query tensor. Shape:
739
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
740
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
741
+
742
+ k : torch.Tensor
743
+ The key tensor. Shape:
744
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
745
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
746
+
747
+ v : torch.Tensor
748
+ The value tensor. Shape:
749
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
750
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
751
+
752
+ tensor_layout : str
753
+ The tensor layout, either "HND" or "NHD".
754
+ Default: "HND".
755
+
756
+ is_causal : bool
757
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
758
+ Default: False.
759
+
760
+ qk_quant_gran : str
761
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
762
+ Default: "per_thread".
763
+
764
+ sm_scale : Optional[float]
765
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
766
+
767
+ pv_accum_dtype : str
768
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
769
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
770
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
771
+ Default: "fp32+fp32".
772
+
773
+ smooth_k : bool
774
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
775
+ Default: True.
776
+
777
+ return_lse : bool
778
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
779
+ Default: False.
780
+
781
+ Returns
782
+ -------
783
+ torch.Tensor
784
+ The output tensor. Shape:
785
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
786
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
787
+
788
+ torch.Tensor
789
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
790
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
791
+ Only returned if `return_lse` is True.
792
+
793
+ Note
794
+ ----
795
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
796
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
797
+ - All tensors must be on the same cuda device.
798
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
799
+ """
800
+
801
+ dtype = q.dtype
802
+ assert q.is_cuda, "Input tensors must be on cuda."
803
+ assert dtype in [torch.float16, torch.bfloat16], (
804
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
805
+ )
806
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
807
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
808
+ )
809
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
810
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
811
+
812
+ torch.cuda.set_device(v.device)
813
+
814
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
815
+ _is_caual = 1 if is_causal else 0
816
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
817
+ _return_lse = 1 if return_lse else 0
818
+
819
+ head_dim_og = q.size(-1)
820
+
821
+ if head_dim_og < 64:
822
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
823
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
824
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
825
+ elif head_dim_og > 64 and head_dim_og < 128:
826
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
827
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
828
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
829
+ elif head_dim_og > 128:
830
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
831
+
832
+ # assert last dim is contiguous
833
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
834
+ "Last dim of qkv must be contiguous."
835
+ )
836
+
837
+ if sm_scale is None:
838
+ sm_scale = head_dim_og**-0.5
839
+
840
+ seq_dim = 1 if _tensor_layout == 0 else 2
841
+ nh_dim = 2 if _tensor_layout == 0 else 1
842
+
843
+ if smooth_k:
844
+ km = k.mean(dim=seq_dim, keepdim=True)
845
+ nqheads = q.size(2)
846
+ nkheads = k.size(2)
847
+ q_per_kv_heads = nqheads // nkheads
848
+ if q_per_kv_heads > 1:
849
+ # nheads_k => nheads_q
850
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
851
+ else:
852
+ km_broadcast = km
853
+ if return_lse:
854
+ if tensor_layout == "NHD":
855
+ lse_correction = (
856
+ torch.matmul(
857
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
858
+ )
859
+ .squeeze(-1)
860
+ .to(torch.float32)
861
+ )
862
+ else:
863
+ lse_correction = (
864
+ torch.matmul(q, km_broadcast.transpose(2, 3))
865
+ .squeeze(-1)
866
+ .to(torch.float32)
867
+ )
868
+ else:
869
+ km = None
870
+
871
+ if qk_quant_gran == "per_warp":
872
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
873
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
874
+ )
875
+ elif qk_quant_gran == "per_thread":
876
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
877
+ q,
878
+ k,
879
+ km,
880
+ tensor_layout=tensor_layout,
881
+ BLKQ=64,
882
+ WARPQ=16,
883
+ BLKK=128,
884
+ WARPK=128,
885
+ )
886
+
887
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
888
+
889
+ # pad v to multiple of 128
890
+ # TODO: modify per_channel_fp8 kernel to handle this
891
+ kv_len = k.size(seq_dim)
892
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
893
+ if v_pad_len > 0:
894
+ if tensor_layout == "HND":
895
+ v = torch.cat(
896
+ [
897
+ v,
898
+ torch.zeros(
899
+ v.size(0),
900
+ v.size(1),
901
+ v_pad_len,
902
+ v.size(3),
903
+ dtype=v.dtype,
904
+ device=v.device,
905
+ ),
906
+ ],
907
+ dim=2,
908
+ )
909
+ else:
910
+ v = torch.cat(
911
+ [
912
+ v,
913
+ torch.zeros(
914
+ v.size(0),
915
+ v_pad_len,
916
+ v.size(2),
917
+ v.size(3),
918
+ dtype=v.dtype,
919
+ device=v.device,
920
+ ),
921
+ ],
922
+ dim=1,
923
+ )
924
+
925
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
926
+
927
+ if pv_accum_dtype == "fp32":
928
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
929
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
930
+ q_int8,
931
+ k_int8,
932
+ v_fp8,
933
+ o,
934
+ q_scale,
935
+ k_scale,
936
+ v_scale,
937
+ _tensor_layout,
938
+ _is_caual,
939
+ _qk_quant_gran,
940
+ sm_scale,
941
+ _return_lse,
942
+ )
943
+ elif pv_accum_dtype == "fp32+fp32":
944
+ print(
945
+ "qint8",
946
+ q_int8.shape,
947
+ "qscale",
948
+ q_scale.shape,
949
+ "kint8",
950
+ k_int8.shape,
951
+ "kscale",
952
+ k_scale.shape,
953
+ "vfp8",
954
+ v_fp8.shape,
955
+ "vscale",
956
+ v_scale.shape,
957
+ )
958
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
+ q_int8,
960
+ k_int8,
961
+ v_fp8,
962
+ o,
963
+ q_scale,
964
+ k_scale,
965
+ v_scale,
966
+ _tensor_layout,
967
+ _is_caual,
968
+ _qk_quant_gran,
969
+ sm_scale,
970
+ _return_lse,
971
+ )
972
+
973
+ o = o[..., :head_dim_og]
974
+
975
+ if return_lse:
976
+ return (
977
+ o,
978
+ lse / 1.44269504 + lse_correction * sm_scale
979
+ if smooth_k
980
+ else lse / 1.44269504,
981
+ )
982
+ else:
983
+ return o
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/layers.py ADDED
File without changes
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
+
4
+
5
+ __all__ = [
6
+ "per_block_int8",
7
+ "per_warp_int8",
8
+ "sub_mean",
9
+ "per_channel_fp8",
10
+ "sageattn",
11
+ "sageattn_qk_int8_pv_fp8_cuda",
12
+ ]
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (433 Bytes). View file
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (550 Bytes). View file
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc ADDED
Binary file (33.4 kB). View file
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc ADDED
Binary file (13.4 kB). View file
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc ADDED
Binary file (13 kB). View file
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _sage_attention_44b112f_dirty
3
+ ops = torch.ops._sage_attention_44b112f_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_sage_attention_44b112f_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28e181de0c6388653fb4b8b2d7347f1f547fc84fe7dc45bc66db9b1431d141bc
3
+ size 26037392
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/core.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ from .quant import per_warp_int8 as per_warp_int8_cuda
24
+ from .quant import sub_mean
25
+ from .quant import per_channel_fp8
26
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
27
+
28
+ from typing import Any, List, Literal, Optional, Tuple, Union
29
+ import warnings
30
+
31
+
32
+ import subprocess
33
+ import re
34
+
35
+
36
+ def get_cuda_version():
37
+ try:
38
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
39
+ match = re.search(r"release (\d+)\.(\d+)", output)
40
+ if match:
41
+ major, minor = int(match.group(1)), int(match.group(2))
42
+ return major, minor
43
+ except Exception as e:
44
+ print("Failed to get CUDA version:", e)
45
+ return None, None
46
+
47
+
48
+ def get_cuda_arch_versions():
49
+ cuda_archs = []
50
+ for i in range(torch.cuda.device_count()):
51
+ major, minor = torch.cuda.get_device_capability(i)
52
+ cuda_archs.append(f"sm{major}{minor}")
53
+ return cuda_archs
54
+
55
+
56
+ def sageattn(
57
+ q: torch.Tensor,
58
+ k: torch.Tensor,
59
+ v: torch.Tensor,
60
+ tensor_layout: str = "HND",
61
+ is_causal: bool = False,
62
+ sm_scale: Optional[float] = None,
63
+ return_lse: bool = False,
64
+ **kwargs: Any,
65
+ ):
66
+ """
67
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
68
+
69
+ Parameters
70
+ ----------
71
+ q : torch.Tensor
72
+ The query tensor. Shape:
73
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
74
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
75
+
76
+ k : torch.Tensor
77
+ The key tensor. Shape:
78
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
79
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
80
+
81
+ v : torch.Tensor
82
+ The value tensor. Shape:
83
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
84
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
85
+
86
+ tensor_layout : str
87
+ The tensor layout, either "HND" or "NHD".
88
+ Default: "HND".
89
+
90
+ is_causal : bool
91
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
92
+ Default: False.
93
+
94
+ sm_scale : Optional[float]
95
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
96
+
97
+ return_lse : bool
98
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
99
+ Default: False.
100
+
101
+ Returns
102
+ -------
103
+ torch.Tensor
104
+ The output tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ torch.Tensor
109
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
110
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
111
+ Only returned if `return_lse` is True.
112
+
113
+ Note
114
+ ----
115
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
116
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
117
+ - All tensors must be on the same cuda device.
118
+ """
119
+
120
+ arch = get_cuda_arch_versions()[q.device.index]
121
+ if arch == "sm80":
122
+ return sageattn_qk_int8_pv_fp16_cuda(
123
+ q,
124
+ k,
125
+ v,
126
+ tensor_layout=tensor_layout,
127
+ is_causal=is_causal,
128
+ sm_scale=sm_scale,
129
+ return_lse=return_lse,
130
+ pv_accum_dtype="fp32",
131
+ )
132
+ elif arch == "sm89":
133
+ return sageattn_qk_int8_pv_fp8_cuda(
134
+ q,
135
+ k,
136
+ v,
137
+ tensor_layout=tensor_layout,
138
+ is_causal=is_causal,
139
+ sm_scale=sm_scale,
140
+ return_lse=return_lse,
141
+ pv_accum_dtype="fp32+fp16",
142
+ )
143
+ elif arch == "sm90":
144
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
145
+ q,
146
+ k,
147
+ v,
148
+ tensor_layout=tensor_layout,
149
+ is_causal=is_causal,
150
+ sm_scale=sm_scale,
151
+ return_lse=return_lse,
152
+ pv_accum_dtype="fp32+fp32",
153
+ )
154
+ elif arch == "sm120":
155
+ return sageattn_qk_int8_pv_fp8_cuda(
156
+ q,
157
+ k,
158
+ v,
159
+ tensor_layout=tensor_layout,
160
+ is_causal=is_causal,
161
+ qk_quant_gran="per_warp",
162
+ sm_scale=sm_scale,
163
+ return_lse=return_lse,
164
+ pv_accum_dtype="fp32+fp16",
165
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
166
+ else:
167
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
168
+
169
+
170
+ @torch.compiler.disable
171
+ def sageattn_qk_int8_pv_fp16_cuda(
172
+ q: torch.Tensor,
173
+ k: torch.Tensor,
174
+ v: torch.Tensor,
175
+ tensor_layout: str = "HND",
176
+ is_causal: bool = False,
177
+ qk_quant_gran: str = "per_thread",
178
+ sm_scale: Optional[float] = None,
179
+ pv_accum_dtype: str = "fp32",
180
+ smooth_k: bool = True,
181
+ smooth_v: bool = False,
182
+ return_lse: bool = False,
183
+ **kwargs: Any,
184
+ ) -> torch.Tensor:
185
+ """
186
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
187
+
188
+ Parameters
189
+ ----------
190
+ q : torch.Tensor
191
+ The query tensor. Shape:
192
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
193
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
194
+
195
+ k : torch.Tensor
196
+ The key tensor. Shape:
197
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
198
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
199
+
200
+ v : torch.Tensor
201
+ The value tensor. Shape:
202
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
203
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
204
+
205
+ tensor_layout : str
206
+ The tensor layout, either "HND" or "NHD".
207
+ Default: "HND".
208
+
209
+ is_causal : bool
210
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
211
+ Default: False.
212
+
213
+ qk_quant_gran : str
214
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
215
+ Default: "per_thread".
216
+
217
+ sm_scale : Optional[float]
218
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
219
+
220
+ pv_accum_dtype : str
221
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
222
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
223
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
224
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
225
+ Default: "fp32".
226
+
227
+ smooth_k : bool
228
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
229
+ Default: True.
230
+
231
+ smooth_v : bool
232
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
233
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
234
+ Default: False.
235
+
236
+ return_lse : bool
237
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
238
+ Default: False.
239
+
240
+ Returns
241
+ -------
242
+ torch.Tensor
243
+ The output tensor. Shape:
244
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
245
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
246
+
247
+ torch.Tensor
248
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
249
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
250
+ Only returned if `return_lse` is True.
251
+
252
+ Note
253
+ ----
254
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
255
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
256
+ - All tensors must be on the same cuda device.
257
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
258
+ """
259
+
260
+ dtype = q.dtype
261
+ assert q.is_cuda, "Input tensors must be on cuda."
262
+ assert dtype in [torch.float16, torch.bfloat16], (
263
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
264
+ )
265
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
266
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
267
+ )
268
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
269
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
270
+
271
+ # FIXME(DefTruth): make sage attention work compatible with distributed
272
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
273
+ # sage attention will run into illegal memory access error after first
274
+ # inference step in distributed env for multi gpus inference. This small
275
+ # workaround also make sage attention work compatible with torch.compile
276
+ # through non-fullgraph compile mode.
277
+ torch.cuda.set_device(v.device)
278
+
279
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
280
+ _is_caual = 1 if is_causal else 0
281
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
282
+ _return_lse = 1 if return_lse else 0
283
+
284
+ head_dim_og = q.size(-1)
285
+
286
+ if head_dim_og < 64:
287
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
288
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
289
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
290
+ elif head_dim_og > 64 and head_dim_og < 128:
291
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
292
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
293
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
294
+ elif head_dim_og > 128:
295
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
296
+
297
+ # assert last dim is contiguous
298
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
299
+ "Last dim of qkv must be contiguous."
300
+ )
301
+
302
+ if sm_scale is None:
303
+ sm_scale = head_dim_og**-0.5
304
+
305
+ seq_dim = 1 if _tensor_layout == 0 else 2
306
+ nh_dim = 2 if _tensor_layout == 0 else 1
307
+
308
+ if smooth_k:
309
+ km = k.mean(dim=seq_dim, keepdim=True)
310
+ nqheads = q.size(2)
311
+ nkheads = k.size(2)
312
+ q_per_kv_heads = nqheads // nkheads
313
+ if q_per_kv_heads > 1:
314
+ # nheads_k => nheads_q
315
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
316
+ else:
317
+ km_broadcast = km
318
+ if return_lse:
319
+ if tensor_layout == "NHD":
320
+ lse_correction = (
321
+ torch.matmul(
322
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
323
+ )
324
+ .squeeze(-1)
325
+ .to(torch.float32)
326
+ )
327
+ else:
328
+ lse_correction = (
329
+ torch.matmul(q, km_broadcast.transpose(2, 3))
330
+ .squeeze(-1)
331
+ .to(torch.float32)
332
+ )
333
+ else:
334
+ km = None
335
+
336
+ if qk_quant_gran == "per_warp":
337
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
338
+ q,
339
+ k,
340
+ km,
341
+ tensor_layout=tensor_layout,
342
+ BLKQ=128,
343
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
344
+ BLKK=64,
345
+ )
346
+ elif qk_quant_gran == "per_thread":
347
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
348
+ q,
349
+ k,
350
+ km,
351
+ tensor_layout=tensor_layout,
352
+ BLKQ=128,
353
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
354
+ BLKK=64,
355
+ WARPK=64,
356
+ )
357
+
358
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
359
+
360
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
361
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
362
+ smooth_v = False
363
+
364
+ if pv_accum_dtype == "fp32":
365
+ v = v.to(torch.float16)
366
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
+ q_int8,
368
+ k_int8,
369
+ v,
370
+ o,
371
+ q_scale,
372
+ k_scale,
373
+ _tensor_layout,
374
+ _is_caual,
375
+ _qk_quant_gran,
376
+ sm_scale,
377
+ _return_lse,
378
+ )
379
+ elif pv_accum_dtype == "fp16":
380
+ if smooth_v:
381
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
+ q_int8,
384
+ k_int8,
385
+ smoothed_v,
386
+ o,
387
+ q_scale,
388
+ k_scale,
389
+ vm,
390
+ _tensor_layout,
391
+ _is_caual,
392
+ _qk_quant_gran,
393
+ sm_scale,
394
+ _return_lse,
395
+ )
396
+ else:
397
+ v = v.to(torch.float16)
398
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
+ q_int8,
400
+ k_int8,
401
+ v,
402
+ o,
403
+ q_scale,
404
+ k_scale,
405
+ _tensor_layout,
406
+ _is_caual,
407
+ _qk_quant_gran,
408
+ sm_scale,
409
+ _return_lse,
410
+ )
411
+ elif pv_accum_dtype == "fp16+fp32":
412
+ v = v.to(torch.float16)
413
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
+ q_int8,
415
+ k_int8,
416
+ v,
417
+ o,
418
+ q_scale,
419
+ k_scale,
420
+ _tensor_layout,
421
+ _is_caual,
422
+ _qk_quant_gran,
423
+ sm_scale,
424
+ _return_lse,
425
+ )
426
+ else:
427
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
428
+
429
+ o = o[..., :head_dim_og]
430
+
431
+ if return_lse:
432
+ return (
433
+ o,
434
+ lse / 1.44269504 + lse_correction * sm_scale
435
+ if smooth_k
436
+ else lse / 1.44269504,
437
+ )
438
+ else:
439
+ return o
440
+
441
+
442
+ @torch.compiler.disable
443
+ def sageattn_qk_int8_pv_fp8_cuda(
444
+ q: torch.Tensor,
445
+ k: torch.Tensor,
446
+ v: torch.Tensor,
447
+ tensor_layout: str = "HND",
448
+ is_causal: bool = False,
449
+ qk_quant_gran: str = "per_thread",
450
+ sm_scale: Optional[float] = None,
451
+ pv_accum_dtype: str = "fp32+fp16",
452
+ smooth_k: bool = True,
453
+ smooth_v: bool = False,
454
+ return_lse: bool = False,
455
+ **kwargs: Any,
456
+ ) -> torch.Tensor:
457
+ """
458
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
459
+
460
+ Parameters
461
+ ----------
462
+ q : torch.Tensor
463
+ The query tensor. Shape:
464
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
465
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
466
+
467
+ k : torch.Tensor
468
+ The key tensor. Shape:
469
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
470
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
471
+
472
+ v : torch.Tensor
473
+ The value tensor. Shape:
474
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
475
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
476
+
477
+ tensor_layout : str
478
+ The tensor layout, either "HND" or "NHD".
479
+ Default: "HND".
480
+
481
+ is_causal : bool
482
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
483
+ Default: False.
484
+
485
+ qk_quant_gran : str
486
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
487
+ Default: "per_thread".
488
+
489
+ sm_scale : Optional[float]
490
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
491
+
492
+ pv_accum_dtype : str
493
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
494
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
495
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
496
+ Default: "fp32+fp32".
497
+
498
+ smooth_k : bool
499
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
500
+ Default: True.
501
+
502
+ smooth_v : bool
503
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
504
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
505
+ Default: False.
506
+
507
+ return_lse : bool
508
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
509
+ Default: False.
510
+
511
+ Returns
512
+ -------
513
+ torch.Tensor
514
+ The output tensor. Shape:
515
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
516
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
517
+
518
+ torch.Tensor
519
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
520
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
521
+ Only returned if `return_lse` is True.
522
+
523
+ Note
524
+ ----
525
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
526
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
527
+ - All tensors must be on the same cuda device.
528
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
529
+ """
530
+
531
+ dtype = q.dtype
532
+ assert q.is_cuda, "Input tensors must be on cuda."
533
+ assert dtype in [torch.float16, torch.bfloat16], (
534
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
535
+ )
536
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
537
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
538
+ )
539
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
540
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
541
+
542
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
543
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
544
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
545
+ # pv_accum_dtype = 'fp32+fp32'
546
+
547
+ # FIXME(DefTruth): make sage attention work compatible with distributed
548
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
549
+ # sage attention will run into illegal memory access error after first
550
+ # inference step in distributed env for multi gpus inference. This small
551
+ # workaround also make sage attention work compatible with torch.compile
552
+ # through non-fullgraph compile mode.
553
+ torch.cuda.set_device(v.device)
554
+
555
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
556
+ _is_caual = 1 if is_causal else 0
557
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
558
+ _return_lse = 1 if return_lse else 0
559
+
560
+ head_dim_og = q.size(-1)
561
+
562
+ if head_dim_og < 64:
563
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
564
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
565
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
566
+ elif head_dim_og > 64 and head_dim_og < 128:
567
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
568
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
569
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
570
+ elif head_dim_og > 128:
571
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
572
+
573
+ # assert last dim is contiguous
574
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
575
+ "Last dim of qkv must be contiguous."
576
+ )
577
+
578
+ if sm_scale is None:
579
+ sm_scale = head_dim_og**-0.5
580
+
581
+ seq_dim = 1 if _tensor_layout == 0 else 2
582
+ nh_dim = 2 if _tensor_layout == 0 else 1
583
+
584
+ if smooth_k:
585
+ km = k.mean(dim=seq_dim, keepdim=True)
586
+ nqheads = q.size(2)
587
+ nkheads = k.size(2)
588
+ q_per_kv_heads = nqheads // nkheads
589
+ if q_per_kv_heads > 1:
590
+ # nheads_k => nheads_q
591
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
592
+ else:
593
+ km_broadcast = km
594
+ if return_lse:
595
+ if tensor_layout == "NHD":
596
+ lse_correction = (
597
+ torch.matmul(
598
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
599
+ )
600
+ .squeeze(-1)
601
+ .to(torch.float32)
602
+ )
603
+ else:
604
+ lse_correction = (
605
+ torch.matmul(q, km_broadcast.transpose(2, 3))
606
+ .squeeze(-1)
607
+ .to(torch.float32)
608
+ )
609
+ else:
610
+ km = None
611
+
612
+ if qk_quant_gran == "per_warp":
613
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
614
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
615
+ )
616
+ elif qk_quant_gran == "per_thread":
617
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
618
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
619
+ )
620
+
621
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
622
+
623
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
624
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
625
+ smooth_v = False
626
+
627
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
628
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
629
+ smooth_v = False
630
+
631
+ quant_v_scale_max = 448.0
632
+ if pv_accum_dtype == "fp32+fp16":
633
+ quant_v_scale_max = 2.25
634
+
635
+ v_fp8, v_scale, vm = per_channel_fp8(
636
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
637
+ )
638
+ print("before kernel call")
639
+ if pv_accum_dtype == "fp32":
640
+ if smooth_v:
641
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
642
+ q_int8,
643
+ k_int8,
644
+ v_fp8,
645
+ o,
646
+ q_scale,
647
+ k_scale,
648
+ v_scale,
649
+ vm,
650
+ _tensor_layout,
651
+ _is_caual,
652
+ _qk_quant_gran,
653
+ sm_scale,
654
+ _return_lse,
655
+ )
656
+ torch.cuda.synchronize()
657
+ else:
658
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
659
+ q_int8,
660
+ k_int8,
661
+ v_fp8,
662
+ o,
663
+ q_scale,
664
+ k_scale,
665
+ v_scale,
666
+ _tensor_layout,
667
+ _is_caual,
668
+ _qk_quant_gran,
669
+ sm_scale,
670
+ _return_lse,
671
+ )
672
+ torch.cuda.synchronize()
673
+ elif pv_accum_dtype == "fp32+fp32":
674
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
675
+ q_int8,
676
+ k_int8,
677
+ v_fp8,
678
+ o,
679
+ q_scale,
680
+ k_scale,
681
+ v_scale,
682
+ _tensor_layout,
683
+ _is_caual,
684
+ _qk_quant_gran,
685
+ sm_scale,
686
+ _return_lse,
687
+ )
688
+ torch.cuda.synchronize()
689
+ elif pv_accum_dtype == "fp32+fp16":
690
+ lse = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
691
+ q_int8,
692
+ k_int8,
693
+ v_fp8,
694
+ o,
695
+ q_scale,
696
+ k_scale,
697
+ v_scale,
698
+ _tensor_layout,
699
+ _is_caual,
700
+ _qk_quant_gran,
701
+ sm_scale,
702
+ _return_lse,
703
+ )
704
+ torch.cuda.synchronize()
705
+ o = o[..., :head_dim_og]
706
+ print("after kernel call")
707
+ if return_lse:
708
+ return (
709
+ o,
710
+ lse / 1.44269504 + lse_correction * sm_scale
711
+ if smooth_k
712
+ else lse / 1.44269504,
713
+ )
714
+ else:
715
+ return o
716
+
717
+
718
+ @torch.compiler.disable
719
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
720
+ q: torch.Tensor,
721
+ k: torch.Tensor,
722
+ v: torch.Tensor,
723
+ tensor_layout: str = "HND",
724
+ is_causal: bool = False,
725
+ qk_quant_gran: str = "per_thread",
726
+ sm_scale: Optional[float] = None,
727
+ pv_accum_dtype: str = "fp32+fp32",
728
+ smooth_k: bool = True,
729
+ return_lse: bool = False,
730
+ **kwargs: Any,
731
+ ) -> torch.Tensor:
732
+ """
733
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
734
+
735
+ Parameters
736
+ ----------
737
+ q : torch.Tensor
738
+ The query tensor. Shape:
739
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
740
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
741
+
742
+ k : torch.Tensor
743
+ The key tensor. Shape:
744
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
745
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
746
+
747
+ v : torch.Tensor
748
+ The value tensor. Shape:
749
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
750
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
751
+
752
+ tensor_layout : str
753
+ The tensor layout, either "HND" or "NHD".
754
+ Default: "HND".
755
+
756
+ is_causal : bool
757
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
758
+ Default: False.
759
+
760
+ qk_quant_gran : str
761
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
762
+ Default: "per_thread".
763
+
764
+ sm_scale : Optional[float]
765
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
766
+
767
+ pv_accum_dtype : str
768
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
769
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
770
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
771
+ Default: "fp32+fp32".
772
+
773
+ smooth_k : bool
774
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
775
+ Default: True.
776
+
777
+ return_lse : bool
778
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
779
+ Default: False.
780
+
781
+ Returns
782
+ -------
783
+ torch.Tensor
784
+ The output tensor. Shape:
785
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
786
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
787
+
788
+ torch.Tensor
789
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
790
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
791
+ Only returned if `return_lse` is True.
792
+
793
+ Note
794
+ ----
795
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
796
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
797
+ - All tensors must be on the same cuda device.
798
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
799
+ """
800
+
801
+ dtype = q.dtype
802
+ assert q.is_cuda, "Input tensors must be on cuda."
803
+ assert dtype in [torch.float16, torch.bfloat16], (
804
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
805
+ )
806
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
807
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
808
+ )
809
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
810
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
811
+
812
+ torch.cuda.set_device(v.device)
813
+
814
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
815
+ _is_caual = 1 if is_causal else 0
816
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
817
+ _return_lse = 1 if return_lse else 0
818
+
819
+ head_dim_og = q.size(-1)
820
+
821
+ if head_dim_og < 64:
822
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
823
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
824
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
825
+ elif head_dim_og > 64 and head_dim_og < 128:
826
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
827
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
828
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
829
+ elif head_dim_og > 128:
830
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
831
+
832
+ # assert last dim is contiguous
833
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
834
+ "Last dim of qkv must be contiguous."
835
+ )
836
+
837
+ if sm_scale is None:
838
+ sm_scale = head_dim_og**-0.5
839
+
840
+ seq_dim = 1 if _tensor_layout == 0 else 2
841
+ nh_dim = 2 if _tensor_layout == 0 else 1
842
+
843
+ if smooth_k:
844
+ km = k.mean(dim=seq_dim, keepdim=True)
845
+ nqheads = q.size(2)
846
+ nkheads = k.size(2)
847
+ q_per_kv_heads = nqheads // nkheads
848
+ if q_per_kv_heads > 1:
849
+ # nheads_k => nheads_q
850
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
851
+ else:
852
+ km_broadcast = km
853
+ if return_lse:
854
+ if tensor_layout == "NHD":
855
+ lse_correction = (
856
+ torch.matmul(
857
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
858
+ )
859
+ .squeeze(-1)
860
+ .to(torch.float32)
861
+ )
862
+ else:
863
+ lse_correction = (
864
+ torch.matmul(q, km_broadcast.transpose(2, 3))
865
+ .squeeze(-1)
866
+ .to(torch.float32)
867
+ )
868
+ else:
869
+ km = None
870
+
871
+ if qk_quant_gran == "per_warp":
872
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
873
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
874
+ )
875
+ elif qk_quant_gran == "per_thread":
876
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
877
+ q,
878
+ k,
879
+ km,
880
+ tensor_layout=tensor_layout,
881
+ BLKQ=64,
882
+ WARPQ=16,
883
+ BLKK=128,
884
+ WARPK=128,
885
+ )
886
+
887
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
888
+
889
+ # pad v to multiple of 128
890
+ # TODO: modify per_channel_fp8 kernel to handle this
891
+ kv_len = k.size(seq_dim)
892
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
893
+ if v_pad_len > 0:
894
+ if tensor_layout == "HND":
895
+ v = torch.cat(
896
+ [
897
+ v,
898
+ torch.zeros(
899
+ v.size(0),
900
+ v.size(1),
901
+ v_pad_len,
902
+ v.size(3),
903
+ dtype=v.dtype,
904
+ device=v.device,
905
+ ),
906
+ ],
907
+ dim=2,
908
+ )
909
+ else:
910
+ v = torch.cat(
911
+ [
912
+ v,
913
+ torch.zeros(
914
+ v.size(0),
915
+ v_pad_len,
916
+ v.size(2),
917
+ v.size(3),
918
+ dtype=v.dtype,
919
+ device=v.device,
920
+ ),
921
+ ],
922
+ dim=1,
923
+ )
924
+
925
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
926
+
927
+ if pv_accum_dtype == "fp32":
928
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
929
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
930
+ q_int8,
931
+ k_int8,
932
+ v_fp8,
933
+ o,
934
+ q_scale,
935
+ k_scale,
936
+ v_scale,
937
+ _tensor_layout,
938
+ _is_caual,
939
+ _qk_quant_gran,
940
+ sm_scale,
941
+ _return_lse,
942
+ )
943
+ elif pv_accum_dtype == "fp32+fp32":
944
+ print(
945
+ "qint8",
946
+ q_int8.shape,
947
+ "qscale",
948
+ q_scale.shape,
949
+ "kint8",
950
+ k_int8.shape,
951
+ "kscale",
952
+ k_scale.shape,
953
+ "vfp8",
954
+ v_fp8.shape,
955
+ "vscale",
956
+ v_scale.shape,
957
+ )
958
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
+ q_int8,
960
+ k_int8,
961
+ v_fp8,
962
+ o,
963
+ q_scale,
964
+ k_scale,
965
+ v_scale,
966
+ _tensor_layout,
967
+ _is_caual,
968
+ _qk_quant_gran,
969
+ sm_scale,
970
+ _return_lse,
971
+ )
972
+
973
+ o = o[..., :head_dim_og]
974
+
975
+ if return_lse:
976
+ return (
977
+ o,
978
+ lse / 1.44269504 + lse_correction * sm_scale
979
+ if smooth_k
980
+ else lse / 1.44269504,
981
+ )
982
+ else:
983
+ return o
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/layers.py ADDED
File without changes
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
+
4
+
5
+ __all__ = [
6
+ "per_block_int8",
7
+ "per_warp_int8",
8
+ "sub_mean",
9
+ "per_channel_fp8",
10
+ "sageattn",
11
+ "sageattn_qk_int8_pv_fp8_cuda",
12
+ ]
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (433 Bytes). View file
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (550 Bytes). View file
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc ADDED
Binary file (33.4 kB). View file
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc ADDED
Binary file (13.4 kB). View file
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc ADDED
Binary file (13 kB). View file
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _sage_attention_44b112f_dirty
3
+ ops = torch.ops._sage_attention_44b112f_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_sage_attention_44b112f_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/_sage_attention_44b112f_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:826ab66e6c33b3b2b17c30371934a55e972d560197c5492f4dedf6fcc29f1a1e
3
+ size 26553920
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/core.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ from .quant import per_warp_int8 as per_warp_int8_cuda
24
+ from .quant import sub_mean
25
+ from .quant import per_channel_fp8
26
+ from .quant_per_thread import per_thread_int8 as per_thread_int8_triton
27
+
28
+ from typing import Any, List, Literal, Optional, Tuple, Union
29
+ import warnings
30
+
31
+
32
+ import subprocess
33
+ import re
34
+
35
+
36
+ def get_cuda_version():
37
+ try:
38
+ output = subprocess.check_output(["nvcc", "--version"]).decode()
39
+ match = re.search(r"release (\d+)\.(\d+)", output)
40
+ if match:
41
+ major, minor = int(match.group(1)), int(match.group(2))
42
+ return major, minor
43
+ except Exception as e:
44
+ print("Failed to get CUDA version:", e)
45
+ return None, None
46
+
47
+
48
+ def get_cuda_arch_versions():
49
+ cuda_archs = []
50
+ for i in range(torch.cuda.device_count()):
51
+ major, minor = torch.cuda.get_device_capability(i)
52
+ cuda_archs.append(f"sm{major}{minor}")
53
+ return cuda_archs
54
+
55
+
56
+ def sageattn(
57
+ q: torch.Tensor,
58
+ k: torch.Tensor,
59
+ v: torch.Tensor,
60
+ tensor_layout: str = "HND",
61
+ is_causal: bool = False,
62
+ sm_scale: Optional[float] = None,
63
+ return_lse: bool = False,
64
+ **kwargs: Any,
65
+ ):
66
+ """
67
+ Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
68
+
69
+ Parameters
70
+ ----------
71
+ q : torch.Tensor
72
+ The query tensor. Shape:
73
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
74
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
75
+
76
+ k : torch.Tensor
77
+ The key tensor. Shape:
78
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
79
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
80
+
81
+ v : torch.Tensor
82
+ The value tensor. Shape:
83
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
84
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
85
+
86
+ tensor_layout : str
87
+ The tensor layout, either "HND" or "NHD".
88
+ Default: "HND".
89
+
90
+ is_causal : bool
91
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
92
+ Default: False.
93
+
94
+ sm_scale : Optional[float]
95
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
96
+
97
+ return_lse : bool
98
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
99
+ Default: False.
100
+
101
+ Returns
102
+ -------
103
+ torch.Tensor
104
+ The output tensor. Shape:
105
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
106
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
107
+
108
+ torch.Tensor
109
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
110
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
111
+ Only returned if `return_lse` is True.
112
+
113
+ Note
114
+ ----
115
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
116
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
117
+ - All tensors must be on the same cuda device.
118
+ """
119
+
120
+ arch = get_cuda_arch_versions()[q.device.index]
121
+ if arch == "sm80":
122
+ return sageattn_qk_int8_pv_fp16_cuda(
123
+ q,
124
+ k,
125
+ v,
126
+ tensor_layout=tensor_layout,
127
+ is_causal=is_causal,
128
+ sm_scale=sm_scale,
129
+ return_lse=return_lse,
130
+ pv_accum_dtype="fp32",
131
+ )
132
+ elif arch == "sm89":
133
+ return sageattn_qk_int8_pv_fp8_cuda(
134
+ q,
135
+ k,
136
+ v,
137
+ tensor_layout=tensor_layout,
138
+ is_causal=is_causal,
139
+ sm_scale=sm_scale,
140
+ return_lse=return_lse,
141
+ pv_accum_dtype="fp32+fp16",
142
+ )
143
+ elif arch == "sm90":
144
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
145
+ q,
146
+ k,
147
+ v,
148
+ tensor_layout=tensor_layout,
149
+ is_causal=is_causal,
150
+ sm_scale=sm_scale,
151
+ return_lse=return_lse,
152
+ pv_accum_dtype="fp32+fp32",
153
+ )
154
+ elif arch == "sm120":
155
+ return sageattn_qk_int8_pv_fp8_cuda(
156
+ q,
157
+ k,
158
+ v,
159
+ tensor_layout=tensor_layout,
160
+ is_causal=is_causal,
161
+ qk_quant_gran="per_warp",
162
+ sm_scale=sm_scale,
163
+ return_lse=return_lse,
164
+ pv_accum_dtype="fp32+fp16",
165
+ ) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
166
+ else:
167
+ raise ValueError(f"Unsupported CUDA architecture: {arch}")
168
+
169
+
170
+ @torch.compiler.disable
171
+ def sageattn_qk_int8_pv_fp16_cuda(
172
+ q: torch.Tensor,
173
+ k: torch.Tensor,
174
+ v: torch.Tensor,
175
+ tensor_layout: str = "HND",
176
+ is_causal: bool = False,
177
+ qk_quant_gran: str = "per_thread",
178
+ sm_scale: Optional[float] = None,
179
+ pv_accum_dtype: str = "fp32",
180
+ smooth_k: bool = True,
181
+ smooth_v: bool = False,
182
+ return_lse: bool = False,
183
+ **kwargs: Any,
184
+ ) -> torch.Tensor:
185
+ """
186
+ SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
187
+
188
+ Parameters
189
+ ----------
190
+ q : torch.Tensor
191
+ The query tensor. Shape:
192
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
193
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
194
+
195
+ k : torch.Tensor
196
+ The key tensor. Shape:
197
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
198
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
199
+
200
+ v : torch.Tensor
201
+ The value tensor. Shape:
202
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
203
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
204
+
205
+ tensor_layout : str
206
+ The tensor layout, either "HND" or "NHD".
207
+ Default: "HND".
208
+
209
+ is_causal : bool
210
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
211
+ Default: False.
212
+
213
+ qk_quant_gran : str
214
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
215
+ Default: "per_thread".
216
+
217
+ sm_scale : Optional[float]
218
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
219
+
220
+ pv_accum_dtype : str
221
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
222
+ - "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
223
+ - "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
224
+ - "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
225
+ Default: "fp32".
226
+
227
+ smooth_k : bool
228
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
229
+ Default: True.
230
+
231
+ smooth_v : bool
232
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
233
+ smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
234
+ Default: False.
235
+
236
+ return_lse : bool
237
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
238
+ Default: False.
239
+
240
+ Returns
241
+ -------
242
+ torch.Tensor
243
+ The output tensor. Shape:
244
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
245
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
246
+
247
+ torch.Tensor
248
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
249
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
250
+ Only returned if `return_lse` is True.
251
+
252
+ Note
253
+ ----
254
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
255
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
256
+ - All tensors must be on the same cuda device.
257
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
258
+ """
259
+
260
+ dtype = q.dtype
261
+ assert q.is_cuda, "Input tensors must be on cuda."
262
+ assert dtype in [torch.float16, torch.bfloat16], (
263
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
264
+ )
265
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
266
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
267
+ )
268
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
269
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
270
+
271
+ # FIXME(DefTruth): make sage attention work compatible with distributed
272
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
273
+ # sage attention will run into illegal memory access error after first
274
+ # inference step in distributed env for multi gpus inference. This small
275
+ # workaround also make sage attention work compatible with torch.compile
276
+ # through non-fullgraph compile mode.
277
+ torch.cuda.set_device(v.device)
278
+
279
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
280
+ _is_caual = 1 if is_causal else 0
281
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
282
+ _return_lse = 1 if return_lse else 0
283
+
284
+ head_dim_og = q.size(-1)
285
+
286
+ if head_dim_og < 64:
287
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
288
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
289
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
290
+ elif head_dim_og > 64 and head_dim_og < 128:
291
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
292
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
293
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
294
+ elif head_dim_og > 128:
295
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
296
+
297
+ # assert last dim is contiguous
298
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
299
+ "Last dim of qkv must be contiguous."
300
+ )
301
+
302
+ if sm_scale is None:
303
+ sm_scale = head_dim_og**-0.5
304
+
305
+ seq_dim = 1 if _tensor_layout == 0 else 2
306
+ nh_dim = 2 if _tensor_layout == 0 else 1
307
+
308
+ if smooth_k:
309
+ km = k.mean(dim=seq_dim, keepdim=True)
310
+ nqheads = q.size(2)
311
+ nkheads = k.size(2)
312
+ q_per_kv_heads = nqheads // nkheads
313
+ if q_per_kv_heads > 1:
314
+ # nheads_k => nheads_q
315
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
316
+ else:
317
+ km_broadcast = km
318
+ if return_lse:
319
+ if tensor_layout == "NHD":
320
+ lse_correction = (
321
+ torch.matmul(
322
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
323
+ )
324
+ .squeeze(-1)
325
+ .to(torch.float32)
326
+ )
327
+ else:
328
+ lse_correction = (
329
+ torch.matmul(q, km_broadcast.transpose(2, 3))
330
+ .squeeze(-1)
331
+ .to(torch.float32)
332
+ )
333
+ else:
334
+ km = None
335
+
336
+ if qk_quant_gran == "per_warp":
337
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
338
+ q,
339
+ k,
340
+ km,
341
+ tensor_layout=tensor_layout,
342
+ BLKQ=128,
343
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
344
+ BLKK=64,
345
+ )
346
+ elif qk_quant_gran == "per_thread":
347
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
348
+ q,
349
+ k,
350
+ km,
351
+ tensor_layout=tensor_layout,
352
+ BLKQ=128,
353
+ WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32),
354
+ BLKK=64,
355
+ WARPK=64,
356
+ )
357
+
358
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
359
+
360
+ if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
361
+ warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
362
+ smooth_v = False
363
+
364
+ if pv_accum_dtype == "fp32":
365
+ v = v.to(torch.float16)
366
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
+ q_int8,
368
+ k_int8,
369
+ v,
370
+ o,
371
+ q_scale,
372
+ k_scale,
373
+ _tensor_layout,
374
+ _is_caual,
375
+ _qk_quant_gran,
376
+ sm_scale,
377
+ _return_lse,
378
+ )
379
+ elif pv_accum_dtype == "fp16":
380
+ if smooth_v:
381
+ smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
+ q_int8,
384
+ k_int8,
385
+ smoothed_v,
386
+ o,
387
+ q_scale,
388
+ k_scale,
389
+ vm,
390
+ _tensor_layout,
391
+ _is_caual,
392
+ _qk_quant_gran,
393
+ sm_scale,
394
+ _return_lse,
395
+ )
396
+ else:
397
+ v = v.to(torch.float16)
398
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
+ q_int8,
400
+ k_int8,
401
+ v,
402
+ o,
403
+ q_scale,
404
+ k_scale,
405
+ _tensor_layout,
406
+ _is_caual,
407
+ _qk_quant_gran,
408
+ sm_scale,
409
+ _return_lse,
410
+ )
411
+ elif pv_accum_dtype == "fp16+fp32":
412
+ v = v.to(torch.float16)
413
+ lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
+ q_int8,
415
+ k_int8,
416
+ v,
417
+ o,
418
+ q_scale,
419
+ k_scale,
420
+ _tensor_layout,
421
+ _is_caual,
422
+ _qk_quant_gran,
423
+ sm_scale,
424
+ _return_lse,
425
+ )
426
+ else:
427
+ raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
428
+
429
+ o = o[..., :head_dim_og]
430
+
431
+ if return_lse:
432
+ return (
433
+ o,
434
+ lse / 1.44269504 + lse_correction * sm_scale
435
+ if smooth_k
436
+ else lse / 1.44269504,
437
+ )
438
+ else:
439
+ return o
440
+
441
+
442
+ @torch.compiler.disable
443
+ def sageattn_qk_int8_pv_fp8_cuda(
444
+ q: torch.Tensor,
445
+ k: torch.Tensor,
446
+ v: torch.Tensor,
447
+ tensor_layout: str = "HND",
448
+ is_causal: bool = False,
449
+ qk_quant_gran: str = "per_thread",
450
+ sm_scale: Optional[float] = None,
451
+ pv_accum_dtype: str = "fp32+fp16",
452
+ smooth_k: bool = True,
453
+ smooth_v: bool = False,
454
+ return_lse: bool = False,
455
+ **kwargs: Any,
456
+ ) -> torch.Tensor:
457
+ """
458
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
459
+
460
+ Parameters
461
+ ----------
462
+ q : torch.Tensor
463
+ The query tensor. Shape:
464
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
465
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
466
+
467
+ k : torch.Tensor
468
+ The key tensor. Shape:
469
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
470
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
471
+
472
+ v : torch.Tensor
473
+ The value tensor. Shape:
474
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
475
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
476
+
477
+ tensor_layout : str
478
+ The tensor layout, either "HND" or "NHD".
479
+ Default: "HND".
480
+
481
+ is_causal : bool
482
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
483
+ Default: False.
484
+
485
+ qk_quant_gran : str
486
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
487
+ Default: "per_thread".
488
+
489
+ sm_scale : Optional[float]
490
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
491
+
492
+ pv_accum_dtype : str
493
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
494
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
495
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
496
+ Default: "fp32+fp32".
497
+
498
+ smooth_k : bool
499
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
500
+ Default: True.
501
+
502
+ smooth_v : bool
503
+ Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
504
+ smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
505
+ Default: False.
506
+
507
+ return_lse : bool
508
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
509
+ Default: False.
510
+
511
+ Returns
512
+ -------
513
+ torch.Tensor
514
+ The output tensor. Shape:
515
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
516
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
517
+
518
+ torch.Tensor
519
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
520
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
521
+ Only returned if `return_lse` is True.
522
+
523
+ Note
524
+ ----
525
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
526
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
527
+ - All tensors must be on the same cuda device.
528
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
529
+ """
530
+
531
+ dtype = q.dtype
532
+ assert q.is_cuda, "Input tensors must be on cuda."
533
+ assert dtype in [torch.float16, torch.bfloat16], (
534
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
535
+ )
536
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
537
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
538
+ )
539
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
540
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
541
+
542
+ # cuda_major_version, cuda_minor_version = get_cuda_version()
543
+ # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16':
544
+ # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'")
545
+ # pv_accum_dtype = 'fp32+fp32'
546
+
547
+ # FIXME(DefTruth): make sage attention work compatible with distributed
548
+ # env, for example, xDiT which launch by torchrun. Without this workaround,
549
+ # sage attention will run into illegal memory access error after first
550
+ # inference step in distributed env for multi gpus inference. This small
551
+ # workaround also make sage attention work compatible with torch.compile
552
+ # through non-fullgraph compile mode.
553
+ torch.cuda.set_device(v.device)
554
+
555
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
556
+ _is_caual = 1 if is_causal else 0
557
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
558
+ _return_lse = 1 if return_lse else 0
559
+
560
+ head_dim_og = q.size(-1)
561
+
562
+ if head_dim_og < 64:
563
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
564
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
565
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
566
+ elif head_dim_og > 64 and head_dim_og < 128:
567
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
568
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
569
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
570
+ elif head_dim_og > 128:
571
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
572
+
573
+ # assert last dim is contiguous
574
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
575
+ "Last dim of qkv must be contiguous."
576
+ )
577
+
578
+ if sm_scale is None:
579
+ sm_scale = head_dim_og**-0.5
580
+
581
+ seq_dim = 1 if _tensor_layout == 0 else 2
582
+ nh_dim = 2 if _tensor_layout == 0 else 1
583
+
584
+ if smooth_k:
585
+ km = k.mean(dim=seq_dim, keepdim=True)
586
+ nqheads = q.size(2)
587
+ nkheads = k.size(2)
588
+ q_per_kv_heads = nqheads // nkheads
589
+ if q_per_kv_heads > 1:
590
+ # nheads_k => nheads_q
591
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
592
+ else:
593
+ km_broadcast = km
594
+ if return_lse:
595
+ if tensor_layout == "NHD":
596
+ lse_correction = (
597
+ torch.matmul(
598
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
599
+ )
600
+ .squeeze(-1)
601
+ .to(torch.float32)
602
+ )
603
+ else:
604
+ lse_correction = (
605
+ torch.matmul(q, km_broadcast.transpose(2, 3))
606
+ .squeeze(-1)
607
+ .to(torch.float32)
608
+ )
609
+ else:
610
+ km = None
611
+
612
+ if qk_quant_gran == "per_warp":
613
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
614
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64
615
+ )
616
+ elif qk_quant_gran == "per_thread":
617
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
618
+ q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64
619
+ )
620
+
621
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
622
+
623
+ if pv_accum_dtype == "fp32+fp32" and smooth_v:
624
+ warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
625
+ smooth_v = False
626
+
627
+ if pv_accum_dtype == "fp32+fp16" and smooth_v:
628
+ warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.")
629
+ smooth_v = False
630
+
631
+ quant_v_scale_max = 448.0
632
+ if pv_accum_dtype == "fp32+fp16":
633
+ quant_v_scale_max = 2.25
634
+
635
+ v_fp8, v_scale, vm = per_channel_fp8(
636
+ v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v
637
+ )
638
+ print("before kernel call")
639
+ if pv_accum_dtype == "fp32":
640
+ if smooth_v:
641
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(
642
+ q_int8,
643
+ k_int8,
644
+ v_fp8,
645
+ o,
646
+ q_scale,
647
+ k_scale,
648
+ v_scale,
649
+ vm,
650
+ _tensor_layout,
651
+ _is_caual,
652
+ _qk_quant_gran,
653
+ sm_scale,
654
+ _return_lse,
655
+ )
656
+ torch.cuda.synchronize()
657
+ else:
658
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
659
+ q_int8,
660
+ k_int8,
661
+ v_fp8,
662
+ o,
663
+ q_scale,
664
+ k_scale,
665
+ v_scale,
666
+ _tensor_layout,
667
+ _is_caual,
668
+ _qk_quant_gran,
669
+ sm_scale,
670
+ _return_lse,
671
+ )
672
+ torch.cuda.synchronize()
673
+ elif pv_accum_dtype == "fp32+fp32":
674
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
675
+ q_int8,
676
+ k_int8,
677
+ v_fp8,
678
+ o,
679
+ q_scale,
680
+ k_scale,
681
+ v_scale,
682
+ _tensor_layout,
683
+ _is_caual,
684
+ _qk_quant_gran,
685
+ sm_scale,
686
+ _return_lse,
687
+ )
688
+ torch.cuda.synchronize()
689
+ elif pv_accum_dtype == "fp32+fp16":
690
+ lse = ops.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(
691
+ q_int8,
692
+ k_int8,
693
+ v_fp8,
694
+ o,
695
+ q_scale,
696
+ k_scale,
697
+ v_scale,
698
+ _tensor_layout,
699
+ _is_caual,
700
+ _qk_quant_gran,
701
+ sm_scale,
702
+ _return_lse,
703
+ )
704
+ torch.cuda.synchronize()
705
+ o = o[..., :head_dim_og]
706
+ print("after kernel call")
707
+ if return_lse:
708
+ return (
709
+ o,
710
+ lse / 1.44269504 + lse_correction * sm_scale
711
+ if smooth_k
712
+ else lse / 1.44269504,
713
+ )
714
+ else:
715
+ return o
716
+
717
+
718
+ @torch.compiler.disable
719
+ def sageattn_qk_int8_pv_fp8_cuda_sm90(
720
+ q: torch.Tensor,
721
+ k: torch.Tensor,
722
+ v: torch.Tensor,
723
+ tensor_layout: str = "HND",
724
+ is_causal: bool = False,
725
+ qk_quant_gran: str = "per_thread",
726
+ sm_scale: Optional[float] = None,
727
+ pv_accum_dtype: str = "fp32+fp32",
728
+ smooth_k: bool = True,
729
+ return_lse: bool = False,
730
+ **kwargs: Any,
731
+ ) -> torch.Tensor:
732
+ """
733
+ SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
734
+
735
+ Parameters
736
+ ----------
737
+ q : torch.Tensor
738
+ The query tensor. Shape:
739
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
740
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
741
+
742
+ k : torch.Tensor
743
+ The key tensor. Shape:
744
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
745
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
746
+
747
+ v : torch.Tensor
748
+ The value tensor. Shape:
749
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
750
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
751
+
752
+ tensor_layout : str
753
+ The tensor layout, either "HND" or "NHD".
754
+ Default: "HND".
755
+
756
+ is_causal : bool
757
+ Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
758
+ Default: False.
759
+
760
+ qk_quant_gran : str
761
+ The granularity of quantization for Q and K, either "per_warp" or "per_thread".
762
+ Default: "per_thread".
763
+
764
+ sm_scale : Optional[float]
765
+ The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
766
+
767
+ pv_accum_dtype : str
768
+ The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
769
+ - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
770
+ - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
771
+ Default: "fp32+fp32".
772
+
773
+ smooth_k : bool
774
+ Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
775
+ Default: True.
776
+
777
+ return_lse : bool
778
+ Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
779
+ Default: False.
780
+
781
+ Returns
782
+ -------
783
+ torch.Tensor
784
+ The output tensor. Shape:
785
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
786
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
787
+
788
+ torch.Tensor
789
+ The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
790
+ Shape: ``[batch_size, num_qo_heads, qo_len]``.
791
+ Only returned if `return_lse` is True.
792
+
793
+ Note
794
+ ----
795
+ - ``num_qo_heads`` must be divisible by ``num_kv_heads``.
796
+ - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
797
+ - All tensors must be on the same cuda device.
798
+ - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
799
+ """
800
+
801
+ dtype = q.dtype
802
+ assert q.is_cuda, "Input tensors must be on cuda."
803
+ assert dtype in [torch.float16, torch.bfloat16], (
804
+ "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
805
+ )
806
+ assert qk_quant_gran in ["per_warp", "per_thread"], (
807
+ "qk_quant_gran must be either 'per_warp' or 'per_thread'."
808
+ )
809
+ assert q.device == k.device == v.device, "All tensors must be on the same device."
810
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
811
+
812
+ torch.cuda.set_device(v.device)
813
+
814
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
815
+ _is_caual = 1 if is_causal else 0
816
+ _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
817
+ _return_lse = 1 if return_lse else 0
818
+
819
+ head_dim_og = q.size(-1)
820
+
821
+ if head_dim_og < 64:
822
+ q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
823
+ k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
824
+ v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
825
+ elif head_dim_og > 64 and head_dim_og < 128:
826
+ q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
827
+ k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
828
+ v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
829
+ elif head_dim_og > 128:
830
+ raise ValueError(f"Unsupported head_dim: {head_dim_og}")
831
+
832
+ # assert last dim is contiguous
833
+ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, (
834
+ "Last dim of qkv must be contiguous."
835
+ )
836
+
837
+ if sm_scale is None:
838
+ sm_scale = head_dim_og**-0.5
839
+
840
+ seq_dim = 1 if _tensor_layout == 0 else 2
841
+ nh_dim = 2 if _tensor_layout == 0 else 1
842
+
843
+ if smooth_k:
844
+ km = k.mean(dim=seq_dim, keepdim=True)
845
+ nqheads = q.size(2)
846
+ nkheads = k.size(2)
847
+ q_per_kv_heads = nqheads // nkheads
848
+ if q_per_kv_heads > 1:
849
+ # nheads_k => nheads_q
850
+ km_broadcast = torch.repeat_interleave(km, q_per_kv_heads, dim=nh_dim)
851
+ else:
852
+ km_broadcast = km
853
+ if return_lse:
854
+ if tensor_layout == "NHD":
855
+ lse_correction = (
856
+ torch.matmul(
857
+ q.transpose(1, 2), km_broadcast.transpose(1, 2).transpose(2, 3)
858
+ )
859
+ .squeeze(-1)
860
+ .to(torch.float32)
861
+ )
862
+ else:
863
+ lse_correction = (
864
+ torch.matmul(q, km_broadcast.transpose(2, 3))
865
+ .squeeze(-1)
866
+ .to(torch.float32)
867
+ )
868
+ else:
869
+ km = None
870
+
871
+ if qk_quant_gran == "per_warp":
872
+ q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(
873
+ q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128
874
+ )
875
+ elif qk_quant_gran == "per_thread":
876
+ q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(
877
+ q,
878
+ k,
879
+ km,
880
+ tensor_layout=tensor_layout,
881
+ BLKQ=64,
882
+ WARPQ=16,
883
+ BLKK=128,
884
+ WARPK=128,
885
+ )
886
+
887
+ o = torch.empty(q.size(), dtype=dtype, device=q.device)
888
+
889
+ # pad v to multiple of 128
890
+ # TODO: modify per_channel_fp8 kernel to handle this
891
+ kv_len = k.size(seq_dim)
892
+ v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
893
+ if v_pad_len > 0:
894
+ if tensor_layout == "HND":
895
+ v = torch.cat(
896
+ [
897
+ v,
898
+ torch.zeros(
899
+ v.size(0),
900
+ v.size(1),
901
+ v_pad_len,
902
+ v.size(3),
903
+ dtype=v.dtype,
904
+ device=v.device,
905
+ ),
906
+ ],
907
+ dim=2,
908
+ )
909
+ else:
910
+ v = torch.cat(
911
+ [
912
+ v,
913
+ torch.zeros(
914
+ v.size(0),
915
+ v_pad_len,
916
+ v.size(2),
917
+ v.size(3),
918
+ dtype=v.dtype,
919
+ device=v.device,
920
+ ),
921
+ ],
922
+ dim=1,
923
+ )
924
+
925
+ v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
926
+
927
+ if pv_accum_dtype == "fp32":
928
+ raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
929
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(
930
+ q_int8,
931
+ k_int8,
932
+ v_fp8,
933
+ o,
934
+ q_scale,
935
+ k_scale,
936
+ v_scale,
937
+ _tensor_layout,
938
+ _is_caual,
939
+ _qk_quant_gran,
940
+ sm_scale,
941
+ _return_lse,
942
+ )
943
+ elif pv_accum_dtype == "fp32+fp32":
944
+ print(
945
+ "qint8",
946
+ q_int8.shape,
947
+ "qscale",
948
+ q_scale.shape,
949
+ "kint8",
950
+ k_int8.shape,
951
+ "kscale",
952
+ k_scale.shape,
953
+ "vfp8",
954
+ v_fp8.shape,
955
+ "vscale",
956
+ v_scale.shape,
957
+ )
958
+ lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
+ q_int8,
960
+ k_int8,
961
+ v_fp8,
962
+ o,
963
+ q_scale,
964
+ k_scale,
965
+ v_scale,
966
+ _tensor_layout,
967
+ _is_caual,
968
+ _qk_quant_gran,
969
+ sm_scale,
970
+ _return_lse,
971
+ )
972
+
973
+ o = o[..., :head_dim_og]
974
+
975
+ if return_lse:
976
+ return (
977
+ o,
978
+ lse / 1.44269504 + lse_correction * sm_scale
979
+ if smooth_k
980
+ else lse / 1.44269504,
981
+ )
982
+ else:
983
+ return o
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/layers.py ADDED
File without changes
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/quant.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ from typing import Optional
19
+
20
+ from ._ops import ops
21
+
22
+
23
+ def per_block_int8(
24
+ q: torch.Tensor,
25
+ k: torch.Tensor,
26
+ km: Optional[torch.Tensor] = None,
27
+ BLKQ: int = 128,
28
+ BLKK: int = 64,
29
+ sm_scale: Optional[float] = None,
30
+ tensor_layout: str = "HND",
31
+ ):
32
+ """
33
+ Quantize the query tensor `q` and the key tensor `k` with per block quantization.
34
+
35
+ Parameters
36
+ ----------
37
+ q : torch.Tensor
38
+ The query tensor. Shape:
39
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
40
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
41
+
42
+ k : torch.Tensor
43
+ The key tensor. Shape:
44
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
45
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
46
+
47
+ km : Optional[torch.Tensor]
48
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
49
+ Should be of the same dtype as `k` if provided. Default is None.
50
+
51
+ sm_scale : Optional[float]
52
+ The scale factor for the softmax operation. Default is ``head_dim**-0.5``.
53
+ It will be multiplied by ``1.44269504`` to work together with the triton attention kernel.
54
+
55
+ tensor_layout : str
56
+ The tensor layout, either "HND" or "NHD".
57
+ Default: "HND".
58
+
59
+ Returns
60
+ -------
61
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
62
+ A tuple containing:
63
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
64
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ]`` with `float32` dtype.
65
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
66
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
67
+
68
+ Note
69
+ ----
70
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
71
+ """
72
+
73
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
74
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
75
+
76
+ if tensor_layout == "HND":
77
+ b, h_qo, qo_len, head_dim = q.shape
78
+ _, h_kv, kv_len, _ = k.shape
79
+
80
+ elif tensor_layout == "NHD":
81
+ b, qo_len, h_qo, head_dim = q.shape
82
+ _, kv_len, h_kv, _ = k.shape
83
+
84
+ else:
85
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
86
+
87
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
88
+
89
+ q_scale = torch.empty(
90
+ (b, h_qo, (qo_len + BLKQ - 1) // BLKQ), device=q.device, dtype=torch.float32
91
+ )
92
+ k_scale = torch.empty(
93
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
94
+ )
95
+
96
+ if sm_scale is None:
97
+ sm_scale = head_dim**-0.5
98
+
99
+ sm_scale *= 1.44269504
100
+
101
+ ops.quant_per_block_int8_cuda(q, q_int8, q_scale, sm_scale, BLKQ, _tensor_layout)
102
+ if km is not None:
103
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
104
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
105
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
106
+ )
107
+ else:
108
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
109
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
110
+
111
+ return q_int8, q_scale, k_int8, k_scale
112
+
113
+
114
+ def per_warp_int8(
115
+ q: torch.Tensor,
116
+ k: torch.Tensor,
117
+ km: Optional[torch.Tensor] = None,
118
+ BLKQ: int = 128,
119
+ WARPQ: int = 32,
120
+ BLKK: int = 64,
121
+ tensor_layout: str = "HND",
122
+ ):
123
+ """
124
+ Quantize the query tensor `q` with per warp quantization and the key tensor `k` with per block quantization.
125
+ Warp size of quantizing `q` is 16 or 32, with a block size of 64 or 128.
126
+ Block size of quantizing `k` is 64 or 128.
127
+
128
+ Parameters
129
+ ----------
130
+ q : torch.Tensor
131
+ The query tensor. Shape:
132
+ - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
133
+ - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
134
+
135
+ k : torch.Tensor
136
+ The key tensor. Shape:
137
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
138
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
139
+
140
+ km : Optional[torch.Tensor]
141
+ The mean tensor of `k` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]``.
142
+ Should be of the same dtype as `k` if provided. Default is None.
143
+
144
+ tensor_layout : str
145
+ The tensor layout, either "HND" or "NHD".
146
+ Default: "HND".
147
+
148
+ Returns
149
+ -------
150
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
151
+ A tuple containing:
152
+ - The quantized query tensor. Shape: Same as `q` but with `int8` dtype.
153
+ - The scale tensor of the query tensor. Shape: ``[batch_size, num_qo_heads, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ)]`` with `float32` dtype.
154
+ - The quantized key tensor. Shape: Same as `k` but with `int8` dtype.
155
+ - The scale tensor of the key tensor. Shape: ``[batch_size, num_kv_heads, (kv_len + BLKK - 1) // BLKK]`` with `float32` dtype.
156
+
157
+ Note
158
+ ----
159
+ - The tensors `q` and `k` must have the dtype ``torch.float16`` or ``torch.bfloat16``
160
+ """
161
+
162
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
163
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
164
+
165
+ if tensor_layout == "HND":
166
+ b, h_qo, qo_len, head_dim = q.shape
167
+ _, h_kv, kv_len, _ = k.shape
168
+
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ else:
174
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
175
+
176
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
177
+
178
+ q_scale = torch.empty(
179
+ (b, h_qo, ((qo_len + BLKQ - 1) // BLKQ) * (BLKQ // WARPQ)),
180
+ device=q.device,
181
+ dtype=torch.float32,
182
+ )
183
+ k_scale = torch.empty(
184
+ (b, h_kv, (kv_len + BLKK - 1) // BLKK), device=q.device, dtype=torch.float32
185
+ )
186
+
187
+ ops.quant_per_warp_int8_cuda(q, q_int8, q_scale, BLKQ, WARPQ, _tensor_layout)
188
+
189
+ if km is not None:
190
+ km = km.squeeze(1) if _tensor_layout == 0 else km.squeeze(2)
191
+ ops.quant_per_block_int8_fuse_sub_mean_cuda(
192
+ k, km, k_int8, k_scale, BLKK, _tensor_layout
193
+ )
194
+ else:
195
+ # The bound CUDA op expects an sm_scale argument; use 1.0 for K to avoid scaling
196
+ ops.quant_per_block_int8_cuda(k, k_int8, k_scale, 1.0, BLKK, _tensor_layout)
197
+
198
+ return q_int8, q_scale, k_int8, k_scale
199
+
200
+
201
+ def sub_mean(v: torch.Tensor, tensor_layout: str = "HND"):
202
+ """
203
+ Calculate the mean of the tensor `v` along the sequence length dimension and subtract it from `v`. Result is stored as fp16.
204
+
205
+ Parameters
206
+ ----------
207
+ v : torch.Tensor
208
+ The input tensor. Shape:
209
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
210
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
211
+
212
+ tensor_layout : str
213
+ The tensor layout, either "HND" or "NHD".
214
+ Default: "HND".
215
+
216
+ Returns
217
+ -------
218
+ Tuple[torch.Tensor, torch.Tensor]
219
+ A tuple containing:
220
+ - The tensor `v_smoothed` with the mean subtracted and stored as fp16. Shape: Same as `v` with `float16` dtype.
221
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with dtype same as `v`.
222
+
223
+ Note
224
+ ----
225
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
226
+ - The returned tensor `v_smoothed` will have dtype ``torch.float16`` regardless of the input dtype.
227
+ - The returned mean tensor will have the same dtype as the input tensor.
228
+ """
229
+
230
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
231
+ vm = v.mean(dim=1 if _tensor_layout == 0 else 2)
232
+
233
+ v_smoothed = torch.empty(v.shape, dtype=torch.float16, device=v.device)
234
+
235
+ # subtract mean and store the result as fp16
236
+ ops.sub_mean_cuda(v, vm, v_smoothed, _tensor_layout)
237
+
238
+ return v_smoothed, vm
239
+
240
+
241
+ def per_channel_fp8(
242
+ v: torch.Tensor,
243
+ tensor_layout: str = "HND",
244
+ scale_max: float = 448.0,
245
+ smooth_v: bool = True,
246
+ ):
247
+ """
248
+ Transpose, pad and permute the tensor `v` and quantize it to fp8 with per channel quantization.
249
+ `v` is first transposed along the head dimension and the sequence length dimension, then padded to a multiple of 64.
250
+ After that, the tensor is permuted along the sequence length dimension by ``[0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15]``.
251
+ The quantization is done per channel, with the scale value and smooth factor calculated per channel.
252
+
253
+ Parameters
254
+ ----------
255
+ v : torch.Tensor
256
+ The input tensor. Shape:
257
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
258
+ - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
259
+
260
+ tensor_layout : str
261
+ The tensor layout, either "HND" or "NHD".
262
+ Default: "HND".
263
+
264
+ scale_max : float
265
+ The maximum scale value for the quantization. Default is 448.0 (upper bound of E4M3 data format).
266
+
267
+ smooth_v : bool
268
+ Whether to smooth the quantized tensor. Default is True.
269
+
270
+ Returns
271
+ -------
272
+ Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
273
+ A tuple containing:
274
+ - The quantized tensor `v_fp8`. Shape:
275
+ - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, head_dim, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
276
+ - If `tensor_layout` is "NHD": ``[batch_size, head_dim, num_kv_heads, (kv_len + 63) // 64 * 64]``, with `float8_e4m3fn` dtype.
277
+ - The scale tensor of `v`. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
278
+ - The mean tensor of `v` along the sequence length dimension. Shape: ``[batch_size, num_kv_heads, head_dim]`` with `float32` dtype.
279
+
280
+ Note
281
+ ----
282
+ - The tensors `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
283
+ - The returned mean tensor will be None if `smooth_v` is False. Otherwise it will have dtype ``torch.float32``.
284
+ """
285
+
286
+ _tensor_layout = 0 if tensor_layout == "NHD" else 1
287
+
288
+ if tensor_layout == "HND":
289
+ b, h_kv, kv_len, head_dim = v.shape
290
+ padded_len = (kv_len + 63) // 64 * 64
291
+ v_transposed_permutted = torch.empty(
292
+ (b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device
293
+ )
294
+
295
+ elif tensor_layout == "NHD":
296
+ b, kv_len, h_kv, head_dim = v.shape
297
+ padded_len = (kv_len + 63) // 64 * 64
298
+ v_transposed_permutted = torch.empty(
299
+ (b, head_dim, h_kv, padded_len), dtype=v.dtype, device=v.device
300
+ )
301
+
302
+ ops.transpose_pad_permute_cuda(v, v_transposed_permutted, _tensor_layout)
303
+
304
+ v_fp8 = torch.empty(
305
+ v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device
306
+ )
307
+
308
+ v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
309
+ vm = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
310
+
311
+ if smooth_v:
312
+ ops.mean_scale_fuse_quant_cuda(
313
+ v_transposed_permutted,
314
+ v_fp8,
315
+ vm,
316
+ v_scale,
317
+ kv_len,
318
+ scale_max,
319
+ _tensor_layout,
320
+ )
321
+ return v_fp8, v_scale, vm
322
+ else:
323
+ ops.scale_fuse_quant_cuda(
324
+ v_transposed_permutted, v_fp8, v_scale, kv_len, scale_max, _tensor_layout
325
+ )
326
+ return v_fp8, v_scale, None
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/quant_per_thread.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 by SageAttention team.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ @triton.jit
22
+ def quant_query_per_thread_int8_kernel(Input, Output, Scale, L,
23
+ stride_iz, stride_ih, stride_in,
24
+ stride_oz, stride_oh, stride_on,
25
+ stride_sz, stride_sh,
26
+ C: tl.constexpr, BLK: tl.constexpr):
27
+ off_blk = tl.program_id(0) // 8
28
+ off_tld = tl.program_id(0) % 8
29
+ off_h = tl.program_id(1)
30
+ off_b = tl.program_id(2)
31
+
32
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
33
+ offs_k = tl.arange(0, C)
34
+
35
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
36
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
37
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
38
+
39
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
40
+ x = x.to(tl.float32)
41
+ scale = tl.max(tl.abs(x)) / 127. + 0.0000001
42
+ x_int8 = x / scale
43
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
44
+ x_int8 = x_int8.to(tl.int8)
45
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
46
+ tl.store(scale_ptrs, scale)
47
+
48
+ @triton.jit
49
+ def quant_key_per_thread_int8_kernel(Input, Output, Scale, L,
50
+ stride_iz, stride_ih, stride_in,
51
+ stride_oz, stride_oh, stride_on,
52
+ stride_sz, stride_sh,
53
+ C: tl.constexpr, BLK: tl.constexpr):
54
+ off_blk = tl.program_id(0) // 4
55
+ off_tld = tl.program_id(0) % 4
56
+ off_h = tl.program_id(1)
57
+ off_b = tl.program_id(2)
58
+
59
+ # offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
60
+ # offs_k = tl.arange(0, C)
61
+
62
+ # input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
63
+ # output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
64
+ # scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
65
+
66
+ # x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
67
+ # x = x.to(tl.float32)
68
+ # scale = tl.max(tl.abs(x)) / 127. + 0.0000001
69
+ # x_int8 = x / scale
70
+ # x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
71
+ # x_int8 = x_int8.to(tl.int8)
72
+ # tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
73
+ # tl.store(scale_ptrs, scale)
74
+
75
+ offs_n0 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2
76
+ offs_n1 = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld * 2 + 1
77
+ offs_k = tl.arange(0, C)
78
+
79
+ input_ptrs0 = Input + off_b * stride_iz + off_h * stride_ih + offs_n0[:, None] * stride_in + offs_k[None, :]
80
+ input_ptrs1 = Input + off_b * stride_iz + off_h * stride_ih + offs_n1[:, None] * stride_in + offs_k[None, :]
81
+ output_ptrs0 = Output + off_b * stride_oz + off_h * stride_oh + offs_n0[:, None] * stride_on + offs_k[None, :]
82
+ output_ptrs1 = Output + off_b * stride_oz + off_h * stride_oh + offs_n1[:, None] * stride_on + offs_k[None, :]
83
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
84
+
85
+ x0 = tl.load(input_ptrs0, mask=offs_n0[:, None] < L)
86
+ x1 = tl.load(input_ptrs1, mask=offs_n1[:, None] < L)
87
+ x0 = x0.to(tl.float32)
88
+ x1 = x1.to(tl.float32)
89
+ scale = max(tl.max(tl.abs(x0)), tl.max(tl.abs(x1))) / 127. + 0.0000001
90
+ x0_int8 = x0 / scale
91
+ x1_int8 = x1 / scale
92
+ x0_int8 += 0.5 * tl.where(x0_int8 >= 0, 1, -1)
93
+ x1_int8 += 0.5 * tl.where(x1_int8 >= 0, 1, -1)
94
+ x0_int8 = x0_int8.to(tl.int8)
95
+ x1_int8 = x1_int8.to(tl.int8)
96
+ tl.store(output_ptrs0, x0_int8, mask=offs_n0[:, None] < L)
97
+ tl.store(output_ptrs1, x1_int8, mask=offs_n1[:, None] < L)
98
+ tl.store(scale_ptrs, scale)
99
+
100
+ @triton.jit
101
+ def quant_query_per_thread_int4_kernel(Input, Output, Scale, L,
102
+ stride_iz, stride_ih, stride_in,
103
+ stride_oz, stride_oh, stride_on,
104
+ stride_sz, stride_sh,
105
+ C: tl.constexpr, BLK: tl.constexpr):
106
+ off_blk = tl.program_id(0) // 8
107
+ off_tld = tl.program_id(0) % 8
108
+ off_h = tl.program_id(1)
109
+ off_b = tl.program_id(2)
110
+
111
+ offs_n = off_blk * BLK + tl.arange(0, BLK // 8) * 8 + off_tld
112
+ offs_k = tl.arange(0, C)
113
+
114
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
115
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
116
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 8 + off_tld
117
+
118
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
119
+ x = x.to(tl.float32)
120
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
121
+ x_int8 = x / scale
122
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
123
+ x_int8 = x_int8.to(tl.int8)
124
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
125
+ tl.store(scale_ptrs, scale)
126
+
127
+ @triton.jit
128
+ def quant_key_per_thread_int4_kernel(Input, Output, Scale, L,
129
+ stride_iz, stride_ih, stride_in,
130
+ stride_oz, stride_oh, stride_on,
131
+ stride_sz, stride_sh,
132
+ C: tl.constexpr, BLK: tl.constexpr):
133
+ off_blk = tl.program_id(0) // 4
134
+ off_tld = tl.program_id(0) % 4
135
+ off_h = tl.program_id(1)
136
+ off_b = tl.program_id(2)
137
+
138
+ offs_n = off_blk * BLK + tl.cat(tl.arange(0, BLK // 8) * 8, tl.arange(0, BLK // 8) * 8 + 1, True) + off_tld * 2
139
+ offs_k = tl.arange(0, C)
140
+
141
+ input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
142
+ output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
143
+ scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk * 4 + off_tld
144
+
145
+ x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
146
+ x = x.to(tl.float32)
147
+ scale = tl.max(tl.abs(x)) / 7. + 0.0000001
148
+ x_int8 = x / scale
149
+ x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
150
+ x_int8 = x_int8.to(tl.int8)
151
+ tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
152
+ tl.store(scale_ptrs, scale)
153
+
154
+ def per_thread_int8(q, k, km=None, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64, sm_scale=None, tensor_layout="HND"):
155
+ q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
156
+ k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
157
+
158
+ if km is not None:
159
+ k = k - km
160
+
161
+ if tensor_layout == "HND":
162
+ b, h_qo, qo_len, head_dim = q.shape
163
+ _, h_kv, kv_len, _ = k.shape
164
+
165
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
166
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
167
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
168
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
169
+ elif tensor_layout == "NHD":
170
+ b, qo_len, h_qo, head_dim = q.shape
171
+ _, kv_len, h_kv, _ = k.shape
172
+
173
+ stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
174
+ stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
175
+ stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
176
+ stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
177
+ else:
178
+ raise ValueError(f"Unknown tensor layout: {tensor_layout}")
179
+
180
+ q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8), device=q.device, dtype=torch.float32)
181
+ k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4), device=q.device, dtype=torch.float32)
182
+
183
+ if sm_scale is None:
184
+ sm_scale = head_dim**-0.5
185
+
186
+ grid = ((qo_len + BLKQ - 1) // BLKQ * (BLKQ // WARPQ) * 8, h_qo, b)
187
+ quant_query_per_thread_int8_kernel[grid](
188
+ q, q_int8, q_scale, qo_len,
189
+ stride_bz_q, stride_h_q, stride_seq_q,
190
+ stride_bz_qo, stride_h_qo, stride_seq_qo,
191
+ q_scale.stride(0), q_scale.stride(1),
192
+ C=head_dim, BLK=WARPQ
193
+ )
194
+
195
+ grid = ((kv_len + BLKK - 1) // BLKK * (BLKK // WARPK) * 4, h_kv, b)
196
+ quant_key_per_thread_int8_kernel[grid](
197
+ k, k_int8, k_scale, kv_len,
198
+ stride_bz_k, stride_h_k, stride_seq_k,
199
+ stride_bz_ko, stride_h_ko, stride_seq_ko,
200
+ k_scale.stride(0), k_scale.stride(1),
201
+ C=head_dim, BLK=WARPK
202
+ )
203
+
204
+ return q_int8, q_scale, k_int8, k_scale