|
|
[general] |
|
|
name = "sage_attention" |
|
|
universal = false |
|
|
cuda-minver = "12.0" |
|
|
|
|
|
[torch] |
|
|
src = [ |
|
|
"torch-ext/torch_binding.cpp", |
|
|
"torch-ext/torch_binding.h", |
|
|
] |
|
|
|
|
|
[kernel._qattn] |
|
|
depends = ["torch"] |
|
|
backend = "cuda" |
|
|
cuda-minver = "12.0" |
|
|
cuda-capabilities = [ |
|
|
"8.0", "8.9", "9.0a" |
|
|
] |
|
|
src = [ |
|
|
"sage_attention/cp_async.cuh", |
|
|
"sage_attention/dispatch_utils.h", |
|
|
"sage_attention/math.cuh", |
|
|
"sage_attention/mma.cuh", |
|
|
"sage_attention/numeric_conversion.cuh", |
|
|
"sage_attention/permuted_smem.cuh", |
|
|
"sage_attention/reduction_utils.cuh", |
|
|
"sage_attention/wgmma.cuh", |
|
|
"sage_attention/utils.cuh", |
|
|
"sage_attention/cuda_tensormap_shim.cuh", |
|
|
] |
|
|
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] |
|
|
cuda-flags = [ |
|
|
"-O3", |
|
|
"-std=c++17", |
|
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
|
"--use_fast_math", |
|
|
"--threads=1", |
|
|
"-Xptxas=-v", |
|
|
"-diag-suppress=174", |
|
|
] |
|
|
|
|
|
[kernel._qattn_sm80] |
|
|
depends = ["torch"] |
|
|
backend = "cuda" |
|
|
cuda-minver = "12.0" |
|
|
cuda-capabilities = [ |
|
|
"8.0" |
|
|
] |
|
|
include = ["."] |
|
|
src = [ |
|
|
"sage_attention/qattn/qk_int_sv_f16_cuda_sm80.cu", |
|
|
"sage_attention/qattn/attn_cuda_sm80.h", |
|
|
"sage_attention/qattn/attn_utils.cuh" |
|
|
] |
|
|
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] |
|
|
cuda-flags = [ |
|
|
"-O3", |
|
|
"-std=c++17", |
|
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
|
"--use_fast_math", |
|
|
"--threads=1", |
|
|
"-Xptxas=-v", |
|
|
"-diag-suppress=174", |
|
|
] |
|
|
|
|
|
|
|
|
[kernel._qattn_sm89] |
|
|
depends = ["torch"] |
|
|
backend = "cuda" |
|
|
cuda-minver = "12.0" |
|
|
cuda-capabilities = [ |
|
|
"8.9", |
|
|
] |
|
|
include = ["."] |
|
|
src = [ |
|
|
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu", |
|
|
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu", |
|
|
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu", |
|
|
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu", |
|
|
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu", |
|
|
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu", |
|
|
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu", |
|
|
"sage_attention/qattn/attn_cuda_sm89.h", |
|
|
"sage_attention/qattn/qk_int_sv_f8_cuda_sm89.cuh", |
|
|
"sage_attention/qattn/attn_utils.cuh" |
|
|
] |
|
|
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] |
|
|
cuda-flags = [ |
|
|
"-O3", |
|
|
"-std=c++17", |
|
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
|
"--use_fast_math", |
|
|
"--threads=1", |
|
|
"-Xptxas=-v", |
|
|
"-diag-suppress=174", |
|
|
] |
|
|
|
|
|
|
|
|
[kernel._qattn_sm90] |
|
|
depends = ["torch"] |
|
|
backend = "cuda" |
|
|
cuda-minver = "12.0" |
|
|
cuda-capabilities = [ |
|
|
"9.0a", |
|
|
] |
|
|
include = ["."] |
|
|
src = [ |
|
|
"sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu", |
|
|
"sage_attention/qattn/attn_cuda_sm90.h", |
|
|
"sage_attention/qattn/attn_utils.cuh" |
|
|
] |
|
|
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] |
|
|
cuda-flags = [ |
|
|
"-O3", |
|
|
"-std=c++17", |
|
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
|
"--use_fast_math", |
|
|
"--threads=1", |
|
|
"-Xptxas=-v", |
|
|
"-diag-suppress=174", |
|
|
] |
|
|
|
|
|
[kernel._fused] |
|
|
depends = ["torch"] |
|
|
backend = "cuda" |
|
|
cuda-minver = "12.0" |
|
|
cuda-capabilities = [ |
|
|
"8.0", "8.9", "9.0a", |
|
|
] |
|
|
include = ["."] |
|
|
src = [ |
|
|
"sage_attention/fused/fused.cu", |
|
|
"sage_attention/fused/fused.h" |
|
|
] |
|
|
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] |
|
|
cuda-flags = [ |
|
|
"-O3", |
|
|
"-std=c++17", |
|
|
"-U__CUDA_NO_HALF_OPERATORS__", |
|
|
"-U__CUDA_NO_HALF_CONVERSIONS__", |
|
|
"--use_fast_math", |
|
|
"--threads=1", |
|
|
"-Xptxas=-v", |
|
|
"-diag-suppress=174", |
|
|
] |
|
|
|
|
|
|