Kernels
activation / setup.py
wyldecat's picture
style: apply yapf, isort, and clang-format
6436ad6
"""Local CUDA build for activation kernels.
Usage:
pip install -e . # editable install
python setup.py build_ext --inplace # build only
The built extension is named '_activation' and can be loaded via:
import _activation
torch.ops._activation.rms_norm(...)
"""
import os
from pathlib import Path
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
ROOT = Path(__file__).parent
CUDA_SOURCES = [
"activation/poly_norm.cu",
"activation/fused_mul_poly_norm.cu",
"activation/rms_norm.cu",
"activation/fused_add_rms_norm.cu",
"activation/grouped_poly_norm.cu",
]
CPP_SOURCES = [
"torch-ext/torch_binding.cpp",
]
# Include dirs: project root (for registration.h, activation/*.h)
# and torch-ext/ (for torch_binding.h)
INCLUDE_DIRS = [
str(ROOT),
str(ROOT / "activation"),
str(ROOT / "torch-ext"),
]
# CUDA flags matching the existing kernel style
NVCC_FLAGS = [
"-O3",
"--use_fast_math",
"-std=c++17",
# Generate code for common architectures
"-gencode=arch=compute_80,code=sm_80", # A100
"-gencode=arch=compute_89,code=sm_89", # L40/4090
"-gencode=arch=compute_90,code=sm_90", # H100
]
# Check for B200 support (sm_100, requires CUDA 12.8+)
cuda_version = tuple(int(x) for x in torch.version.cuda.split(".")[:2])
if cuda_version >= (12, 8):
NVCC_FLAGS.append("-gencode=arch=compute_100,code=sm_100")
CXX_FLAGS = ["-O3", "-std=c++17"]
ext_modules = [
CUDAExtension(
name="_activation",
sources=[str(ROOT / s) for s in CPP_SOURCES + CUDA_SOURCES],
include_dirs=INCLUDE_DIRS,
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
),
]
setup(
name="activation",
version="0.1.0",
description="Custom CUDA normalization kernels for LLM training",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
packages=["activation"],
package_dir={"activation": "torch-ext/activation"},
python_requires=">=3.10",
install_requires=["torch>=2.7"],
)