nunchaku-kontext / setup.py
D3vShoaib's picture
Add Git LFS support and remove binary files
04eaca9
import os
import re
import subprocess
import sys
from datetime import date
import setuptools
import torch
from packaging import version as packaging_version
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
class CustomBuildExtension(BuildExtension):
def build_extensions(self):
for ext in self.extensions:
if not "cxx" in ext.extra_compile_args:
ext.extra_compile_args["cxx"] = []
if not "nvcc" in ext.extra_compile_args:
ext.extra_compile_args["nvcc"] = []
if self.compiler.compiler_type == "msvc":
ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"]
else:
ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
super().build_extensions()
def get_sm_targets() -> list[str]:
nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
try:
nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
if match:
nvcc_version = match.group(2)
else:
raise Exception("nvcc version not found")
print(f"Found nvcc version: {nvcc_version}")
except:
raise Exception("nvcc not found")
support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")
install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
if install_mode == "FAST":
ret = []
for i in range(torch.cuda.device_count()):
capability = torch.cuda.get_device_capability(i)
sm = f"{capability[0]}{capability[1]}"
if sm == "120" and support_sm120:
sm = "120a"
ret.append(sm)
return ret
elif install_mode == "ALL":
# All supported architectures (except for experimental ones)
sm_targets = ["75", "80", "86", "89", "90"]
if support_sm120:
sm_targets.append("120a")
return sm_targets
else:
raise ValueError(f"Unknown install mode: {install_mode}")
FLUX_SOURCES = [
"nunchaku/csrc/pybind.cpp",
]
ext_modules = []
# Check if CUDA is available
if torch.cuda.is_available() and CUDA_HOME is not None:
sm_targets = get_sm_targets()
arch_flags = [f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_targets]
ext_modules.append(
CUDAExtension(
"nunchaku._C",
FLUX_SOURCES,
extra_compile_args={
"cxx": ["-O3", "-std=c++20"],
"nvcc": [
"-O3",
"-std=c++20",
"--use_fast_math",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
] + arch_flags,
"msvc": ["/std:c++20"],
"gcc": ["-std=c++20"],
"nvcc_msvc": [],
},
include_dirs=[
"third_party/cutlass/include",
"third_party/cutlass/tools/util/include",
],
)
)
else:
print("CUDA not available. Installing CPU-only version.")
setuptools.setup(
name="flux-kontext",
packages=setuptools.find_packages(),
ext_modules=ext_modules,
cmdclass={"build_ext": CustomBuildExtension},
zip_safe=False,
)