Spaces:
Running
Running
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import subprocess | |
import sys | |
from setuptools import Extension, find_packages, setup | |
if sys.version_info < (3, 6): | |
sys.exit("Sorry, Python >= 3.6 is required for fairseq.") | |
def write_version_py(): | |
with open(os.path.join("fairseq", "version.txt")) as f: | |
version = f.read().strip() | |
# append latest commit hash to version string | |
try: | |
sha = ( | |
subprocess.check_output(["git", "rev-parse", "HEAD"]) | |
.decode("ascii") | |
.strip() | |
) | |
version += "+" + sha[:7] | |
except Exception: | |
pass | |
# write version info to fairseq/version.py | |
with open(os.path.join("fairseq", "version.py"), "w") as f: | |
f.write('__version__ = "{}"\n'.format(version)) | |
return version | |
version = write_version_py() | |
with open("README.md") as f: | |
readme = f.read() | |
if sys.platform == "darwin": | |
extra_compile_args = ["-stdlib=libc++", "-O3"] | |
else: | |
extra_compile_args = ["-std=c++11", "-O3"] | |
class NumpyExtension(Extension): | |
"""Source: https://stackoverflow.com/a/54128391""" | |
def __init__(self, *args, **kwargs): | |
self.__include_dirs = [] | |
super().__init__(*args, **kwargs) | |
def include_dirs(self): | |
import numpy | |
return self.__include_dirs + [numpy.get_include()] | |
def include_dirs(self, dirs): | |
self.__include_dirs = dirs | |
extensions = [ | |
Extension( | |
"fairseq.libbleu", | |
sources=[ | |
"fairseq/clib/libbleu/libbleu.cpp", | |
"fairseq/clib/libbleu/module.cpp", | |
], | |
extra_compile_args=extra_compile_args, | |
), | |
NumpyExtension( | |
"fairseq.data.data_utils_fast", | |
sources=["fairseq/data/data_utils_fast.pyx"], | |
language="c++", | |
extra_compile_args=extra_compile_args, | |
), | |
NumpyExtension( | |
"fairseq.data.token_block_utils_fast", | |
sources=["fairseq/data/token_block_utils_fast.pyx"], | |
language="c++", | |
extra_compile_args=extra_compile_args, | |
), | |
] | |
cmdclass = {} | |
try: | |
# torch is not available when generating docs | |
from torch.utils import cpp_extension | |
extensions.extend( | |
[ | |
cpp_extension.CppExtension( | |
"fairseq.libbase", | |
sources=[ | |
"fairseq/clib/libbase/balanced_assignment.cpp", | |
], | |
) | |
] | |
) | |
extensions.extend( | |
[ | |
cpp_extension.CppExtension( | |
"fairseq.libnat", | |
sources=[ | |
"fairseq/clib/libnat/edit_dist.cpp", | |
], | |
), | |
cpp_extension.CppExtension( | |
"alignment_train_cpu_binding", | |
sources=[ | |
"examples/operators/alignment_train_cpu.cpp", | |
], | |
), | |
] | |
) | |
if "CUDA_HOME" in os.environ: | |
extensions.extend( | |
[ | |
cpp_extension.CppExtension( | |
"fairseq.libnat_cuda", | |
sources=[ | |
"fairseq/clib/libnat_cuda/edit_dist.cu", | |
"fairseq/clib/libnat_cuda/binding.cpp", | |
], | |
), | |
cpp_extension.CppExtension( | |
"fairseq.ngram_repeat_block_cuda", | |
sources=[ | |
"fairseq/clib/cuda/ngram_repeat_block_cuda.cpp", | |
"fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu", | |
], | |
), | |
cpp_extension.CppExtension( | |
"alignment_train_cuda_binding", | |
sources=[ | |
"examples/operators/alignment_train_kernel.cu", | |
"examples/operators/alignment_train_cuda.cpp", | |
], | |
), | |
] | |
) | |
cmdclass["build_ext"] = cpp_extension.BuildExtension | |
except ImportError: | |
pass | |
if "READTHEDOCS" in os.environ: | |
# don't build extensions when generating docs | |
extensions = [] | |
if "build_ext" in cmdclass: | |
del cmdclass["build_ext"] | |
# use CPU build of PyTorch | |
dependency_links = [ | |
"https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" | |
] | |
else: | |
dependency_links = [] | |
if "clean" in sys.argv[1:]: | |
# Source: https://bit.ly/2NLVsgE | |
print("deleting Cython files...") | |
import subprocess | |
subprocess.run( | |
["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"], | |
shell=True, | |
) | |
extra_packages = [] | |
if os.path.exists(os.path.join("fairseq", "model_parallel", "megatron", "mpu")): | |
extra_packages.append("fairseq.model_parallel.megatron.mpu") | |
def do_setup(package_data): | |
setup( | |
name="fairseq", | |
version=version, | |
description="Facebook AI Research Sequence-to-Sequence Toolkit", | |
url="https://github.com/pytorch/fairseq", | |
classifiers=[ | |
"Intended Audience :: Science/Research", | |
"License :: OSI Approved :: MIT License", | |
"Programming Language :: Python :: 3.6", | |
"Programming Language :: Python :: 3.7", | |
"Programming Language :: Python :: 3.8", | |
"Topic :: Scientific/Engineering :: Artificial Intelligence", | |
], | |
long_description=readme, | |
long_description_content_type="text/markdown", | |
setup_requires=[ | |
"cython", | |
'numpy<1.20.0; python_version<"3.7"', | |
'numpy; python_version>="3.7"', | |
"setuptools>=18.0", | |
], | |
install_requires=[ | |
"cffi", | |
"cython", | |
'dataclasses; python_version<"3.7"', | |
"hydra-core>=1.0.7,<1.1", | |
"omegaconf<2.1", | |
'numpy<1.20.0; python_version<"3.7"', | |
'numpy; python_version>="3.7"', | |
"regex", | |
"sacrebleu>=1.4.12", | |
# "torch", | |
"tqdm", | |
"bitarray", | |
# "torchaudio>=0.8.0", | |
], | |
dependency_links=dependency_links, | |
packages=find_packages( | |
exclude=[ | |
"examples", | |
"examples.*", | |
"scripts", | |
"scripts.*", | |
"tests", | |
"tests.*", | |
] | |
) | |
+ extra_packages, | |
package_data=package_data, | |
ext_modules=extensions, | |
test_suite="tests", | |
entry_points={ | |
"console_scripts": [ | |
"fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", | |
"fairseq-generate = fairseq_cli.generate:cli_main", | |
"fairseq-hydra-train = fairseq_cli.hydra_train:cli_main", | |
"fairseq-interactive = fairseq_cli.interactive:cli_main", | |
"fairseq-preprocess = fairseq_cli.preprocess:cli_main", | |
"fairseq-score = fairseq_cli.score:cli_main", | |
"fairseq-train = fairseq_cli.train:cli_main", | |
"fairseq-validate = fairseq_cli.validate:cli_main", | |
], | |
}, | |
cmdclass=cmdclass, | |
zip_safe=False, | |
) | |
def get_files(path, relative_to="fairseq"): | |
all_files = [] | |
for root, _dirs, files in os.walk(path, followlinks=True): | |
root = os.path.relpath(root, relative_to) | |
for file in files: | |
if file.endswith(".pyc"): | |
continue | |
all_files.append(os.path.join(root, file)) | |
return all_files | |
if __name__ == "__main__": | |
try: | |
# symlink examples into fairseq package so package_data accepts them | |
fairseq_examples = os.path.join("fairseq", "examples") | |
if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): | |
os.symlink(os.path.join("..", "examples"), fairseq_examples) | |
package_data = { | |
"fairseq": ( | |
get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config")) | |
) | |
} | |
do_setup(package_data) | |
finally: | |
if "build_ext" not in sys.argv[1:] and os.path.islink(fairseq_examples): | |
os.unlink(fairseq_examples) | |