medmekk HF Staff commited on
Commit
a8031ce
·
1 Parent(s): 1369690

update builds

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CMakeLists.txt +0 -325
  2. build.toml +6 -6
  3. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__init__.py +1 -2
  4. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  5. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  6. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  7. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  8. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  9. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/_ops.py +3 -3
  10. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} +1 -1
  11. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/core.py +4 -18
  12. build/torch27-cxx11-cu126-x86_64-linux/sage_attention/layers.py +0 -0
  13. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__init__.py +1 -2
  14. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  15. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  16. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  17. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  18. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  19. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/_ops.py +3 -3
  20. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} +1 -1
  21. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/core.py +4 -18
  22. build/torch27-cxx11-cu128-x86_64-linux/sage_attention/layers.py +0 -0
  23. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__init__.py +1 -2
  24. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  25. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  26. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  27. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  28. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  29. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/_ops.py +3 -3
  30. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} +1 -1
  31. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/core.py +4 -18
  32. build/torch28-cxx11-cu126-x86_64-linux/sage_attention/layers.py +0 -0
  33. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__init__.py +1 -2
  34. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  35. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  36. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  37. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  38. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  39. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/_ops.py +3 -3
  40. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} +2 -2
  41. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/core.py +4 -18
  42. build/torch28-cxx11-cu128-x86_64-linux/sage_attention/layers.py +0 -0
  43. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__init__.py +1 -2
  44. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc +0 -0
  45. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc +0 -0
  46. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc +0 -0
  47. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc +0 -0
  48. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc +0 -0
  49. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/_ops.py +3 -3
  50. build/torch28-cxx11-cu129-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} +1 -1
