|
|
|
|
|
from setuptools import find_packages, setup |
|
|
|
import os |
|
import subprocess |
|
import sys |
|
import time |
|
import torch |
|
from torch.utils.cpp_extension import (BuildExtension, CppExtension, |
|
CUDAExtension) |
|
|
|
version_file = 'basicsr/version.py' |
|
|
|
|
|
def readme(): |
|
return '' |
|
|
|
|
|
|
|
|
|
|
|
def get_git_hash(): |
|
|
|
def _minimal_ext_cmd(cmd): |
|
|
|
env = {} |
|
for k in ['SYSTEMROOT', 'PATH', 'HOME']: |
|
v = os.environ.get(k) |
|
if v is not None: |
|
env[k] = v |
|
|
|
env['LANGUAGE'] = 'C' |
|
env['LANG'] = 'C' |
|
env['LC_ALL'] = 'C' |
|
out = subprocess.Popen( |
|
cmd, stdout=subprocess.PIPE, env=env).communicate()[0] |
|
return out |
|
|
|
try: |
|
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) |
|
sha = out.strip().decode('ascii') |
|
except OSError: |
|
sha = 'unknown' |
|
|
|
return sha |
|
|
|
|
|
def get_hash(): |
|
if os.path.exists('.git'): |
|
sha = get_git_hash()[:7] |
|
elif os.path.exists(version_file): |
|
try: |
|
from basicsr.version import __version__ |
|
sha = __version__.split('+')[-1] |
|
except ImportError: |
|
raise ImportError('Unable to get git version') |
|
else: |
|
sha = 'unknown' |
|
|
|
return sha |
|
|
|
|
|
def write_version_py(): |
|
content = """# GENERATED VERSION FILE |
|
# TIME: {} |
|
__version__ = '{}' |
|
short_version = '{}' |
|
version_info = ({}) |
|
""" |
|
sha = get_hash() |
|
with open('VERSION', 'r') as f: |
|
SHORT_VERSION = f.read().strip() |
|
VERSION_INFO = ', '.join( |
|
[x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) |
|
VERSION = SHORT_VERSION + '+' + sha |
|
|
|
version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, |
|
VERSION_INFO) |
|
with open(version_file, 'w') as f: |
|
f.write(version_file_str) |
|
|
|
|
|
def get_version(): |
|
with open(version_file, 'r') as f: |
|
exec(compile(f.read(), version_file, 'exec')) |
|
return locals()['__version__'] |
|
|
|
|
|
def make_cuda_ext(name, module, sources, sources_cuda=None): |
|
if sources_cuda is None: |
|
sources_cuda = [] |
|
define_macros = [] |
|
extra_compile_args = {'cxx': []} |
|
|
|
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': |
|
define_macros += [('WITH_CUDA', None)] |
|
extension = CUDAExtension |
|
extra_compile_args['nvcc'] = [ |
|
'-D__CUDA_NO_HALF_OPERATORS__', |
|
'-D__CUDA_NO_HALF_CONVERSIONS__', |
|
'-D__CUDA_NO_HALF2_OPERATORS__', |
|
] |
|
sources += sources_cuda |
|
else: |
|
print(f'Compiling {name} without CUDA') |
|
extension = CppExtension |
|
|
|
return extension( |
|
name=f'{module}.{name}', |
|
sources=[os.path.join(*module.split('.'), p) for p in sources], |
|
define_macros=define_macros, |
|
extra_compile_args=extra_compile_args) |
|
|
|
|
|
def get_requirements(filename='requirements.txt'): |
|
return [] |
|
here = os.path.dirname(os.path.realpath(__file__)) |
|
with open(os.path.join(here, filename), 'r') as f: |
|
requires = [line.replace('\n', '') for line in f.readlines()] |
|
return requires |
|
|
|
|
|
if __name__ == '__main__': |
|
if '--no_cuda_ext' in sys.argv: |
|
ext_modules = [] |
|
sys.argv.remove('--no_cuda_ext') |
|
else: |
|
ext_modules = [ |
|
make_cuda_ext( |
|
name='deform_conv_ext', |
|
module='basicsr.models.ops.dcn', |
|
sources=['src/deform_conv_ext.cpp'], |
|
sources_cuda=[ |
|
'src/deform_conv_cuda.cpp', |
|
'src/deform_conv_cuda_kernel.cu' |
|
]), |
|
make_cuda_ext( |
|
name='fused_act_ext', |
|
module='basicsr.models.ops.fused_act', |
|
sources=['src/fused_bias_act.cpp'], |
|
sources_cuda=['src/fused_bias_act_kernel.cu']), |
|
make_cuda_ext( |
|
name='upfirdn2d_ext', |
|
module='basicsr.models.ops.upfirdn2d', |
|
sources=['src/upfirdn2d.cpp'], |
|
sources_cuda=['src/upfirdn2d_kernel.cu']), |
|
] |
|
|
|
write_version_py() |
|
setup( |
|
name='basicsr', |
|
version=get_version(), |
|
description='Open Source Image and Video Super-Resolution Toolbox', |
|
long_description=readme(), |
|
author='Xintao Wang', |
|
author_email='xintao.wang@outlook.com', |
|
keywords='computer vision, restoration, super resolution', |
|
url='https://github.com/xinntao/BasicSR', |
|
packages=find_packages( |
|
exclude=('options', 'datasets', 'experiments', 'results', |
|
'tb_logger', 'wandb')), |
|
classifiers=[ |
|
'Development Status :: 4 - Beta', |
|
'License :: OSI Approved :: Apache Software License', |
|
'Operating System :: OS Independent', |
|
'Programming Language :: Python :: 3', |
|
'Programming Language :: Python :: 3.7', |
|
'Programming Language :: Python :: 3.8', |
|
], |
|
license='Apache License 2.0', |
|
setup_requires=['cython', 'numpy'], |
|
install_requires=get_requirements(), |
|
ext_modules=ext_modules, |
|
cmdclass={'build_ext': BuildExtension}, |
|
zip_safe=False) |
|
|