|
|
|
|
|
|
|
|
|
"""isort:skip_file""" |
|
|
|
import functools |
|
import importlib |
|
|
|
|
|
dependencies = [ |
|
"dataclasses", |
|
"hydra", |
|
"numpy", |
|
"omegaconf", |
|
"regex", |
|
"requests", |
|
"torch", |
|
] |
|
|
|
|
|
|
|
missing_deps = [] |
|
for dep in dependencies: |
|
try: |
|
importlib.import_module(dep) |
|
except ImportError: |
|
|
|
|
|
|
|
if dep == "hydra": |
|
dep = "hydra-core" |
|
missing_deps.append(dep) |
|
if len(missing_deps) > 0: |
|
raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) |
|
|
|
|
|
|
|
from fairseq.hub_utils import ( |
|
BPEHubInterface as bpe, |
|
TokenizerHubInterface as tokenizer, |
|
) |
|
from fairseq.models import MODEL_REGISTRY |
|
|
|
|
|
|
|
|
|
try: |
|
import fairseq.data.token_block_utils_fast |
|
except ImportError: |
|
try: |
|
import cython |
|
import os |
|
from setuptools import sandbox |
|
|
|
sandbox.run_setup( |
|
os.path.join(os.path.dirname(__file__), "setup.py"), |
|
["build_ext", "--inplace"], |
|
) |
|
except ImportError: |
|
print( |
|
"Unable to build Cython components. Please make sure Cython is " |
|
"installed if the torch.hub model you are loading depends on it." |
|
) |
|
|
|
|
|
|
|
for _model_type, _cls in MODEL_REGISTRY.items(): |
|
for model_name in _cls.hub_models().keys(): |
|
globals()[model_name] = functools.partial( |
|
_cls.from_pretrained, |
|
model_name, |
|
) |
|
|