| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Import utilities: Utilities related to imports and our lazy inits. |
| | """ |
| |
|
| | import importlib.util |
| | import operator as op |
| | import os |
| | import sys |
| | from collections import OrderedDict |
| | from itertools import chain |
| | from types import ModuleType |
| | from typing import Any, Union |
| |
|
| | from huggingface_hub.utils import is_jinja_available |
| | from packaging import version |
| | from packaging.version import Version, parse |
| |
|
| | from . import logging |
| |
|
| |
|
| | |
| | if sys.version_info < (3, 8): |
| | import importlib_metadata |
| | else: |
| | import importlib.metadata as importlib_metadata |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} |
| | ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) |
| |
|
| | USE_TF = os.environ.get("USE_TF", "AUTO").upper() |
| | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() |
| | USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() |
| | USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() |
| | DIFFUSERS_SLOW_IMPORT = os.environ.get("DIFFUSERS_SLOW_IMPORT", "FALSE").upper() |
| | DIFFUSERS_SLOW_IMPORT = DIFFUSERS_SLOW_IMPORT in ENV_VARS_TRUE_VALUES |
| |
|
| | STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} |
| |
|
| | _torch_version = "N/A" |
| | if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: |
| | _torch_available = importlib.util.find_spec("torch") is not None |
| | if _torch_available: |
| | try: |
| | _torch_version = importlib_metadata.version("torch") |
| | logger.info(f"PyTorch version {_torch_version} available.") |
| | except importlib_metadata.PackageNotFoundError: |
| | _torch_available = False |
| | else: |
| | logger.info("Disabling PyTorch because USE_TORCH is set") |
| | _torch_available = False |
| |
|
| | _torch_xla_available = importlib.util.find_spec("torch_xla") is not None |
| | if _torch_xla_available: |
| | try: |
| | _torch_xla_version = importlib_metadata.version("torch_xla") |
| | logger.info(f"PyTorch XLA version {_torch_xla_version} available.") |
| | except ImportError: |
| | _torch_xla_available = False |
| |
|
| | |
| | _torch_npu_available = importlib.util.find_spec("torch_npu") is not None |
| | if _torch_npu_available: |
| | try: |
| | _torch_npu_version = importlib_metadata.version("torch_npu") |
| | logger.info(f"torch_npu version {_torch_npu_version} available.") |
| | except ImportError: |
| | _torch_npu_available = False |
| |
|
| | _jax_version = "N/A" |
| | _flax_version = "N/A" |
| | if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: |
| | _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None |
| | if _flax_available: |
| | try: |
| | _jax_version = importlib_metadata.version("jax") |
| | _flax_version = importlib_metadata.version("flax") |
| | logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") |
| | except importlib_metadata.PackageNotFoundError: |
| | _flax_available = False |
| | else: |
| | _flax_available = False |
| |
|
| | if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: |
| | _safetensors_available = importlib.util.find_spec("safetensors") is not None |
| | if _safetensors_available: |
| | try: |
| | _safetensors_version = importlib_metadata.version("safetensors") |
| | logger.info(f"Safetensors version {_safetensors_version} available.") |
| | except importlib_metadata.PackageNotFoundError: |
| | _safetensors_available = False |
| | else: |
| | logger.info("Disabling Safetensors because USE_TF is set") |
| | _safetensors_available = False |
| |
|
| | _transformers_available = importlib.util.find_spec("transformers") is not None |
| | try: |
| | _transformers_version = importlib_metadata.version("transformers") |
| | logger.debug(f"Successfully imported transformers version {_transformers_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _transformers_available = False |
| |
|
| |
|
| | _inflect_available = importlib.util.find_spec("inflect") is not None |
| | try: |
| | _inflect_version = importlib_metadata.version("inflect") |
| | logger.debug(f"Successfully imported inflect version {_inflect_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _inflect_available = False |
| |
|
| |
|
| | _unidecode_available = importlib.util.find_spec("unidecode") is not None |
| | try: |
| | _unidecode_version = importlib_metadata.version("unidecode") |
| | logger.debug(f"Successfully imported unidecode version {_unidecode_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _unidecode_available = False |
| |
|
| | _onnxruntime_version = "N/A" |
| | _onnx_available = importlib.util.find_spec("onnxruntime") is not None |
| | if _onnx_available: |
| | candidates = ( |
| | "onnxruntime", |
| | "onnxruntime-gpu", |
| | "ort_nightly_gpu", |
| | "onnxruntime-directml", |
| | "onnxruntime-openvino", |
| | "ort_nightly_directml", |
| | "onnxruntime-rocm", |
| | "onnxruntime-training", |
| | ) |
| | _onnxruntime_version = None |
| | |
| | for pkg in candidates: |
| | try: |
| | _onnxruntime_version = importlib_metadata.version(pkg) |
| | break |
| | except importlib_metadata.PackageNotFoundError: |
| | pass |
| | _onnx_available = _onnxruntime_version is not None |
| | if _onnx_available: |
| | logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") |
| |
|
| | |
| | |
| | try: |
| | candidates = ( |
| | "opencv-python", |
| | "opencv-contrib-python", |
| | "opencv-python-headless", |
| | "opencv-contrib-python-headless", |
| | ) |
| | _opencv_version = None |
| | for pkg in candidates: |
| | try: |
| | _opencv_version = importlib_metadata.version(pkg) |
| | break |
| | except importlib_metadata.PackageNotFoundError: |
| | pass |
| | _opencv_available = _opencv_version is not None |
| | if _opencv_available: |
| | logger.debug(f"Successfully imported cv2 version {_opencv_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _opencv_available = False |
| |
|
| | _scipy_available = importlib.util.find_spec("scipy") is not None |
| | try: |
| | _scipy_version = importlib_metadata.version("scipy") |
| | logger.debug(f"Successfully imported scipy version {_scipy_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _scipy_available = False |
| |
|
| | _librosa_available = importlib.util.find_spec("librosa") is not None |
| | try: |
| | _librosa_version = importlib_metadata.version("librosa") |
| | logger.debug(f"Successfully imported librosa version {_librosa_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _librosa_available = False |
| |
|
| | _accelerate_available = importlib.util.find_spec("accelerate") is not None |
| | try: |
| | _accelerate_version = importlib_metadata.version("accelerate") |
| | logger.debug(f"Successfully imported accelerate version {_accelerate_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _accelerate_available = False |
| |
|
| | _xformers_available = importlib.util.find_spec("xformers") is not None |
| | try: |
| | _xformers_version = importlib_metadata.version("xformers") |
| | if _torch_available: |
| | _torch_version = importlib_metadata.version("torch") |
| | if version.Version(_torch_version) < version.Version("1.12"): |
| | raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12") |
| |
|
| | logger.debug(f"Successfully imported xformers version {_xformers_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _xformers_available = False |
| |
|
| | _k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None |
| | try: |
| | _k_diffusion_version = importlib_metadata.version("k_diffusion") |
| | logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _k_diffusion_available = False |
| |
|
| | _note_seq_available = importlib.util.find_spec("note_seq") is not None |
| | try: |
| | _note_seq_version = importlib_metadata.version("note_seq") |
| | logger.debug(f"Successfully imported note-seq version {_note_seq_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _note_seq_available = False |
| |
|
| | _wandb_available = importlib.util.find_spec("wandb") is not None |
| | try: |
| | _wandb_version = importlib_metadata.version("wandb") |
| | logger.debug(f"Successfully imported wandb version {_wandb_version }") |
| | except importlib_metadata.PackageNotFoundError: |
| | _wandb_available = False |
| |
|
| |
|
| | _tensorboard_available = importlib.util.find_spec("tensorboard") |
| | try: |
| | _tensorboard_version = importlib_metadata.version("tensorboard") |
| | logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _tensorboard_available = False |
| |
|
| |
|
| | _compel_available = importlib.util.find_spec("compel") |
| | try: |
| | _compel_version = importlib_metadata.version("compel") |
| | logger.debug(f"Successfully imported compel version {_compel_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _compel_available = False |
| |
|
| |
|
| | _ftfy_available = importlib.util.find_spec("ftfy") is not None |
| | try: |
| | _ftfy_version = importlib_metadata.version("ftfy") |
| | logger.debug(f"Successfully imported ftfy version {_ftfy_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _ftfy_available = False |
| |
|
| |
|
| | _bs4_available = importlib.util.find_spec("bs4") is not None |
| | try: |
| | |
| | _bs4_version = importlib_metadata.version("beautifulsoup4") |
| | logger.debug(f"Successfully imported ftfy version {_bs4_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _bs4_available = False |
| |
|
| | _torchsde_available = importlib.util.find_spec("torchsde") is not None |
| | try: |
| | _torchsde_version = importlib_metadata.version("torchsde") |
| | logger.debug(f"Successfully imported torchsde version {_torchsde_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _torchsde_available = False |
| |
|
| | _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None |
| | try: |
| | _invisible_watermark_version = importlib_metadata.version("invisible-watermark") |
| | logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _invisible_watermark_available = False |
| |
|
| |
|
| | _peft_available = importlib.util.find_spec("peft") is not None |
| | try: |
| | _peft_version = importlib_metadata.version("peft") |
| | logger.debug(f"Successfully imported peft version {_peft_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _peft_available = False |
| |
|
| | _torchvision_available = importlib.util.find_spec("torchvision") is not None |
| | try: |
| | _torchvision_version = importlib_metadata.version("torchvision") |
| | logger.debug(f"Successfully imported torchvision version {_torchvision_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _torchvision_available = False |
| |
|
| | _matplotlib_available = importlib.util.find_spec("matplotlib") is not None |
| | try: |
| | _matplotlib_version = importlib_metadata.version("matplotlib") |
| | logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _matplotlib_available = False |
| |
|
| | _timm_available = importlib.util.find_spec("timm") is not None |
| | if _timm_available: |
| | try: |
| | _timm_version = importlib_metadata.version("timm") |
| | logger.info(f"Timm version {_timm_version} available.") |
| | except importlib_metadata.PackageNotFoundError: |
| | _timm_available = False |
| |
|
| |
|
| | def is_timm_available(): |
| | return _timm_available |
| |
|
| |
|
| | _bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None |
| | try: |
| | _bitsandbytes_version = importlib_metadata.version("bitsandbytes") |
| | logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") |
| | except importlib_metadata.PackageNotFoundError: |
| | _bitsandbytes_available = False |
| |
|
| | |
| | _is_notebook = False |
| | try: |
| | shell_class = get_ipython().__class__ |
| | for parent_class in shell_class.__mro__: |
| | if parent_class.__name__ == "ZMQInteractiveShell": |
| | _is_notebook = True |
| | break |
| | except NameError: |
| | pass |
| |
|
| | _is_google_colab = "google.colab" in sys.modules |
| |
|
| |
|
| | def is_torch_available(): |
| | return _torch_available |
| |
|
| |
|
| | def is_torch_xla_available(): |
| | return _torch_xla_available |
| |
|
| |
|
| | def is_torch_npu_available(): |
| | return _torch_npu_available |
| |
|
| |
|
| | def is_flax_available(): |
| | return _flax_available |
| |
|
| |
|
| | def is_transformers_available(): |
| | return _transformers_available |
| |
|
| |
|
| | def is_inflect_available(): |
| | return _inflect_available |
| |
|
| |
|
| | def is_unidecode_available(): |
| | return _unidecode_available |
| |
|
| |
|
| | def is_onnx_available(): |
| | return _onnx_available |
| |
|
| |
|
| | def is_opencv_available(): |
| | return _opencv_available |
| |
|
| |
|
| | def is_scipy_available(): |
| | return _scipy_available |
| |
|
| |
|
| | def is_librosa_available(): |
| | return _librosa_available |
| |
|
| |
|
| | def is_xformers_available(): |
| | return _xformers_available |
| |
|
| |
|
| | def is_accelerate_available(): |
| | return _accelerate_available |
| |
|
| |
|
| | def is_k_diffusion_available(): |
| | return _k_diffusion_available |
| |
|
| |
|
| | def is_note_seq_available(): |
| | return _note_seq_available |
| |
|
| |
|
| | def is_wandb_available(): |
| | return _wandb_available |
| |
|
| |
|
| | def is_tensorboard_available(): |
| | return _tensorboard_available |
| |
|
| |
|
| | def is_compel_available(): |
| | return _compel_available |
| |
|
| |
|
| | def is_ftfy_available(): |
| | return _ftfy_available |
| |
|
| |
|
| | def is_bs4_available(): |
| | return _bs4_available |
| |
|
| |
|
| | def is_torchsde_available(): |
| | return _torchsde_available |
| |
|
| |
|
| | def is_invisible_watermark_available(): |
| | return _invisible_watermark_available |
| |
|
| |
|
| | def is_peft_available(): |
| | return _peft_available |
| |
|
| |
|
| | def is_torchvision_available(): |
| | return _torchvision_available |
| |
|
| |
|
| | def is_matplotlib_available(): |
| | return _matplotlib_available |
| |
|
| |
|
| | def is_safetensors_available(): |
| | return _safetensors_available |
| |
|
| |
|
| | def is_bitsandbytes_available(): |
| | return _bitsandbytes_available |
| |
|
| |
|
| | def is_notebook(): |
| | return _is_notebook |
| |
|
| |
|
| | def is_google_colab(): |
| | return _is_google_colab |
| |
|
| |
|
| | |
| | FLAX_IMPORT_ERROR = """ |
| | {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the |
| | installation page: https://github.com/google/flax and follow the ones that match your environment. |
| | """ |
| |
|
| | |
| | INFLECT_IMPORT_ERROR = """ |
| | {0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install |
| | inflect` |
| | """ |
| |
|
| | |
| | PYTORCH_IMPORT_ERROR = """ |
| | {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the |
| | installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. |
| | """ |
| |
|
| | |
| | ONNX_IMPORT_ERROR = """ |
| | {0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip |
| | install onnxruntime` |
| | """ |
| |
|
| | |
| | OPENCV_IMPORT_ERROR = """ |
| | {0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip |
| | install opencv-python` |
| | """ |
| |
|
| | |
| | SCIPY_IMPORT_ERROR = """ |
| | {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install |
| | scipy` |
| | """ |
| |
|
| | |
| | LIBROSA_IMPORT_ERROR = """ |
| | {0} requires the librosa library but it was not found in your environment. Checkout the instructions on the |
| | installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. |
| | """ |
| |
|
| | |
| | TRANSFORMERS_IMPORT_ERROR = """ |
| | {0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip |
| | install transformers` |
| | """ |
| |
|
| | |
| | UNIDECODE_IMPORT_ERROR = """ |
| | {0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install |
| | Unidecode` |
| | """ |
| |
|
| | |
| | K_DIFFUSION_IMPORT_ERROR = """ |
| | {0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip |
| | install k-diffusion` |
| | """ |
| |
|
| | |
| | NOTE_SEQ_IMPORT_ERROR = """ |
| | {0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip |
| | install note-seq` |
| | """ |
| |
|
| | |
| | WANDB_IMPORT_ERROR = """ |
| | {0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip |
| | install wandb` |
| | """ |
| |
|
| | |
| | TENSORBOARD_IMPORT_ERROR = """ |
| | {0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip |
| | install tensorboard` |
| | """ |
| |
|
| |
|
| | |
| | COMPEL_IMPORT_ERROR = """ |
| | {0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` |
| | """ |
| |
|
| | |
| | BS4_IMPORT_ERROR = """ |
| | {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: |
| | `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. |
| | """ |
| |
|
| | |
| | FTFY_IMPORT_ERROR = """ |
| | {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the |
| | installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones |
| | that match your environment. Please note that you may need to restart your runtime after installation. |
| | """ |
| |
|
| | |
| | TORCHSDE_IMPORT_ERROR = """ |
| | {0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` |
| | """ |
| |
|
| | |
| | INVISIBLE_WATERMARK_IMPORT_ERROR = """ |
| | {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` |
| | """ |
| |
|
| | |
| | PEFT_IMPORT_ERROR = """ |
| | {0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft` |
| | """ |
| |
|
| | |
| | SAFETENSORS_IMPORT_ERROR = """ |
| | {0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors` |
| | """ |
| |
|
| | |
| | BITSANDBYTES_IMPORT_ERROR = """ |
| | {0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes` |
| | """ |
| |
|
| | BACKENDS_MAPPING = OrderedDict( |
| | [ |
| | ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), |
| | ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), |
| | ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), |
| | ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), |
| | ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), |
| | ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), |
| | ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), |
| | ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), |
| | ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), |
| | ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), |
| | ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), |
| | ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), |
| | ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), |
| | ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), |
| | ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), |
| | ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), |
| | ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), |
| | ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), |
| | ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), |
| | ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)), |
| | ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), |
| | ] |
| | ) |
| |
|
| |
|
| | def requires_backends(obj, backends): |
| | if not isinstance(backends, (list, tuple)): |
| | backends = [backends] |
| |
|
| | name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ |
| | checks = (BACKENDS_MAPPING[backend] for backend in backends) |
| | failed = [msg.format(name) for available, msg in checks if not available()] |
| | if failed: |
| | raise ImportError("".join(failed)) |
| |
|
| | if name in [ |
| | "VersatileDiffusionTextToImagePipeline", |
| | "VersatileDiffusionPipeline", |
| | "VersatileDiffusionDualGuidedPipeline", |
| | "StableDiffusionImageVariationPipeline", |
| | "UnCLIPPipeline", |
| | ] and is_transformers_version("<", "4.25.0"): |
| | raise ImportError( |
| | f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" |
| | " --upgrade transformers \n```" |
| | ) |
| |
|
| | if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( |
| | "<", "4.26.0" |
| | ): |
| | raise ImportError( |
| | f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" |
| | " --upgrade transformers \n```" |
| | ) |
| |
|
| |
|
| | class DummyObject(type): |
| | """ |
| | Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by |
| | `requires_backend` each time a user tries to access any method of that class. |
| | """ |
| |
|
| | def __getattr__(cls, key): |
| | if key.startswith("_") and key not in ["_load_connected_pipes", "_is_onnx"]: |
| | return super().__getattr__(cls, key) |
| | requires_backends(cls, cls._backends) |
| |
|
| |
|
| | |
| | def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): |
| | """ |
| | Args: |
| | Compares a library version to some requirement using a given operation. |
| | library_or_version (`str` or `packaging.version.Version`): |
| | A library name or a version to check. |
| | operation (`str`): |
| | A string representation of an operator, such as `">"` or `"<="`. |
| | requirement_version (`str`): |
| | The version to compare the library version against |
| | """ |
| | if operation not in STR_OPERATION_TO_FUNC.keys(): |
| | raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") |
| | operation = STR_OPERATION_TO_FUNC[operation] |
| | if isinstance(library_or_version, str): |
| | library_or_version = parse(importlib_metadata.version(library_or_version)) |
| | return operation(library_or_version, parse(requirement_version)) |
| |
|
| |
|
| | |
| | def is_torch_version(operation: str, version: str): |
| | """ |
| | Args: |
| | Compares the current PyTorch version to a given reference with an operation. |
| | operation (`str`): |
| | A string representation of an operator, such as `">"` or `"<="` |
| | version (`str`): |
| | A string version of PyTorch |
| | """ |
| | return compare_versions(parse(_torch_version), operation, version) |
| |
|
| |
|
| | def is_transformers_version(operation: str, version: str): |
| | """ |
| | Args: |
| | Compares the current Transformers version to a given reference with an operation. |
| | operation (`str`): |
| | A string representation of an operator, such as `">"` or `"<="` |
| | version (`str`): |
| | A version string |
| | """ |
| | if not _transformers_available: |
| | return False |
| | return compare_versions(parse(_transformers_version), operation, version) |
| |
|
| |
|
| | def is_accelerate_version(operation: str, version: str): |
| | """ |
| | Args: |
| | Compares the current Accelerate version to a given reference with an operation. |
| | operation (`str`): |
| | A string representation of an operator, such as `">"` or `"<="` |
| | version (`str`): |
| | A version string |
| | """ |
| | if not _accelerate_available: |
| | return False |
| | return compare_versions(parse(_accelerate_version), operation, version) |
| |
|
| |
|
| | def is_peft_version(operation: str, version: str): |
| | """ |
| | Args: |
| | Compares the current PEFT version to a given reference with an operation. |
| | operation (`str`): |
| | A string representation of an operator, such as `">"` or `"<="` |
| | version (`str`): |
| | A version string |
| | """ |
| | if not _peft_version: |
| | return False |
| | return compare_versions(parse(_peft_version), operation, version) |
| |
|
| |
|
| | def is_k_diffusion_version(operation: str, version: str): |
| | """ |
| | Args: |
| | Compares the current k-diffusion version to a given reference with an operation. |
| | operation (`str`): |
| | A string representation of an operator, such as `">"` or `"<="` |
| | version (`str`): |
| | A version string |
| | """ |
| | if not _k_diffusion_available: |
| | return False |
| | return compare_versions(parse(_k_diffusion_version), operation, version) |
| |
|
| |
|
| | def get_objects_from_module(module): |
| | """ |
| | Args: |
| | Returns a dict of object names and values in a module, while skipping private/internal objects |
| | module (ModuleType): |
| | Module to extract the objects from. |
| | |
| | Returns: |
| | dict: Dictionary of object names and corresponding values |
| | """ |
| |
|
| | objects = {} |
| | for name in dir(module): |
| | if name.startswith("_"): |
| | continue |
| | objects[name] = getattr(module, name) |
| |
|
| | return objects |
| |
|
| |
|
| | class OptionalDependencyNotAvailable(BaseException): |
| | """An error indicating that an optional dependency of Diffusers was not found in the environment.""" |
| |
|
| |
|
| | class _LazyModule(ModuleType): |
| | """ |
| | Module class that surfaces all objects but only performs associated imports when the objects are requested. |
| | """ |
| |
|
| | |
| | |
| | def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): |
| | super().__init__(name) |
| | self._modules = set(import_structure.keys()) |
| | self._class_to_module = {} |
| | for key, values in import_structure.items(): |
| | for value in values: |
| | self._class_to_module[value] = key |
| | |
| | self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) |
| | self.__file__ = module_file |
| | self.__spec__ = module_spec |
| | self.__path__ = [os.path.dirname(module_file)] |
| | self._objects = {} if extra_objects is None else extra_objects |
| | self._name = name |
| | self._import_structure = import_structure |
| |
|
| | |
| | def __dir__(self): |
| | result = super().__dir__() |
| | |
| | |
| | for attr in self.__all__: |
| | if attr not in result: |
| | result.append(attr) |
| | return result |
| |
|
| | def __getattr__(self, name: str) -> Any: |
| | if name in self._objects: |
| | return self._objects[name] |
| | if name in self._modules: |
| | value = self._get_module(name) |
| | elif name in self._class_to_module.keys(): |
| | module = self._get_module(self._class_to_module[name]) |
| | value = getattr(module, name) |
| | else: |
| | raise AttributeError(f"module {self.__name__} has no attribute {name}") |
| |
|
| | setattr(self, name, value) |
| | return value |
| |
|
| | def _get_module(self, module_name: str): |
| | try: |
| | return importlib.import_module("." + module_name, self.__name__) |
| | except Exception as e: |
| | raise RuntimeError( |
| | f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" |
| | f" traceback):\n{e}" |
| | ) from e |
| |
|
| | def __reduce__(self): |
| | return (self.__class__, (self._name, self.__file__, self._import_structure)) |
| |
|