sage_attention / build.toml
medmekk's picture
medmekk HF Staff
update builds
a8031ce
[general]
name = "sage_attention"
universal = false
cuda-minver = "12.0"
[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]
[kernel._qattn]
depends = ["torch"]
backend = "cuda"
cuda-minver = "12.0"
cuda-capabilities = [
"8.0", "8.9", "9.0a"
]
src = [
"sage_attention/cp_async.cuh",
"sage_attention/dispatch_utils.h",
"sage_attention/math.cuh",
"sage_attention/mma.cuh",
"sage_attention/numeric_conversion.cuh",
"sage_attention/permuted_smem.cuh",
"sage_attention/reduction_utils.cuh",
"sage_attention/wgmma.cuh",
"sage_attention/utils.cuh",
"sage_attention/cuda_tensormap_shim.cuh",
]
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
cuda-flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=1",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]
[kernel._qattn_sm80]
depends = ["torch"]
backend = "cuda"
cuda-minver = "12.0"
cuda-capabilities = [
"8.0"
]
include = ["."]
src = [
"sage_attention/qattn/qk_int_sv_f16_cuda_sm80.cu",
"sage_attention/qattn/attn_cuda_sm80.h",
"sage_attention/qattn/attn_utils.cuh"
]
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
cuda-flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=1",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]
[kernel._qattn_sm89]
depends = ["torch"]
backend = "cuda"
cuda-minver = "12.0"
cuda-capabilities = [
"8.9",
]
include = ["."]
src = [
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu",
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu",
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu",
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu",
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu",
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
"sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu",
"sage_attention/qattn/attn_cuda_sm89.h",
"sage_attention/qattn/qk_int_sv_f8_cuda_sm89.cuh",
"sage_attention/qattn/attn_utils.cuh"
]
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
cuda-flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=1",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]
[kernel._qattn_sm90]
depends = ["torch"]
backend = "cuda"
cuda-minver = "12.0"
cuda-capabilities = [
"9.0a",
]
include = ["."]
src = [
"sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu",
"sage_attention/qattn/attn_cuda_sm90.h",
"sage_attention/qattn/attn_utils.cuh"
]
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
cuda-flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=1",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]
[kernel._fused]
depends = ["torch"]
backend = "cuda"
cuda-minver = "12.0"
cuda-capabilities = [
"8.0", "8.9", "9.0a",
]
include = ["."]
src = [
"sage_attention/fused/fused.cu",
"sage_attention/fused/fused.h"
]
cxx-flags = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
cuda-flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=1",
"-Xptxas=-v",
"-diag-suppress=174", # suppress the specific warning
]