|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Import utilities: Utilities related to imports and our lazy inits. |
|
""" |
|
import importlib.util |
|
import os |
|
import sys |
|
from collections import OrderedDict |
|
|
|
from packaging import version |
|
|
|
from . import logging |
|
|
|
|
|
|
|
if sys.version_info < (3, 8): |
|
import importlib_metadata |
|
else: |
|
import importlib.metadata as importlib_metadata |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} |
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) |
|
|
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper() |
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() |
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() |
|
|
|
_torch_version = "N/A" |
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: |
|
_torch_available = importlib.util.find_spec("torch") is not None |
|
if _torch_available: |
|
try: |
|
_torch_version = importlib_metadata.version("torch") |
|
logger.info(f"PyTorch version {_torch_version} available.") |
|
except importlib_metadata.PackageNotFoundError: |
|
_torch_available = False |
|
else: |
|
logger.info("Disabling PyTorch because USE_TF is set") |
|
_torch_available = False |
|
|
|
|
|
_tf_version = "N/A" |
|
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: |
|
_tf_available = importlib.util.find_spec("tensorflow") is not None |
|
if _tf_available: |
|
candidates = ( |
|
"tensorflow", |
|
"tensorflow-cpu", |
|
"tensorflow-gpu", |
|
"tf-nightly", |
|
"tf-nightly-cpu", |
|
"tf-nightly-gpu", |
|
"intel-tensorflow", |
|
"intel-tensorflow-avx512", |
|
"tensorflow-rocm", |
|
"tensorflow-macos", |
|
"tensorflow-aarch64", |
|
) |
|
_tf_version = None |
|
|
|
for pkg in candidates: |
|
try: |
|
_tf_version = importlib_metadata.version(pkg) |
|
break |
|
except importlib_metadata.PackageNotFoundError: |
|
pass |
|
_tf_available = _tf_version is not None |
|
if _tf_available: |
|
if version.parse(_tf_version) < version.parse("2"): |
|
logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") |
|
_tf_available = False |
|
else: |
|
logger.info(f"TensorFlow version {_tf_version} available.") |
|
else: |
|
logger.info("Disabling Tensorflow because USE_TORCH is set") |
|
_tf_available = False |
|
|
|
|
|
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: |
|
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None |
|
if _flax_available: |
|
try: |
|
_jax_version = importlib_metadata.version("jax") |
|
_flax_version = importlib_metadata.version("flax") |
|
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") |
|
except importlib_metadata.PackageNotFoundError: |
|
_flax_available = False |
|
else: |
|
_flax_available = False |
|
|
|
|
|
_transformers_available = importlib.util.find_spec("transformers") is not None |
|
try: |
|
_transformers_version = importlib_metadata.version("transformers") |
|
logger.debug(f"Successfully imported transformers version {_transformers_version}") |
|
except importlib_metadata.PackageNotFoundError: |
|
_transformers_available = False |
|
|
|
|
|
_inflect_available = importlib.util.find_spec("inflect") is not None |
|
try: |
|
_inflect_version = importlib_metadata.version("inflect") |
|
logger.debug(f"Successfully imported inflect version {_inflect_version}") |
|
except importlib_metadata.PackageNotFoundError: |
|
_inflect_available = False |
|
|
|
|
|
_unidecode_available = importlib.util.find_spec("unidecode") is not None |
|
try: |
|
_unidecode_version = importlib_metadata.version("unidecode") |
|
logger.debug(f"Successfully imported unidecode version {_unidecode_version}") |
|
except importlib_metadata.PackageNotFoundError: |
|
_unidecode_available = False |
|
|
|
|
|
_modelcards_available = importlib.util.find_spec("modelcards") is not None |
|
try: |
|
_modelcards_version = importlib_metadata.version("modelcards") |
|
logger.debug(f"Successfully imported modelcards version {_modelcards_version}") |
|
except importlib_metadata.PackageNotFoundError: |
|
_modelcards_available = False |
|
|
|
|
|
_onnx_available = importlib.util.find_spec("onnxruntime") is not None |
|
try: |
|
_onnxruntime_version = importlib_metadata.version("onnxruntime") |
|
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") |
|
except importlib_metadata.PackageNotFoundError: |
|
_onnx_available = False |
|
|
|
|
|
_scipy_available = importlib.util.find_spec("scipy") is not None |
|
try: |
|
_scipy_version = importlib_metadata.version("scipy") |
|
logger.debug(f"Successfully imported transformers version {_scipy_version}") |
|
except importlib_metadata.PackageNotFoundError: |
|
_scipy_available = False |
|
|
|
|
|
def is_torch_available(): |
|
return _torch_available |
|
|
|
|
|
def is_tf_available(): |
|
return _tf_available |
|
|
|
|
|
def is_flax_available(): |
|
return _flax_available |
|
|
|
|
|
def is_transformers_available(): |
|
return _transformers_available |
|
|
|
|
|
def is_inflect_available(): |
|
return _inflect_available |
|
|
|
|
|
def is_unidecode_available(): |
|
return _unidecode_available |
|
|
|
|
|
def is_modelcards_available(): |
|
return _modelcards_available |
|
|
|
|
|
def is_onnx_available(): |
|
return _onnx_available |
|
|
|
|
|
def is_scipy_available(): |
|
return _scipy_available |
|
|
|
|
|
|
|
FLAX_IMPORT_ERROR = """ |
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the |
|
installation page: https://github.com/google/flax and follow the ones that match your environment. |
|
""" |
|
|
|
|
|
INFLECT_IMPORT_ERROR = """ |
|
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install |
|
inflect` |
|
""" |
|
|
|
|
|
PYTORCH_IMPORT_ERROR = """ |
|
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the |
|
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. |
|
""" |
|
|
|
|
|
ONNX_IMPORT_ERROR = """ |
|
{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip |
|
install onnxruntime` |
|
""" |
|
|
|
|
|
SCIPY_IMPORT_ERROR = """ |
|
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install |
|
scipy` |
|
""" |
|
|
|
|
|
TENSORFLOW_IMPORT_ERROR = """ |
|
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the |
|
installation page: https://www.tensorflow.org/install and follow the ones that match your environment. |
|
""" |
|
|
|
|
|
TRANSFORMERS_IMPORT_ERROR = """ |
|
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip |
|
install transformers` |
|
""" |
|
|
|
|
|
UNIDECODE_IMPORT_ERROR = """ |
|
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install |
|
Unidecode` |
|
""" |
|
|
|
|
|
BACKENDS_MAPPING = OrderedDict( |
|
[ |
|
("flax", (is_flax_available, FLAX_IMPORT_ERROR)), |
|
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), |
|
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), |
|
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), |
|
("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), |
|
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), |
|
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), |
|
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), |
|
] |
|
) |
|
|
|
|
|
def requires_backends(obj, backends): |
|
if not isinstance(backends, (list, tuple)): |
|
backends = [backends] |
|
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ |
|
checks = (BACKENDS_MAPPING[backend] for backend in backends) |
|
failed = [msg.format(name) for available, msg in checks if not available()] |
|
if failed: |
|
raise ImportError("".join(failed)) |
|
|
|
|
|
class DummyObject(type): |
|
""" |
|
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by |
|
`requires_backend` each time a user tries to access any method of that class. |
|
""" |
|
|
|
def __getattr__(cls, key): |
|
if key.startswith("_"): |
|
return super().__getattr__(cls, key) |
|
requires_backends(cls, cls._backends) |
|
|