|
|
|
import warnings
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from packaging.version import parse, Version
|
|
from setuptools import setup, find_packages
|
|
import subprocess
|
|
|
|
|
|
import torch
|
|
from torch.utils.cpp_extension import (
|
|
BuildExtension,
|
|
CppExtension,
|
|
CUDAExtension,
|
|
CUDA_HOME,
|
|
)
|
|
|
|
PACKAGE_NAME = "blackmamba"
|
|
VERSION = "0.0.1"
|
|
|
|
with open("README.md", "r", encoding="utf-8") as fh:
|
|
long_description = fh.read()
|
|
|
|
|
|
|
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
|
|
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
|
|
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
|
|
|
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
|
|
|
|
|
def get_cuda_bare_metal_version(cuda_dir):
|
|
raw_output = subprocess.check_output(
|
|
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
|
)
|
|
output = raw_output.split()
|
|
release_idx = output.index("release") + 1
|
|
bare_metal_version = parse(output[release_idx].split(",")[0])
|
|
|
|
return raw_output, bare_metal_version
|
|
|
|
|
|
def check_if_cuda_home_none(global_option: str) -> None:
|
|
if CUDA_HOME is not None:
|
|
return
|
|
|
|
|
|
warnings.warn(
|
|
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
|
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
|
"only images whose names contain 'devel' will provide nvcc."
|
|
)
|
|
|
|
|
|
def append_nvcc_threads(nvcc_extra_args):
|
|
return nvcc_extra_args + ["--threads", "4"]
|
|
|
|
|
|
ext_modules = []
|
|
if not SKIP_CUDA_BUILD:
|
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
|
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
|
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
|
|
|
check_if_cuda_home_none(PACKAGE_NAME)
|
|
|
|
cc_flag = []
|
|
if CUDA_HOME is not None:
|
|
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
|
if bare_metal_version < Version("11.6"):
|
|
raise RuntimeError(
|
|
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
|
|
"Note: make sure nvcc has a supported version by running nvcc -V."
|
|
)
|
|
|
|
cc_flag.append("-gencode")
|
|
cc_flag.append("arch=compute_70,code=sm_70")
|
|
cc_flag.append("-gencode")
|
|
cc_flag.append("arch=compute_80,code=sm_80")
|
|
if bare_metal_version >= Version("11.8"):
|
|
cc_flag.append("-gencode")
|
|
cc_flag.append("arch=compute_90,code=sm_90")
|
|
|
|
|
|
|
|
|
|
if FORCE_CXX11_ABI:
|
|
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
|
|
|
ext_modules.append(
|
|
CUDAExtension(
|
|
name="selective_scan_cuda",
|
|
sources=[
|
|
"csrc/selective_scan/selective_scan.cpp",
|
|
"csrc/selective_scan/selective_scan_fwd_fp32.cu",
|
|
"csrc/selective_scan/selective_scan_fwd_fp16.cu",
|
|
"csrc/selective_scan/selective_scan_fwd_bf16.cu",
|
|
"csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
|
|
"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
|
|
"csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
|
|
"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
|
|
"csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
|
|
"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
|
|
],
|
|
extra_compile_args={
|
|
"cxx": ["-O3", "-std=c++17"],
|
|
"nvcc": append_nvcc_threads(
|
|
[
|
|
"-O3",
|
|
"-std=c++17",
|
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
|
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
|
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
|
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
|
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
|
"--expt-relaxed-constexpr",
|
|
"--expt-extended-lambda",
|
|
"--use_fast_math",
|
|
"--ptxas-options=-v",
|
|
"-lineinfo",
|
|
]
|
|
+ cc_flag
|
|
),
|
|
},
|
|
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
|
|
)
|
|
)
|
|
|
|
|
|
setup(
|
|
name=PACKAGE_NAME,
|
|
version=VERSION,
|
|
description="Blackmamba state-space + MoE model",
|
|
long_description=long_description,
|
|
long_description_content_type="text/markdown",
|
|
packages=find_packages(include=['ops'],),
|
|
exclude=(
|
|
"csrc",
|
|
"blackmamba.egg-info",
|
|
),
|
|
ext_modules=ext_modules,
|
|
cmdclass={"build_ext": BuildExtension},
|
|
python_requires=">=3.7",
|
|
install_requires=[
|
|
"torch",
|
|
"packaging",
|
|
"ninja",
|
|
"einops",
|
|
"triton",
|
|
"transformers",
|
|
"causal_conv1d>=1.1.0",
|
|
],
|
|
) |