import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension from os.path import join CPU_ONLY = False project_root = 'Correlation_Module' source_files = ['correlation.cpp', 'correlation_sampler.cpp'] cxx_args = ['-std=c++17', '-fopenmp'] def generate_nvcc_args(gpu_archs): nvcc_args = [] for arch in gpu_archs: nvcc_args.extend(['-gencode', f'arch=compute_{arch},code=sm_{arch}']) return nvcc_args gpu_arch = os.environ.get('GPU_ARCH', '').split() nvcc_args = generate_nvcc_args(gpu_arch) with open("README.md", "r") as fh: long_description = fh.read() def launch_setup(): if CPU_ONLY: Extension = CppExtension macro = [] else: Extension = CUDAExtension source_files.append('correlation_cuda_kernel.cu') macro = [("USE_CUDA", None)] sources = [join(project_root, file) for file in source_files] setup( name='spatial_correlation_sampler', version="0.4.0", author="Clément Pinard", author_email="clement.pinard@ensta-paristech.fr", description="Correlation module for pytorch", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ClementPinard/Pytorch-Correlation-extension", install_requires=['torch>=1.1', 'numpy'], ext_modules=[ Extension('spatial_correlation_sampler_backend', sources, define_macros=macro, extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}, extra_link_args=['-lgomp']) ], package_dir={'': project_root}, packages=['spatial_correlation_sampler'], cmdclass={ 'build_ext': BuildExtension }, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Artificial Intelligence" ]) if __name__ == '__main__': launch_setup()