| """Local CUDA build for activation kernels. |
| |
| Usage: |
| pip install -e . # editable install |
| python setup.py build_ext --inplace # build only |
| |
| The built extension is named '_activation' and can be loaded via: |
| import _activation |
| torch.ops._activation.rms_norm(...) |
| """ |
|
|
| import os |
| from pathlib import Path |
|
|
| import torch |
| from setuptools import setup |
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension |
|
|
| ROOT = Path(__file__).parent |
|
|
| CUDA_SOURCES = [ |
| "activation/poly_norm.cu", |
| "activation/fused_mul_poly_norm.cu", |
| "activation/rms_norm.cu", |
| "activation/fused_add_rms_norm.cu", |
| "activation/grouped_poly_norm.cu", |
| ] |
|
|
| CPP_SOURCES = [ |
| "torch-ext/torch_binding.cpp", |
| ] |
|
|
| |
| |
| INCLUDE_DIRS = [ |
| str(ROOT), |
| str(ROOT / "activation"), |
| str(ROOT / "torch-ext"), |
| ] |
|
|
| |
| NVCC_FLAGS = [ |
| "-O3", |
| "--use_fast_math", |
| "-std=c++17", |
| |
| "-gencode=arch=compute_80,code=sm_80", |
| "-gencode=arch=compute_89,code=sm_89", |
| "-gencode=arch=compute_90,code=sm_90", |
| ] |
|
|
| |
| cuda_version = tuple(int(x) for x in torch.version.cuda.split(".")[:2]) |
| if cuda_version >= (12, 8): |
| NVCC_FLAGS.append("-gencode=arch=compute_100,code=sm_100") |
|
|
| CXX_FLAGS = ["-O3", "-std=c++17"] |
|
|
| ext_modules = [ |
| CUDAExtension( |
| name="_activation", |
| sources=[str(ROOT / s) for s in CPP_SOURCES + CUDA_SOURCES], |
| include_dirs=INCLUDE_DIRS, |
| extra_compile_args={ |
| "cxx": CXX_FLAGS, |
| "nvcc": NVCC_FLAGS, |
| }, |
| ), |
| ] |
|
|
| setup( |
| name="activation", |
| version="0.1.0", |
| description="Custom CUDA normalization kernels for LLM training", |
| ext_modules=ext_modules, |
| cmdclass={"build_ext": BuildExtension}, |
| packages=["activation"], |
| package_dir={"activation": "torch-ext/activation"}, |
| python_requires=">=3.10", |
| install_requires=["torch>=2.7"], |
| ) |
|
|