CMakeLists.txt DELETED
@@ -1,325 +0,0 @@
1
- cmake_minimum_required(VERSION 3.26)
2
- project(sage_attention LANGUAGES CXX)
3
-
4
- set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel")
5
-
6
- install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
7
-
8
- include(FetchContent)
9
- file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
10
- message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
11
-
12
- set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
13
-
14
- set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")
15
-
16
- include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
17
-
18
- if(DEFINED Python_EXECUTABLE)
19
- # Allow passing through the interpreter (e.g. from setup.py).
20
- find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
21
- if (NOT Python_FOUND)
22
- message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
23
- endif()
24
- else()
25
- find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
26
- endif()
27
-
28
- append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
29
-
30
- find_package(Torch REQUIRED)
31
-
32
- if (NOT TARGET_DEVICE STREQUAL "cuda" AND
33
- NOT TARGET_DEVICE STREQUAL "rocm")
34
- return()
35
- endif()
36
-
37
- if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
38
- CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
39
- set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0+PTX")
40
- else()
41
- set(CUDA_DEFAULT_KERNEL_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0+PTX")
42
- endif()
43
-
44
- if (NOT HIP_FOUND AND CUDA_FOUND)
45
- set(GPU_LANG "CUDA")
46
-
47
-
48
-
49
- elseif(HIP_FOUND)
50
- set(GPU_LANG "HIP")
51
-
52
- # Importing torch recognizes and sets up some HIP/ROCm configuration but does
53
- # not let cmake recognize .hip files. In order to get cmake to understand the
54
- # .hip extension automatically, HIP must be enabled explicitly.
55
- enable_language(HIP)
56
- else()
57
- message(FATAL_ERROR "Can't find CUDA or HIP installation.")
58
- endif()
59
-
60
-
61
- if(GPU_LANG STREQUAL "CUDA")
62
- clear_cuda_arches(CUDA_ARCH_FLAGS)
63
- extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
64
- message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
65
- # Filter the target architectures by the supported supported archs
66
- # since for some files we will build for all CUDA_ARCHS.
67
- cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
68
- message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")
69
-
70
- if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
71
- list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
72
- endif()
73
-
74
- add_compile_definitions(CUDA_KERNEL)
75
- elseif(GPU_LANG STREQUAL "HIP")
76
- set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
77
- # TODO: remove this once we can set specific archs per source file set.
78
- override_gpu_arches(GPU_ARCHES
79
- ${GPU_LANG}
80
- "${${GPU_LANG}_SUPPORTED_ARCHS}")
81
-
82
- add_compile_definitions(ROCM_KERNEL)
83
- else()
84
- override_gpu_arches(GPU_ARCHES
85
- ${GPU_LANG}
86
- "${${GPU_LANG}_SUPPORTED_ARCHS}")
87
- endif()
88
-
89
- get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
90
- list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
91
-
92
- set(TORCH_sage_attention_SRC
93
- torch-ext/torch_binding.cpp torch-ext/torch_binding.h
94
- )
95
-
96
-
97
- list(APPEND SRC "${TORCH_sage_attention_SRC}")
98
-
99
-
100
- set(_qattn_sm80_SRC
101
- "sage_attention/qattn/qk_int_sv_f16_cuda_sm80.cu"
102
- "sage_attention/qattn/attn_cuda_sm80.h"
103
- "sage_attention/qattn/attn_utils.cuh"
104
- )
105
-
106
- # TODO: check if CLion support this:
107
- # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
108
- set_source_files_properties(
109
- ${_qattn_sm80_SRC}
110
- PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
111
-
112
- if(GPU_LANG STREQUAL "CUDA")
113
- cuda_archs_loose_intersection(_qattn_sm80_ARCHS "8.0" "${CUDA_ARCHS}")
114
- message(STATUS "Capabilities for kernel _qattn_sm80: ${_qattn_sm80_ARCHS}")
115
- set_gencode_flags_for_srcs(SRCS "${_qattn_sm80_SRC}" CUDA_ARCHS "${_qattn_sm80_ARCHS}")
116
-
117
-
118
- foreach(_KERNEL_SRC ${_qattn_sm80_SRC})
119
- if(_KERNEL_SRC MATCHES ".*\\.cu$")
120
- set_property(
121
- SOURCE ${_KERNEL_SRC}
122
- APPEND PROPERTY
123
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
124
- )
125
- endif()
126
- endforeach()
127
-
128
- foreach(_KERNEL_SRC ${_qattn_sm80_SRC})
129
- set_property(
130
- SOURCE ${_KERNEL_SRC}
131
- APPEND PROPERTY
132
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
133
- )
134
- endforeach()
135
-
136
- list(APPEND SRC "${_qattn_sm80_SRC}")
137
- endif()
138
-
139
-
140
-
141
- set(_qattn_sm90_SRC
142
- "sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu"
143
- "sage_attention/qattn/attn_cuda_sm90.h"
144
- "sage_attention/qattn/attn_utils.cuh"
145
- "sage_attention/cuda_tensormap_shim.cuh"
146
- )
147
-
148
- # TODO: check if CLion support this:
149
- # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
150
- set_source_files_properties(
151
- ${_qattn_sm90_SRC}
152
- PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
153
-
154
- if(GPU_LANG STREQUAL "CUDA")
155
- cuda_archs_loose_intersection(_qattn_sm90_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
156
- message(STATUS "Capabilities for kernel _qattn_sm90: ${_qattn_sm90_ARCHS}")
157
- set_gencode_flags_for_srcs(SRCS "${_qattn_sm90_SRC}" CUDA_ARCHS "${_qattn_sm90_ARCHS}")
158
-
159
-
160
- foreach(_KERNEL_SRC ${_qattn_sm90_SRC})
161
- if(_KERNEL_SRC MATCHES ".*\\.cu$")
162
- set_property(
163
- SOURCE ${_KERNEL_SRC}
164
- APPEND PROPERTY
165
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
166
- )
167
- endif()
168
- endforeach()
169
-
170
- foreach(_KERNEL_SRC ${_qattn_sm90_SRC})
171
- set_property(
172
- SOURCE ${_KERNEL_SRC}
173
- APPEND PROPERTY
174
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
175
- )
176
- endforeach()
177
-
178
- list(APPEND SRC "${_qattn_sm90_SRC}")
179
- endif()
180
-
181
-
182
-
183
- set(_qattn_sm89_SRC
184
- "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu"
185
- "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu"
186
- "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu"
187
- "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu"
188
- "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu"
189
- "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu"
190
- "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu"
191
- "sage_attention/qattn/attn_cuda_sm89.h"
192
- "sage_attention/qattn/qk_int_sv_f8_cuda_sm89.cuh"
193
- "sage_attention/qattn/attn_utils.cuh"
194
- )
195
-
196
- # TODO: check if CLion support this:
197
- # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
198
- set_source_files_properties(
199
- ${_qattn_sm89_SRC}
200
- PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
201
-
202
- if(GPU_LANG STREQUAL "CUDA")
203
- cuda_archs_loose_intersection(_qattn_sm89_ARCHS "8.9" "${CUDA_ARCHS}")
204
- message(STATUS "Capabilities for kernel _qattn_sm89: ${_qattn_sm89_ARCHS}")
205
- set_gencode_flags_for_srcs(SRCS "${_qattn_sm89_SRC}" CUDA_ARCHS "${_qattn_sm89_ARCHS}")
206
-
207
-
208
- foreach(_KERNEL_SRC ${_qattn_sm89_SRC})
209
- if(_KERNEL_SRC MATCHES ".*\\.cu$")
210
- set_property(
211
- SOURCE ${_KERNEL_SRC}
212
- APPEND PROPERTY
213
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
214
- )
215
- endif()
216
- endforeach()
217
-
218
- foreach(_KERNEL_SRC ${_qattn_sm89_SRC})
219
- set_property(
220
- SOURCE ${_KERNEL_SRC}
221
- APPEND PROPERTY
222
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
223
- )
224
- endforeach()
225
-
226
- list(APPEND SRC "${_qattn_sm89_SRC}")
227
- endif()
228
-
229
-
230
-
231
- set(_fused_SRC
232
- "sage_attention/fused/fused.cu"
233
- "sage_attention/fused/fused.h"
234
- )
235
-
236
- # TODO: check if CLion support this:
237
- # https://youtrack.jetbrains.com/issue/CPP-16510/CLion-does-not-handle-per-file-include-directories
238
- set_source_files_properties(
239
- ${_fused_SRC}
240
- PROPERTIES INCLUDE_DIRECTORIES "${CMAKE_SOURCE_DIR}/.")
241
-
242
- if(GPU_LANG STREQUAL "CUDA")
243
- cuda_archs_loose_intersection(_fused_ARCHS "8.0;8.9;9.0;9.0a" "${CUDA_ARCHS}")
244
- message(STATUS "Capabilities for kernel _fused: ${_fused_ARCHS}")
245
- set_gencode_flags_for_srcs(SRCS "${_fused_SRC}" CUDA_ARCHS "${_fused_ARCHS}")
246
-
247
-
248
- foreach(_KERNEL_SRC ${_fused_SRC})
249
- if(_KERNEL_SRC MATCHES ".*\\.cu$")
250
- set_property(
251
- SOURCE ${_KERNEL_SRC}
252
- APPEND PROPERTY
253
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
254
- )
255
- endif()
256
- endforeach()
257
-
258
- foreach(_KERNEL_SRC ${_fused_SRC})
259
- set_property(
260
- SOURCE ${_KERNEL_SRC}
261
- APPEND PROPERTY
262
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
263
- )
264
- endforeach()
265
-
266
- list(APPEND SRC "${_fused_SRC}")
267
- endif()
268
-
269
-
270
-
271
- set(_qattn_SRC
272
- "sage_attention/cp_async.cuh"
273
- "sage_attention/dispatch_utils.h"
274
- "sage_attention/math.cuh"
275
- "sage_attention/mma.cuh"
276
- "sage_attention/numeric_conversion.cuh"
277
- "sage_attention/permuted_smem.cuh"
278
- "sage_attention/reduction_utils.cuh"
279
- "sage_attention/wgmma.cuh"
280
- "sage_attention/utils.cuh"
281
- "sage_attention/cuda_tensormap_shim.cuh"
282
- )
283
-
284
-
285
- if(GPU_LANG STREQUAL "CUDA")
286
- cuda_archs_loose_intersection(_qattn_ARCHS "8.0;8.9;9.0;9.0a" "${CUDA_ARCHS}")
287
- message(STATUS "Capabilities for kernel _qattn: ${_qattn_ARCHS}")
288
- set_gencode_flags_for_srcs(SRCS "${_qattn_SRC}" CUDA_ARCHS "${_qattn_ARCHS}")
289
-
290
-
291
- foreach(_KERNEL_SRC ${_qattn_SRC})
292
- if(_KERNEL_SRC MATCHES ".*\\.cu$")
293
- set_property(
294
- SOURCE ${_KERNEL_SRC}
295
- APPEND PROPERTY
296
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-O3;-std=c++17;-U__CUDA_NO_HALF_OPERATORS__;-U__CUDA_NO_HALF_CONVERSIONS__;--use_fast_math;--threads=1;-Xptxas=-v;-diag-suppress=174>"
297
- )
298
- endif()
299
- endforeach()
300
-
301
- foreach(_KERNEL_SRC ${_qattn_SRC})
302
- set_property(
303
- SOURCE ${_KERNEL_SRC}
304
- APPEND PROPERTY
305
- COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CXX>:-g;-O3;-fopenmp;-lgomp;-std=c++17;-DENABLE_BF16>"
306
- )
307
- endforeach()
308
-
309
- list(APPEND SRC "${_qattn_SRC}")
310
- endif()
311
-
312
-
313
- define_gpu_extension_target(
314
- _sage_attention_57cb7ec_dirty
315
- DESTINATION _sage_attention_57cb7ec_dirty
316
- LANGUAGE ${GPU_LANG}
317
- SOURCES ${SRC}
318
- COMPILE_FLAGS ${GPU_FLAGS}
319
- ARCHITECTURES ${GPU_ARCHES}
320
- #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
321
- USE_SABI 3
322
- WITH_SOABI)
323
-
324
- target_link_options(_sage_attention_57cb7ec_dirty PRIVATE -static-libstdc++)
325
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build.toml CHANGED
@@ -1,7 +1,7 @@
1
  [general]
2
  name = "sage_attention"
3
  universal = false
4
- cuda-minver = "12.4"
5
 
6
  [torch]
7
  src = [
@@ -12,7 +12,7 @@ src = [
12
  [kernel._qattn]
13
  depends = ["torch"]
14
  backend = "cuda"
15
- cuda-minver = "12.4"
16
  cuda-capabilities = [
17
  "8.0", "8.9", "9.0a"
18
  ]
@@ -43,7 +43,7 @@ cuda-flags = [
43
  [kernel._qattn_sm80]
44
  depends = ["torch"]
45
  backend = "cuda"
46
- cuda-minver = "12.4"
47
  cuda-capabilities = [
48
  "8.0"
49
  ]
@@ -69,7 +69,7 @@ cuda-flags = [
69
  [kernel._qattn_sm89]
70
  depends = ["torch"]
71
  backend = "cuda"
72
- cuda-minver = "12.4"
73
  cuda-capabilities = [
74
  "8.9",
75
  ]
@@ -102,7 +102,7 @@ cuda-flags = [
102
  [kernel._qattn_sm90]
103
  depends = ["torch"]
104
  backend = "cuda"
105
- cuda-minver = "12.4"
106
  cuda-capabilities = [
107
  "9.0a",
108
  ]
@@ -127,7 +127,7 @@ cuda-flags = [
127
  [kernel._fused]
128
  depends = ["torch"]
129
  backend = "cuda"
130
- cuda-minver = "12.4"
131
  cuda-capabilities = [
132
  "8.0", "8.9", "9.0a",
133
  ]
 
1
  [general]
2
  name = "sage_attention"
3
  universal = false
4
+ cuda-minver = "12.0"
5
 
6
  [torch]
7
  src = [
 
12
  [kernel._qattn]
13
  depends = ["torch"]
14
  backend = "cuda"
15
+ cuda-minver = "12.0"
16
  cuda-capabilities = [
17
  "8.0", "8.9", "9.0a"
18
  ]
 
43
  [kernel._qattn_sm80]
44
  depends = ["torch"]
45
  backend = "cuda"
46
+ cuda-minver = "12.0"
47
  cuda-capabilities = [
48
  "8.0"
49
  ]
 
69
  [kernel._qattn_sm89]
70
  depends = ["torch"]
71
  backend = "cuda"
72
+ cuda-minver = "12.0"
73
  cuda-capabilities = [
74
  "8.9",
75
  ]
 
102
  [kernel._qattn_sm90]
103
  depends = ["torch"]
104
  backend = "cuda"
105
+ cuda-minver = "12.0"
106
  cuda-capabilities = [
107
  "9.0a",
108
  ]
 
127
  [kernel._fused]
128
  depends = ["torch"]
129
  backend = "cuda"
130
+ cuda-minver = "12.0"
131
  cuda-capabilities = [
132
  "8.0", "8.9", "9.0a",
133
  ]
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
- from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
 
4
 
5
  __all__ = [
@@ -8,5 +8,4 @@ __all__ = [
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
11
- "sageattn_qk_int8_pv_fp8_cuda",
12
  ]
 
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
 
4
 
5
  __all__ = [
 
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
 
11
  ]
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc differ
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc differ
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc differ
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc differ
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc differ
 
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _sage_attention_af2d0c0_dirty
3
- ops = torch.ops._sage_attention_af2d0c0_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_sage_attention_af2d0c0_dirty::{op_name}"
 
1
  import torch
2
+ from . import _sage_attention_1369690_dirty
3
+ ops = torch.ops._sage_attention_1369690_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_sage_attention_1369690_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:83f3b3d1c1371cf577a4e2c2fa3bbeef137aa93a89cf380816c14e650b1449f6
3
  size 26037568
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c7a3fa0f2b5db528e3854fcb72e3bc5936ed760336b96bf0e183d19fada3767
3
  size 26037568
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/core.py CHANGED
@@ -363,7 +363,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
@@ -379,7 +379,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
@@ -395,7 +395,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
@@ -410,7 +410,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
@@ -941,20 +941,6 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
944
- print(
945
- "qint8",
946
- q_int8.shape,
947
- "qscale",
948
- q_scale.shape,
949
- "kint8",
950
- k_int8.shape,
951
- "kscale",
952
- k_scale.shape,
953
- "vfp8",
954
- v_fp8.shape,
955
- "vscale",
956
- v_scale.shape,
957
- )
958
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
  q_int8,
960
  k_int8,
 
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
+ lse = ops.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
 
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
 
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
+ lse = ops.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
 
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
+ lse = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
 
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
945
  q_int8,
946
  k_int8,
build/torch27-cxx11-cu126-x86_64-linux/sage_attention/layers.py DELETED
File without changes
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
- from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
 
4
 
5
  __all__ = [
@@ -8,5 +8,4 @@ __all__ = [
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
11
- "sageattn_qk_int8_pv_fp8_cuda",
12
  ]
 
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
 
4
 
5
  __all__ = [
 
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
 
11
  ]
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc CHANGED
Binary files a/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc differ
 
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _sage_attention_af2d0c0_dirty
3
- ops = torch.ops._sage_attention_af2d0c0_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_sage_attention_af2d0c0_dirty::{op_name}"
 
1
  import torch
2
+ from . import _sage_attention_1369690_dirty
3
+ ops = torch.ops._sage_attention_1369690_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_sage_attention_1369690_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:871d2abf021f7175f2a66cd9f3599fdd88c9be0c98df1bb4d09f9905d955405f
3
  size 26553840
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1728ddb8a13631676b67cb867e7af21388f2f4d23279805bb3b5fa11bc6119c1
3
  size 26553840
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/core.py CHANGED
@@ -363,7 +363,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
@@ -379,7 +379,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
@@ -395,7 +395,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
@@ -410,7 +410,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
@@ -941,20 +941,6 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
944
- print(
945
- "qint8",
946
- q_int8.shape,
947
- "qscale",
948
- q_scale.shape,
949
- "kint8",
950
- k_int8.shape,
951
- "kscale",
952
- k_scale.shape,
953
- "vfp8",
954
- v_fp8.shape,
955
- "vscale",
956
- v_scale.shape,
957
- )
958
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
  q_int8,
960
  k_int8,
 
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
+ lse = ops.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
 
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
 
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
+ lse = ops.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
 
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
+ lse = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
 
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
945
  q_int8,
946
  k_int8,
build/torch27-cxx11-cu128-x86_64-linux/sage_attention/layers.py DELETED
File without changes
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
- from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
 
4
 
5
  __all__ = [
@@ -8,5 +8,4 @@ __all__ = [
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
11
- "sageattn_qk_int8_pv_fp8_cuda",
12
  ]
 
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
 
4
 
5
  __all__ = [
 
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
 
11
  ]
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc differ
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc differ
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc differ
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc differ
 
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _sage_attention_af2d0c0_dirty
3
- ops = torch.ops._sage_attention_af2d0c0_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_sage_attention_af2d0c0_dirty::{op_name}"
 
1
  import torch
2
+ from . import _sage_attention_1369690_dirty
3
+ ops = torch.ops._sage_attention_1369690_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_sage_attention_1369690_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:26b18ae63bccd4c5926533ffa1d0995e7bf3faf7919c0c55e1b829267ac73afd
3
  size 26037392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:636461b53d3b3c4c1cd1940bc6ecb32728cb0f80bb347cf52591afd0ea121c8c
3
  size 26037392
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/core.py CHANGED
@@ -363,7 +363,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
@@ -379,7 +379,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
@@ -395,7 +395,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
@@ -410,7 +410,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
@@ -941,20 +941,6 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
944
- print(
945
- "qint8",
946
- q_int8.shape,
947
- "qscale",
948
- q_scale.shape,
949
- "kint8",
950
- k_int8.shape,
951
- "kscale",
952
- k_scale.shape,
953
- "vfp8",
954
- v_fp8.shape,
955
- "vscale",
956
- v_scale.shape,
957
- )
958
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
  q_int8,
960
  k_int8,
 
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
+ lse = ops.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
 
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
 
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
+ lse = ops.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
 
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
+ lse = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
 
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
945
  q_int8,
946
  k_int8,
build/torch28-cxx11-cu126-x86_64-linux/sage_attention/layers.py DELETED
File without changes
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
- from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
 
4
 
5
  __all__ = [
@@ -8,5 +8,4 @@ __all__ = [
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
11
- "sageattn_qk_int8_pv_fp8_cuda",
12
  ]
 
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
 
4
 
5
  __all__ = [
 
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
 
11
  ]
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc differ
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc differ
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc differ
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc differ
 
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _sage_attention_af2d0c0_dirty
3
- ops = torch.ops._sage_attention_af2d0c0_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_sage_attention_af2d0c0_dirty::{op_name}"
 
1
  import torch
2
+ from . import _sage_attention_1369690_dirty
3
+ ops = torch.ops._sage_attention_1369690_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_sage_attention_1369690_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2681241cb3fee535e10ba52179293982bca60a5fed972404fdec8ae5fa848175
3
- size 26549824
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bfc9cab8f63665571d07c111a94edb3bec9a17aba7721a4c67be5392db0841d
3
+ size 26553920
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/core.py CHANGED
@@ -363,7 +363,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
@@ -379,7 +379,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
@@ -395,7 +395,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
@@ -410,7 +410,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
- lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
@@ -941,20 +941,6 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
944
- print(
945
- "qint8",
946
- q_int8.shape,
947
- "qscale",
948
- q_scale.shape,
949
- "kint8",
950
- k_int8.shape,
951
- "kscale",
952
- k_scale.shape,
953
- "vfp8",
954
- v_fp8.shape,
955
- "vscale",
956
- v_scale.shape,
957
- )
958
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
959
  q_int8,
960
  k_int8,
 
363
 
364
  if pv_accum_dtype == "fp32":
365
  v = v.to(torch.float16)
366
+ lse = ops.qk_int8_sv_f16_accum_f32_attn(
367
  q_int8,
368
  k_int8,
369
  v,
 
379
  elif pv_accum_dtype == "fp16":
380
  if smooth_v:
381
  smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
382
+ lse = ops.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(
383
  q_int8,
384
  k_int8,
385
  smoothed_v,
 
395
  )
396
  else:
397
  v = v.to(torch.float16)
398
+ lse = ops.qk_int8_sv_f16_accum_f16_attn(
399
  q_int8,
400
  k_int8,
401
  v,
 
410
  )
411
  elif pv_accum_dtype == "fp16+fp32":
412
  v = v.to(torch.float16)
413
+ lse = ops.qk_int8_sv_f16_accum_f16_attn_inst_buf(
414
  q_int8,
415
  k_int8,
416
  v,
 
941
  _return_lse,
942
  )
943
  elif pv_accum_dtype == "fp32+fp32":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  lse = ops.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(
945
  q_int8,
946
  k_int8,
build/torch28-cxx11-cu128-x86_64-linux/sage_attention/layers.py DELETED
File without changes
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
- from .core import sageattn, sageattn_qk_int8_pv_fp8_cuda
3
 
4
 
5
  __all__ = [
@@ -8,5 +8,4 @@ __all__ = [
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
11
- "sageattn_qk_int8_pv_fp8_cuda",
12
  ]
 
1
  from .quant import per_block_int8, per_warp_int8, sub_mean, per_channel_fp8
2
+ from .core import sageattn
3
 
4
 
5
  __all__ = [
 
8
  "sub_mean",
9
  "per_channel_fp8",
10
  "sageattn",
 
11
  ]
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/_ops.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/core.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc CHANGED
Binary files a/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/sage_attention/__pycache__/quant_per_thread.cpython-313.pyc differ
 
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _sage_attention_af2d0c0_dirty
3
- ops = torch.ops._sage_attention_af2d0c0_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_sage_attention_af2d0c0_dirty::{op_name}"
 
1
  import torch
2
+ from . import _sage_attention_1369690_dirty
3
+ ops = torch.ops._sage_attention_1369690_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_sage_attention_1369690_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/sage_attention/{_sage_attention_af2d0c0_dirty.abi3.so → _sage_attention_1369690_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff47cafcc3abed4dc02589ee11c315f3b88f65a0510caa89a07825ccd8ea1a48
3
  size 26608048
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fef61629b7537ad41d31cf5715a11a38ce3f7cc97b0d5bf26356492b36ad5c29
3
  size 26608048