ziqima's picture
initial commit
4893ce0
import os
from sys import argv
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from distutils.sysconfig import get_config_vars
(opt,) = get_config_vars("OPT")
os.environ["OPT"] = " ".join(
flag for flag in opt.split() if flag != "-Wstrict-prototypes"
)
def _argparse(pattern, argv, is_flag=True, is_list=False):
if is_flag:
found = pattern in argv
if found:
argv.remove(pattern)
return found, argv
else:
arr = [arg for arg in argv if pattern == arg.split("=")[0]]
if is_list:
if len(arr) == 0: # not found
return False, argv
else:
assert "=" in arr[0], f"{arr[0]} requires a value."
argv.remove(arr[0])
val = arr[0].split("=")[1]
if "," in val:
return val.split(","), argv
else:
return [val], argv
else:
if len(arr) == 0: # not found
return False, argv
else:
assert "=" in arr[0], f"{arr[0]} requires a value."
argv.remove(arr[0])
return arr[0].split("=")[1], argv
INCLUDE_DIRS, argv = _argparse("--include_dirs", argv, False, is_list=True)
include_dirs = []
if not (INCLUDE_DIRS is False):
include_dirs += INCLUDE_DIRS
setup(
name="pointgroup_ops",
packages=["pointgroup_ops"],
package_dir={"pointgroup_ops": "functions"},
ext_modules=[
CUDAExtension(
name="pointgroup_ops_cuda",
sources=["src/bfs_cluster.cpp", "src/bfs_cluster_kernel.cu"],
extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]},
)
],
include_dirs=[*include_dirs],
cmdclass={"build_ext": BuildExtension},
)