import os from sys import argv from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension from distutils.sysconfig import get_config_vars (opt,) = get_config_vars("OPT") os.environ["OPT"] = " ".join( flag for flag in opt.split() if flag != "-Wstrict-prototypes" ) def _argparse(pattern, argv, is_flag=True, is_list=False): if is_flag: found = pattern in argv if found: argv.remove(pattern) return found, argv else: arr = [arg for arg in argv if pattern == arg.split("=")[0]] if is_list: if len(arr) == 0: # not found return False, argv else: assert "=" in arr[0], f"{arr[0]} requires a value." argv.remove(arr[0]) val = arr[0].split("=")[1] if "," in val: return val.split(","), argv else: return [val], argv else: if len(arr) == 0: # not found return False, argv else: assert "=" in arr[0], f"{arr[0]} requires a value." argv.remove(arr[0]) return arr[0].split("=")[1], argv INCLUDE_DIRS, argv = _argparse("--include_dirs", argv, False, is_list=True) include_dirs = [] if not (INCLUDE_DIRS is False): include_dirs += INCLUDE_DIRS setup( name="pointgroup_ops", packages=["pointgroup_ops"], package_dir={"pointgroup_ops": "functions"}, ext_modules=[ CUDAExtension( name="pointgroup_ops_cuda", sources=["src/bfs_cluster.cpp", "src/bfs_cluster_kernel.cu"], extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, ) ], include_dirs=[*include_dirs], cmdclass={"build_ext": BuildExtension}, )