# Copyright 2022-present NAVER Corp. | |
# CC BY-NC-SA 4.0 | |
# Available only for non-commercial use | |
from setuptools import setup | |
from torch import cuda | |
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | |
# if you want to compile for all possible CUDA architectures | |
all_cuda_archs = [] #cuda.get_gencode_flags().replace('compute=','arch=').split() | |
setup( | |
name='cuda_deepm', | |
ext_modules = [ | |
CUDAExtension( | |
name = 'cuda_deepm', | |
sources = ["func.cpp", "kernels.cu"], | |
extra_compile_args = dict(nvcc=['-O2']+all_cuda_archs, cxx=['-O2']) | |
) | |
], | |
cmdclass = { | |
'build_ext': BuildExtension | |
}) | |