diff --git a/my_diffusers/__init__.py b/my_diffusers/__init__.py deleted file mode 100644 index 6c57032ebc7396a11717bfc60cffd1bd91c5e290..0000000000000000000000000000000000000000 --- a/my_diffusers/__init__.py +++ /dev/null @@ -1,204 +0,0 @@ -__version__ = "0.14.0" - -from .configuration_utils import ConfigMixin -from .utils import ( - OptionalDependencyNotAvailable, - is_flax_available, - is_inflect_available, - is_k_diffusion_available, - is_k_diffusion_version, - is_librosa_available, - is_onnx_available, - is_scipy_available, - is_torch_available, - is_transformers_available, - is_transformers_version, - is_unidecode_available, - logging, -) - - -try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_onnx_objects import * # noqa F403 -else: - from .pipelines import OnnxRuntimeModel - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_pt_objects import * # noqa F403 -else: - from .models import ( - AutoencoderKL, - ControlNetModel, - ModelMixin, - PriorTransformer, - Transformer2DModel, - UNet1DModel, - UNet2DConditionModel, - UNet2DModel, - VQModel, - ) - from .optimization import ( - get_constant_schedule, - get_constant_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_cosine_with_hard_restarts_schedule_with_warmup, - get_linear_schedule_with_warmup, - get_polynomial_decay_schedule_with_warmup, - get_scheduler, - ) - from .pipelines import ( - AudioPipelineOutput, - DanceDiffusionPipeline, - DDIMPipeline, - DDPMPipeline, - DiffusionPipeline, - DiTPipeline, - ImagePipelineOutput, - KarrasVePipeline, - LDMPipeline, - LDMSuperResolutionPipeline, - PNDMPipeline, - RePaintPipeline, - ScoreSdeVePipeline, - ) - from .schedulers import ( - DDIMInverseScheduler, - DDIMScheduler, - DDPMScheduler, - DEISMultistepScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - IPNDMScheduler, - KarrasVeScheduler, - KDPM2AncestralDiscreteScheduler, - KDPM2DiscreteScheduler, - PNDMScheduler, - RePaintScheduler, - SchedulerMixin, - ScoreSdeVeScheduler, - UnCLIPScheduler, - UniPCMultistepScheduler, - VQDiffusionScheduler, - ) - from .training_utils import EMAModel - -try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_scipy_objects import * # noqa F403 -else: - from .schedulers import LMSDiscreteScheduler - - -try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_transformers_objects import * # noqa F403 -else: - from .pipelines import ( - AltDiffusionImg2ImgPipeline, - AltDiffusionPipeline, - CycleDiffusionPipeline, - LDMTextToImagePipeline, - PaintByExamplePipeline, - SemanticStableDiffusionPipeline, - StableDiffusionAttendAndExcitePipeline, - StableDiffusionControlNetPipeline, - StableDiffusionDepth2ImgPipeline, - StableDiffusionImageVariationPipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionInpaintPipelineLegacy, - StableDiffusionInstructPix2PixPipeline, - StableDiffusionLatentUpscalePipeline, - StableDiffusionPanoramaPipeline, - StableDiffusionPipeline, - StableDiffusionPipelineSafe, - StableDiffusionPix2PixZeroPipeline, - StableDiffusionSAGPipeline, - StableDiffusionUpscalePipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - UnCLIPImageVariationPipeline, - UnCLIPPipeline, - VersatileDiffusionDualGuidedPipeline, - VersatileDiffusionImageVariationPipeline, - VersatileDiffusionPipeline, - VersatileDiffusionTextToImagePipeline, - VQDiffusionPipeline, - ) - -try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 -else: - from .pipelines import StableDiffusionKDiffusionPipeline - -try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 -else: - from .pipelines import ( - OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionInpaintPipeline, - OnnxStableDiffusionInpaintPipelineLegacy, - OnnxStableDiffusionPipeline, - StableDiffusionOnnxPipeline, - ) - -try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_librosa_objects import * # noqa F403 -else: - from .pipelines import AudioDiffusionPipeline, Mel - -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_flax_objects import * # noqa F403 -else: - from .models.modeling_flax_utils import FlaxModelMixin - from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel - from .models.vae_flax import FlaxAutoencoderKL - from .pipelines import FlaxDiffusionPipeline - from .schedulers import ( - FlaxDDIMScheduler, - FlaxDDPMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxKarrasVeScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, - FlaxSchedulerMixin, - FlaxScoreSdeVeScheduler, - ) - - -try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_flax_and_transformers_objects import * # noqa F403 -else: - from .pipelines import ( - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - ) diff --git a/my_diffusers/__pycache__/__init__.cpython-39.pyc b/my_diffusers/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 6fdd817e9d807f2db8fbb801f933718106d5e484..0000000000000000000000000000000000000000 Binary files a/my_diffusers/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/__pycache__/configuration_utils.cpython-39.pyc b/my_diffusers/__pycache__/configuration_utils.cpython-39.pyc deleted file mode 100644 index 82920c60ab078bc34cc494b81c711bb654f96f65..0000000000000000000000000000000000000000 Binary files a/my_diffusers/__pycache__/configuration_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/commands/__init__.py b/my_diffusers/commands/__init__.py deleted file mode 100644 index 4ad4af9199bbe297dbc6679fd9ecb46baa976053..0000000000000000000000000000000000000000 --- a/my_diffusers/commands/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from argparse import ArgumentParser - - -class BaseDiffusersCLICommand(ABC): - @staticmethod - @abstractmethod - def register_subcommand(parser: ArgumentParser): - raise NotImplementedError() - - @abstractmethod - def run(self): - raise NotImplementedError() diff --git a/my_diffusers/commands/__pycache__/__init__.cpython-311.pyc b/my_diffusers/commands/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index ab533a8304660284460a08a363d2d782c2d46dbd..0000000000000000000000000000000000000000 Binary files a/my_diffusers/commands/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/commands/__pycache__/diffusers_cli.cpython-311.pyc b/my_diffusers/commands/__pycache__/diffusers_cli.cpython-311.pyc deleted file mode 100644 index 587aa5733aa806bc41ac4c2f7087cee442dd693f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/commands/__pycache__/diffusers_cli.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/commands/__pycache__/env.cpython-311.pyc b/my_diffusers/commands/__pycache__/env.cpython-311.pyc deleted file mode 100644 index 6e27ede8b70c6dd28d6167d50952e1eec4aef4a4..0000000000000000000000000000000000000000 Binary files a/my_diffusers/commands/__pycache__/env.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/commands/diffusers_cli.py b/my_diffusers/commands/diffusers_cli.py deleted file mode 100644 index 74ad29a786d7f77e982242d7020170cb4d031c41..0000000000000000000000000000000000000000 --- a/my_diffusers/commands/diffusers_cli.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from argparse import ArgumentParser - -from .env import EnvironmentCommand - - -def main(): - parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") - commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") - - # Register commands - EnvironmentCommand.register_subcommand(commands_parser) - - # Let's go - args = parser.parse_args() - - if not hasattr(args, "func"): - parser.print_help() - exit(1) - - # Run - service = args.func(args) - service.run() - - -if __name__ == "__main__": - main() diff --git a/my_diffusers/commands/env.py b/my_diffusers/commands/env.py deleted file mode 100644 index db9de720942b5efcff921d7e2503e3ae8813561e..0000000000000000000000000000000000000000 --- a/my_diffusers/commands/env.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import platform -from argparse import ArgumentParser - -import huggingface_hub - -from .. import __version__ as version -from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available -from . import BaseDiffusersCLICommand - - -def info_command_factory(_): - return EnvironmentCommand() - - -class EnvironmentCommand(BaseDiffusersCLICommand): - @staticmethod - def register_subcommand(parser: ArgumentParser): - download_parser = parser.add_parser("env") - download_parser.set_defaults(func=info_command_factory) - - def run(self): - hub_version = huggingface_hub.__version__ - - pt_version = "not installed" - pt_cuda_available = "NA" - if is_torch_available(): - import torch - - pt_version = torch.__version__ - pt_cuda_available = torch.cuda.is_available() - - transformers_version = "not installed" - if is_transformers_available(): - import transformers - - transformers_version = transformers.__version__ - - accelerate_version = "not installed" - if is_accelerate_available(): - import accelerate - - accelerate_version = accelerate.__version__ - - xformers_version = "not installed" - if is_xformers_available(): - import xformers - - xformers_version = xformers.__version__ - - info = { - "`diffusers` version": version, - "Platform": platform.platform(), - "Python version": platform.python_version(), - "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", - "Huggingface_hub version": hub_version, - "Transformers version": transformers_version, - "Accelerate version": accelerate_version, - "xFormers version": xformers_version, - "Using GPU in script?": "", - "Using distributed or parallel set-up in script?": "", - } - - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") - print(self.format_dict(info)) - - return info - - @staticmethod - def format_dict(d): - return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" diff --git a/my_diffusers/configuration_utils.py b/my_diffusers/configuration_utils.py deleted file mode 100644 index 47b201b92cff2c9c90637c74020664578caf9e21..0000000000000000000000000000000000000000 --- a/my_diffusers/configuration_utils.py +++ /dev/null @@ -1,615 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" ConfigMixin base class and utilities.""" -import dataclasses -import functools -import importlib -import inspect -import json -import os -import re -from collections import OrderedDict -from pathlib import PosixPath -from typing import Any, Dict, Tuple, Union - -import numpy as np -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError -from requests import HTTPError - -from . import __version__ -from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging - - -logger = logging.get_logger(__name__) - -_re_configuration_file = re.compile(r"config\.(.*)\.json") - - -class FrozenDict(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - for key, value in self.items(): - setattr(self, key, value) - - self.__frozen = True - - def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") - - def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") - - def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") - - def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - - def __setattr__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setattr__(name, value) - - def __setitem__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setitem__(name, value) - - -class ConfigMixin: - r""" - Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all - methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with - - [`~ConfigMixin.from_config`] - - [`~ConfigMixin.save_config`] - - Class attributes: - - **config_name** (`str`) -- A filename under which the config should stored when calling - [`~ConfigMixin.save_config`] (should be overridden by parent class). - - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - overridden by subclass). - - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). - - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function - should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by - subclass). - """ - config_name = None - ignore_for_config = [] - has_compatibles = False - - _deprecated_kwargs = [] - - def register_to_config(self, **kwargs): - if self.config_name is None: - raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") - # Special case for `kwargs` used in deprecation warning added to schedulers - # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, - # or solve in a more general way. - kwargs.pop("kwargs", None) - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - - if not hasattr(self, "_internal_dict"): - internal_dict = kwargs - else: - previous_dict = dict(self._internal_dict) - internal_dict = {**self._internal_dict, **kwargs} - logger.debug(f"Updating config from {previous_dict} to {internal_dict}") - - self._internal_dict = FrozenDict(internal_dict) - - def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - """ - Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the - [`~ConfigMixin.from_config`] class method. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the configuration JSON file will be saved (will be created if it does not exist). - """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - - os.makedirs(save_directory, exist_ok=True) - - # If we save using the predefined names, we can load using `from_config` - output_config_file = os.path.join(save_directory, self.config_name) - - self.to_json_file(output_config_file) - logger.info(f"Configuration saved in {output_config_file}") - - @classmethod - def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): - r""" - Instantiate a Python class from a config dictionary - - Parameters: - config (`Dict[str, Any]`): - A config dictionary from which the Python class will be instantiated. Make sure to only load - configuration files of compatible classes. - return_unused_kwargs (`bool`, *optional*, defaults to `False`): - Whether kwargs that are not consumed by the Python class should be returned or not. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the Python class. - `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually - overwrite same named arguments of `config`. - - Examples: - - ```python - >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler - - >>> # Download scheduler from huggingface.co and cache. - >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") - - >>> # Instantiate DDIM scheduler class with same config as DDPM - >>> scheduler = DDIMScheduler.from_config(scheduler.config) - - >>> # Instantiate PNDM scheduler class with same config as DDPM - >>> scheduler = PNDMScheduler.from_config(scheduler.config) - ``` - """ - # <===== TO BE REMOVED WITH DEPRECATION - # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated - if "pretrained_model_name_or_path" in kwargs: - config = kwargs.pop("pretrained_model_name_or_path") - - if config is None: - raise ValueError("Please make sure to provide a config as the first positional argument.") - # ======> - - if not isinstance(config, dict): - deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." - if "Scheduler" in cls.__name__: - deprecation_message += ( - f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." - " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" - " be removed in v1.0.0." - ) - elif "Model" in cls.__name__: - deprecation_message += ( - f"If you were trying to load a model, please use {cls}.load_config(...) followed by" - f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" - " instead. This functionality will be removed in v1.0.0." - ) - deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) - config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) - - init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) - - # Allow dtype to be specified on initialization - if "dtype" in unused_kwargs: - init_dict["dtype"] = unused_kwargs.pop("dtype") - - # add possible deprecated kwargs - for deprecated_kwarg in cls._deprecated_kwargs: - if deprecated_kwarg in unused_kwargs: - init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) - - # Return model and optionally state and/or unused_kwargs - model = cls(**init_dict) - - # make sure to also save config parameters that might be used for compatible classes - model.register_to_config(**hidden_dict) - - # add hidden kwargs of compatible classes to unused_kwargs - unused_kwargs = {**unused_kwargs, **hidden_dict} - - if return_unused_kwargs: - return (model, unused_kwargs) - else: - return model - - @classmethod - def get_config_dict(cls, *args, **kwargs): - deprecation_message = ( - f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" - " removed in version v1.0.0" - ) - deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) - return cls.load_config(*args, **kwargs) - - @classmethod - def load_config( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - r""" - Instantiate a Python class from a config dictionary - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an - organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., - `./my_model_directory/`. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to - use this method in a firewalled environment. - - - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - use_auth_token = kwargs.pop("use_auth_token", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - _ = kwargs.pop("mirror", None) - subfolder = kwargs.pop("subfolder", None) - - user_agent = {"file_type": "config"} - - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - - if cls.config_name is None: - raise ValueError( - "`self.config_name` is not defined. Note that one should not load a config from " - "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" - ) - - if os.path.isfile(pretrained_model_name_or_path): - config_file = pretrained_model_name_or_path - elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): - # Load from a PyTorch checkpoint - config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) - ): - config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) - else: - raise EnvironmentError( - f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." - ) - else: - try: - # Load from URL or cache if already cached - config_file = hf_hub_download( - pretrained_model_name_or_path, - filename=cls.config_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision, - ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" - " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" - " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" - " login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" - " this model name. Check the model page at" - f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." - ) - except HTTPError as err: - raise EnvironmentError( - "There was a specific connection error when trying to load" - f" {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" - " run the library in offline mode at" - " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a {cls.config_name} file" - ) - - try: - # Load config dict - config_dict = cls._dict_from_json_file(config_file) - except (json.JSONDecodeError, UnicodeDecodeError): - raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") - - if return_unused_kwargs: - return config_dict, kwargs - - return config_dict - - @staticmethod - def _get_init_keys(cls): - return set(dict(inspect.signature(cls.__init__).parameters).keys()) - - @classmethod - def extract_init_dict(cls, config_dict, **kwargs): - # 0. Copy origin config dict - original_dict = {k: v for k, v in config_dict.items()} - - # 1. Retrieve expected config attributes from __init__ signature - expected_keys = cls._get_init_keys(cls) - expected_keys.remove("self") - # remove general kwargs if present in dict - if "kwargs" in expected_keys: - expected_keys.remove("kwargs") - # remove flax internal keys - if hasattr(cls, "_flax_internal_args"): - for arg in cls._flax_internal_args: - expected_keys.remove(arg) - - # 2. Remove attributes that cannot be expected from expected config attributes - # remove keys to be ignored - if len(cls.ignore_for_config) > 0: - expected_keys = expected_keys - set(cls.ignore_for_config) - - # load diffusers library to import compatible and original scheduler - diffusers_library = importlib.import_module(__name__.split(".")[0]) - - if cls.has_compatibles: - compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] - else: - compatible_classes = [] - - expected_keys_comp_cls = set() - for c in compatible_classes: - expected_keys_c = cls._get_init_keys(c) - expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) - expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) - config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} - - # remove attributes from orig class that cannot be expected - orig_cls_name = config_dict.pop("_class_name", cls.__name__) - if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): - orig_cls = getattr(diffusers_library, orig_cls_name) - unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys - config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} - - # remove private attributes - config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} - - # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments - init_dict = {} - for key in expected_keys: - # if config param is passed to kwarg and is present in config dict - # it should overwrite existing config dict key - if key in kwargs and key in config_dict: - config_dict[key] = kwargs.pop(key) - - if key in kwargs: - # overwrite key - init_dict[key] = kwargs.pop(key) - elif key in config_dict: - # use value from config dict - init_dict[key] = config_dict.pop(key) - - # 4. Give nice warning if unexpected values have been passed - if len(config_dict) > 0: - logger.warning( - f"The config attributes {config_dict} were passed to {cls.__name__}, " - "but are not expected and will be ignored. Please verify your " - f"{cls.config_name} configuration file." - ) - - # 5. Give nice info if config attributes are initiliazed to default because they have not been passed - passed_keys = set(init_dict.keys()) - if len(expected_keys - passed_keys) > 0: - logger.info( - f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." - ) - - # 6. Define unused keyword arguments - unused_kwargs = {**config_dict, **kwargs} - - # 7. Define "hidden" config parameters that were saved for compatible classes - hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} - - return init_dict, unused_kwargs, hidden_config_dict - - @classmethod - def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() - return json.loads(text) - - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - @property - def config(self) -> Dict[str, Any]: - """ - Returns the config of the class as a frozen dictionary - - Returns: - `Dict[str, Any]`: Config of the class. - """ - return self._internal_dict - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string. - - Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. - """ - config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} - config_dict["_class_name"] = self.__class__.__name__ - config_dict["_diffusers_version"] = __version__ - - def to_json_saveable(value): - if isinstance(value, np.ndarray): - value = value.tolist() - elif isinstance(value, PosixPath): - value = str(value) - return value - - config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} - return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - -def register_to_config(init): - r""" - Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are - automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that - shouldn't be registered in the config, use the `ignore_for_config` class variable - - Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! - """ - - @functools.wraps(init) - def inner_init(self, *args, **kwargs): - # Ignore private kwargs in the init. - init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} - config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} - if not isinstance(self, ConfigMixin): - raise RuntimeError( - f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " - "not inherit from `ConfigMixin`." - ) - - ignore = getattr(self, "ignore_for_config", []) - # Get positional arguments aligned with kwargs - new_kwargs = {} - signature = inspect.signature(init) - parameters = { - name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore - } - for arg, name in zip(args, parameters.keys()): - new_kwargs[name] = arg - - # Then add all kwargs - new_kwargs.update( - { - k: init_kwargs.get(k, default) - for k, default in parameters.items() - if k not in ignore and k not in new_kwargs - } - ) - new_kwargs = {**config_init_kwargs, **new_kwargs} - getattr(self, "register_to_config")(**new_kwargs) - init(self, *args, **init_kwargs) - - return inner_init - - -def flax_register_to_config(cls): - original_init = cls.__init__ - - @functools.wraps(original_init) - def init(self, *args, **kwargs): - if not isinstance(self, ConfigMixin): - raise RuntimeError( - f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " - "not inherit from `ConfigMixin`." - ) - - # Ignore private kwargs in the init. Retrieve all passed attributes - init_kwargs = {k: v for k, v in kwargs.items()} - - # Retrieve default values - fields = dataclasses.fields(self) - default_kwargs = {} - for field in fields: - # ignore flax specific attributes - if field.name in self._flax_internal_args: - continue - if type(field.default) == dataclasses._MISSING_TYPE: - default_kwargs[field.name] = None - else: - default_kwargs[field.name] = getattr(self, field.name) - - # Make sure init_kwargs override default kwargs - new_kwargs = {**default_kwargs, **init_kwargs} - # dtype should be part of `init_kwargs`, but not `new_kwargs` - if "dtype" in new_kwargs: - new_kwargs.pop("dtype") - - # Get positional arguments aligned with kwargs - for i, arg in enumerate(args): - name = fields[i].name - new_kwargs[name] = arg - - getattr(self, "register_to_config")(**new_kwargs) - original_init(self, *args, **kwargs) - - cls.__init__ = init - return cls diff --git a/my_diffusers/dependency_versions_check.py b/my_diffusers/dependency_versions_check.py deleted file mode 100644 index 4f8578c52957bf6c06decb0d97d3139437f0078f..0000000000000000000000000000000000000000 --- a/my_diffusers/dependency_versions_check.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys - -from .dependency_versions_table import deps -from .utils.versions import require_version, require_version_core - - -# define which module versions we always want to check at run time -# (usually the ones defined in `install_requires` in setup.py) -# -# order specific notes: -# - tqdm must be checked before tokenizers - -pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() -if sys.version_info < (3, 7): - pkgs_to_check_at_runtime.append("dataclasses") -if sys.version_info < (3, 8): - pkgs_to_check_at_runtime.append("importlib_metadata") - -for pkg in pkgs_to_check_at_runtime: - if pkg in deps: - if pkg == "tokenizers": - # must be loaded here, or else tqdm check may fail - from .utils import is_tokenizers_available - - if not is_tokenizers_available(): - continue # not required, check version only if installed - - require_version_core(deps[pkg]) - else: - raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") - - -def dep_version_check(pkg, hint=None): - require_version(deps[pkg], hint) diff --git a/my_diffusers/dependency_versions_table.py b/my_diffusers/dependency_versions_table.py deleted file mode 100644 index a84d33706ef8a69ab4f01b73b204c9e06e956c9d..0000000000000000000000000000000000000000 --- a/my_diffusers/dependency_versions_table.py +++ /dev/null @@ -1,35 +0,0 @@ -# THIS FILE HAS BEEN AUTOGENERATED. To update: -# 1. modify the `_deps` dict in setup.py -# 2. run `make deps_table_update`` -deps = { - "Pillow": "Pillow", - "accelerate": "accelerate>=0.11.0", - "black": "black~=23.1", - "datasets": "datasets", - "filelock": "filelock", - "flax": "flax>=0.4.1", - "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.10.0", - "importlib_metadata": "importlib_metadata", - "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8,!=0.3.2", - "jaxlib": "jaxlib>=0.1.65", - "Jinja2": "Jinja2", - "k-diffusion": "k-diffusion>=0.0.12", - "librosa": "librosa", - "numpy": "numpy", - "parameterized": "parameterized", - "pytest": "pytest", - "pytest-timeout": "pytest-timeout", - "pytest-xdist": "pytest-xdist", - "ruff": "ruff>=0.0.241", - "safetensors": "safetensors", - "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", - "scipy": "scipy", - "regex": "regex!=2019.12.17", - "requests": "requests", - "tensorboard": "tensorboard", - "torch": "torch>=1.4", - "torchvision": "torchvision", - "transformers": "transformers>=4.25.1", -} diff --git a/my_diffusers/experimental/__init__.py b/my_diffusers/experimental/__init__.py deleted file mode 100644 index ebc8155403016dfd8ad7fb78d246f9da9098ac50..0000000000000000000000000000000000000000 --- a/my_diffusers/experimental/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .rl import ValueGuidedRLPipeline diff --git a/my_diffusers/experimental/__pycache__/__init__.cpython-311.pyc b/my_diffusers/experimental/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 3db477909dc3a7de83af4789a3c844190a2bc349..0000000000000000000000000000000000000000 Binary files a/my_diffusers/experimental/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/experimental/rl/__init__.py b/my_diffusers/experimental/rl/__init__.py deleted file mode 100644 index 7b338d3173e12d478b6b6d6fd0e50650a0ab5a4c..0000000000000000000000000000000000000000 --- a/my_diffusers/experimental/rl/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/my_diffusers/experimental/rl/__pycache__/__init__.cpython-311.pyc b/my_diffusers/experimental/rl/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index c2f0538797be495545f7ed94fbf5e2f20d32cd0b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/experimental/rl/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-311.pyc b/my_diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-311.pyc deleted file mode 100644 index 7531803518d0d1c1a431c44675c15df11aaa9012..0000000000000000000000000000000000000000 Binary files a/my_diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/experimental/rl/value_guided_sampling.py b/my_diffusers/experimental/rl/value_guided_sampling.py deleted file mode 100644 index 7de33a795c77d747653f7e382531dba48ead6879..0000000000000000000000000000000000000000 --- a/my_diffusers/experimental/rl/value_guided_sampling.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -import tqdm - -from ...models.unet_1d import UNet1DModel -from ...pipelines import DiffusionPipeline -from ...utils import randn_tensor -from ...utils.dummy_pt_objects import DDPMScheduler - - -class ValueGuidedRLPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - Pipeline for sampling actions from a diffusion model trained to predict sequences of states. - - Original implementation inspired by this repository: https://github.com/jannerm/diffuser. - - Parameters: - value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward. - unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this - application is [`DDPMScheduler`]. - env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. - """ - - def __init__( - self, - value_function: UNet1DModel, - unet: UNet1DModel, - scheduler: DDPMScheduler, - env, - ): - super().__init__() - self.value_function = value_function - self.unet = unet - self.scheduler = scheduler - self.env = env - self.data = env.get_dataset() - self.means = dict() - for key in self.data.keys(): - try: - self.means[key] = self.data[key].mean() - except: # noqa: E722 - pass - self.stds = dict() - for key in self.data.keys(): - try: - self.stds[key] = self.data[key].std() - except: # noqa: E722 - pass - self.state_dim = env.observation_space.shape[0] - self.action_dim = env.action_space.shape[0] - - def normalize(self, x_in, key): - return (x_in - self.means[key]) / self.stds[key] - - def de_normalize(self, x_in, key): - return x_in * self.stds[key] + self.means[key] - - def to_torch(self, x_in): - if type(x_in) is dict: - return {k: self.to_torch(v) for k, v in x_in.items()} - elif torch.is_tensor(x_in): - return x_in.to(self.unet.device) - return torch.tensor(x_in, device=self.unet.device) - - def reset_x0(self, x_in, cond, act_dim): - for key, val in cond.items(): - x_in[:, key, act_dim:] = val.clone() - return x_in - - def run_diffusion(self, x, conditions, n_guide_steps, scale): - batch_size = x.shape[0] - y = None - for i in tqdm.tqdm(self.scheduler.timesteps): - # create batch of timesteps to pass into model - timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) - for _ in range(n_guide_steps): - with torch.enable_grad(): - x.requires_grad_() - - # permute to match dimension for pre-trained models - y = self.value_function(x.permute(0, 2, 1), timesteps).sample - grad = torch.autograd.grad([y.sum()], [x])[0] - - posterior_variance = self.scheduler._get_variance(i) - model_std = torch.exp(0.5 * posterior_variance) - grad = model_std * grad - - grad[timesteps < 2] = 0 - x = x.detach() - x = x + scale * grad - x = self.reset_x0(x, conditions, self.action_dim) - - prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) - - # TODO: verify deprecation of this kwarg - x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] - - # apply conditions to the trajectory (set the initial state) - x = self.reset_x0(x, conditions, self.action_dim) - x = self.to_torch(x) - return x, y - - def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): - # normalize the observations and create batch dimension - obs = self.normalize(obs, "observations") - obs = obs[None].repeat(batch_size, axis=0) - - conditions = {0: self.to_torch(obs)} - shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) - - # generate initial noise and apply our conditions (to make the trajectories start at current state) - x1 = randn_tensor(shape, device=self.unet.device) - x = self.reset_x0(x1, conditions, self.action_dim) - x = self.to_torch(x) - - # run the diffusion process - x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) - - # sort output trajectories by value - sorted_idx = y.argsort(0, descending=True).squeeze() - sorted_values = x[sorted_idx] - actions = sorted_values[:, :, : self.action_dim] - actions = actions.detach().cpu().numpy() - denorm_actions = self.de_normalize(actions, key="actions") - - # select the action with the highest value - if y is not None: - selected_index = 0 - else: - # if we didn't run value guiding, select a random action - selected_index = np.random.randint(0, batch_size) - - denorm_actions = denorm_actions[selected_index, 0] - return denorm_actions diff --git a/my_diffusers/loaders.py b/my_diffusers/loaders.py deleted file mode 100644 index beb2b9ef4e7e7e4ed43216ca71aafecf0f373cbb..0000000000000000000000000000000000000000 --- a/my_diffusers/loaders.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from collections import defaultdict -from typing import Callable, Dict, Union - -import torch - -from .models.cross_attention import LoRACrossAttnProcessor -from .models.modeling_utils import _get_model_file -from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging - - -logger = logging.get_logger(__name__) - - -LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" - - -class AttnProcsLayers(torch.nn.Module): - def __init__(self, state_dict: Dict[str, torch.Tensor]): - super().__init__() - self.layers = torch.nn.ModuleList(state_dict.values()) - self.mapping = {k: v for k, v in enumerate(state_dict.keys())} - self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - - # we add a hook to state_dict() and load_state_dict() so that the - # naming fits with `unet.attn_processors` - def map_to(module, state_dict, *args, **kwargs): - new_state_dict = {} - for key, value in state_dict.items(): - num = int(key.split(".")[1]) # 0 is always "layers" - new_key = key.replace(f"layers.{num}", module.mapping[num]) - new_state_dict[new_key] = value - - return new_state_dict - - def map_from(module, state_dict, *args, **kwargs): - all_keys = list(state_dict.keys()) - for key in all_keys: - replace_key = key.split(".processor")[0] + ".processor" - new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") - state_dict[new_key] = state_dict[key] - del state_dict[key] - - self._register_state_dict_hook(map_to) - self._register_load_state_dict_pre_hook(map_from, with_module=True) - - -class UNet2DConditionLoadersMixin: - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - r""" - Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be - defined in - [cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) - and be a `torch.nn.Module` class. - - - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - """ - - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME) - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path_or_dict - - # fill attn processors - attn_processors = {} - - is_lora = all("lora" in k for k in state_dict.keys()) - - if is_lora: - lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): - attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) - lora_grouped_dict[attn_processor_key][sub_key] = value - - for key, value_dict in lora_grouped_dict.items(): - rank = value_dict["to_k_lora.down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - - attn_processors[key] = LoRACrossAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank - ) - attn_processors[key].load_state_dict(value_dict) - - else: - raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") - - # set correct dtype & device - attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} - - # set layers - self.set_attn_processor(attn_processors) - - def save_attn_procs( - self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - weights_name: str = LORA_WEIGHT_NAME, - save_function: Callable = None, - ): - r""" - Save an attention processor to a directory, so that it can be re-loaded using the - `[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - """ - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - if save_function is None: - save_function = torch.save - - os.makedirs(save_directory, exist_ok=True) - - model_to_save = AttnProcsLayers(self.attn_processors) - - # Save the model - state_dict = model_to_save.state_dict() - - # Clean the folder from a previous save - for filename in os.listdir(save_directory): - full_filename = os.path.join(save_directory, filename) - # If we have a shard file that is not going to be replaced, we delete it, but only from the main process - # in distributed settings to avoid race conditions. - weights_no_suffix = weights_name.replace(".bin", "") - if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process: - os.remove(full_filename) - - # Save the model - save_function(state_dict, os.path.join(save_directory, weights_name)) - - logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") diff --git a/my_diffusers/models/__init__.py b/my_diffusers/models/__init__.py deleted file mode 100644 index e0b2cddd4bf9ecdfa8bfb62c471de11531670d0d..0000000000000000000000000000000000000000 --- a/my_diffusers/models/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..utils import is_flax_available, is_torch_available - - -if is_torch_available(): - from .autoencoder_kl import AutoencoderKL - from .controlnet import ControlNetModel - from .dual_transformer_2d import DualTransformer2DModel - from .modeling_utils import ModelMixin - from .prior_transformer import PriorTransformer - from .transformer_2d import Transformer2DModel - from .unet_1d import UNet1DModel - from .unet_2d import UNet2DModel - from .unet_2d_condition import UNet2DConditionModel - from .vq_model import VQModel - -if is_flax_available(): - from .unet_2d_condition_flax import FlaxUNet2DConditionModel - from .vae_flax import FlaxAutoencoderKL diff --git a/my_diffusers/models/__pycache__/__init__.cpython-311.pyc b/my_diffusers/models/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 3fab395fa75cdde012421e69eb50570b8a42d7ae..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/attention.cpython-311.pyc b/my_diffusers/models/__pycache__/attention.cpython-311.pyc deleted file mode 100644 index b9488b3fb5104e726379c825b304b66d688a70df..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/attention.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/attention_flax.cpython-311.pyc b/my_diffusers/models/__pycache__/attention_flax.cpython-311.pyc deleted file mode 100644 index 7129c1d17afab2306499033d50251ae761ce9494..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/attention_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/autoencoder_kl.cpython-311.pyc b/my_diffusers/models/__pycache__/autoencoder_kl.cpython-311.pyc deleted file mode 100644 index a86b8879f407467b60599f62a5f635b4702e7145..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/autoencoder_kl.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/controlnet.cpython-311.pyc b/my_diffusers/models/__pycache__/controlnet.cpython-311.pyc deleted file mode 100644 index 771031013aa8453b5baf227640fbbd336f997a9e..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/controlnet.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/cross_attention.cpython-311.pyc b/my_diffusers/models/__pycache__/cross_attention.cpython-311.pyc deleted file mode 100644 index 2b19f26b714010835912ebae19d027929a234c06..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/cross_attention.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/dual_transformer_2d.cpython-311.pyc b/my_diffusers/models/__pycache__/dual_transformer_2d.cpython-311.pyc deleted file mode 100644 index 6f7a05a42888e909511168bdfbf5ade595f4a513..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/dual_transformer_2d.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/embeddings.cpython-311.pyc b/my_diffusers/models/__pycache__/embeddings.cpython-311.pyc deleted file mode 100644 index 8078adb37c4b1ae9ace0e8bde811e83ad04645bc..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/embeddings.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/embeddings_flax.cpython-311.pyc b/my_diffusers/models/__pycache__/embeddings_flax.cpython-311.pyc deleted file mode 100644 index f24ca64635f5fe1913b4ac299b6bffbcfd8f791e..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/embeddings_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/modeling_flax_pytorch_utils.cpython-311.pyc b/my_diffusers/models/__pycache__/modeling_flax_pytorch_utils.cpython-311.pyc deleted file mode 100644 index 3e9ae6f00aadd50ccc8a37f41395fea124d52385..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/modeling_flax_pytorch_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/modeling_flax_utils.cpython-311.pyc b/my_diffusers/models/__pycache__/modeling_flax_utils.cpython-311.pyc deleted file mode 100644 index 0dcba36bdbdfdc3cd36897a913501367a3094e7c..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/modeling_flax_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-311.pyc b/my_diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-311.pyc deleted file mode 100644 index b48e19cc9490c2559e500ec25029a650e39120db..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/modeling_utils.cpython-311.pyc b/my_diffusers/models/__pycache__/modeling_utils.cpython-311.pyc deleted file mode 100644 index cddf05f2eee175b2ab4c23fe4adeb32f419a67dc..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/modeling_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/prior_transformer.cpython-311.pyc b/my_diffusers/models/__pycache__/prior_transformer.cpython-311.pyc deleted file mode 100644 index 866d6427cb9c9dc5f1506eedb382b199b6f73a96..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/prior_transformer.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/resnet.cpython-311.pyc b/my_diffusers/models/__pycache__/resnet.cpython-311.pyc deleted file mode 100644 index 9af41f1d06f4a9e61fb5354ab5a210e647d01d8a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/resnet.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/resnet_flax.cpython-311.pyc b/my_diffusers/models/__pycache__/resnet_flax.cpython-311.pyc deleted file mode 100644 index c2db537b688c7b7c0be4a97944446f4c99d380b9..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/resnet_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/transformer_2d.cpython-311.pyc b/my_diffusers/models/__pycache__/transformer_2d.cpython-311.pyc deleted file mode 100644 index 055407f6d2b5d6acab80a8049642e29c746f4482..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/transformer_2d.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/unet_1d.cpython-311.pyc b/my_diffusers/models/__pycache__/unet_1d.cpython-311.pyc deleted file mode 100644 index c49f8d5940d985845134c2f7a36e2b10e8e0cb88..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/unet_1d.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/unet_1d_blocks.cpython-311.pyc b/my_diffusers/models/__pycache__/unet_1d_blocks.cpython-311.pyc deleted file mode 100644 index 2eafbdf5cb7977d0662014601a7f8b27a687a13d..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/unet_1d_blocks.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/unet_2d.cpython-311.pyc b/my_diffusers/models/__pycache__/unet_2d.cpython-311.pyc deleted file mode 100644 index f34b3c20759575a8a16e969153791d26977df0f9..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/unet_2d.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/unet_2d_blocks.cpython-311.pyc b/my_diffusers/models/__pycache__/unet_2d_blocks.cpython-311.pyc deleted file mode 100644 index de37d3597979e2f47afaa088847c16cf38bdc92f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/unet_2d_blocks.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-311.pyc b/my_diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-311.pyc deleted file mode 100644 index 13fb80ce58aefff40fb3483fe15c6b41d6c361b6..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/unet_2d_blocks_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/unet_2d_condition.cpython-311.pyc b/my_diffusers/models/__pycache__/unet_2d_condition.cpython-311.pyc deleted file mode 100644 index b341bfeeb25b61a05a9bd8dc49d0d4f5bcef412e..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/unet_2d_condition.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/unet_2d_condition_flax.cpython-311.pyc b/my_diffusers/models/__pycache__/unet_2d_condition_flax.cpython-311.pyc deleted file mode 100644 index fe691c25807ed8700bbc1e779c56c443575cef29..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/unet_2d_condition_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/vae.cpython-311.pyc b/my_diffusers/models/__pycache__/vae.cpython-311.pyc deleted file mode 100644 index a71a484e9d03ca130c9d6eecaca5432e9c37d6cb..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/vae.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/vae_flax.cpython-311.pyc b/my_diffusers/models/__pycache__/vae_flax.cpython-311.pyc deleted file mode 100644 index 9bf2d3cc6ceb6b45230cf69ed937ee082f839ea5..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/vae_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/__pycache__/vq_model.cpython-311.pyc b/my_diffusers/models/__pycache__/vq_model.cpython-311.pyc deleted file mode 100644 index 7f14c3ebacd43c09694133b72ce3ed759a47846f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/models/__pycache__/vq_model.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/models/attention.py b/my_diffusers/models/attention.py deleted file mode 100644 index b476e762675c1972335265aa62a99baf5ae22c5a..0000000000000000000000000000000000000000 --- a/my_diffusers/models/attention.py +++ /dev/null @@ -1,517 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from typing import Callable, Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from ..utils.import_utils import is_xformers_available -from .cross_attention import CrossAttention -from .embeddings import CombinedTimestepLabelEmbeddings - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (`int`): The number of channels in the input and output. - num_head_channels (`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - norm_num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - - self._use_memory_efficient_attention_xformers = False - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self._attention_op = attention_op - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - query_proj = self.reshape_heads_to_batch_dim(query_proj) - key_proj = self.reshape_heads_to_batch_dim(key_proj) - value_proj = self.reshape_heads_to_batch_dim(value_proj) - - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op - ) - hidden_states = hidden_states.to(query_proj.dtype) - else: - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # 1. Self-Attn - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) - - # 2. Cross-Attn - if cross_attention_dim is not None: - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) # is self-attn if encoder_hidden_states is none - else: - self.attn2 = None - - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - if cross_attention_dim is not None: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - else: - self.norm2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - class_labels=None, - ): - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - # 1. Self-Attention - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - - # 2. Cross-Attention - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(nn.Linear(inner_dim, dim_out)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. - """ - - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - self.approximate = approximate - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) - - def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states - - -class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def gelu(self, gate): - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * self.gelu(gate) - - -class ApproximateGELU(nn.Module): - """ - The approximate form of Gaussian Error Linear Unit (GELU) - - For more details, see section 2: https://arxiv.org/abs/1606.08415 - """ - - def __init__(self, dim_in: int, dim_out: int): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out) - - def forward(self, x): - x = self.proj(x) - return x * torch.sigmoid(1.702 * x) - - -class AdaLayerNorm(nn.Module): - """ - Norm layer modified to incorporate timestep embeddings. - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) - - def forward(self, x, timestep): - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = torch.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift - return x - - -class AdaLayerNormZero(nn.Module): - """ - Norm layer adaptive layer norm zero (adaLN-Zero). - """ - - def __init__(self, embedding_dim, num_embeddings): - super().__init__() - - self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) - - self.silu = nn.SiLU() - self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) - - def forward(self, x, timestep, class_labels, hidden_dtype=None): - emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp - - -class AdaGroupNorm(nn.Module): - """ - GroupNorm layer modified to incorporate timestep embeddings. - """ - - def __init__( - self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 - ): - super().__init__() - self.num_groups = num_groups - self.eps = eps - self.act = None - if act_fn == "swish": - self.act = lambda x: F.silu(x) - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "gelu": - self.act = nn.GELU() - - self.linear = nn.Linear(embedding_dim, out_dim * 2) - - def forward(self, x, emb): - if self.act: - emb = self.act(emb) - emb = self.linear(emb) - emb = emb[:, :, None, None] - scale, shift = emb.chunk(2, dim=1) - - x = F.group_norm(x, self.num_groups, eps=self.eps) - x = x * (1 + scale) + shift - return x diff --git a/my_diffusers/models/attention_flax.py b/my_diffusers/models/attention_flax.py deleted file mode 100644 index b0717356bec16a15bff858b696c67869757216fa..0000000000000000000000000000000000000000 --- a/my_diffusers/models/attention_flax.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import flax.linen as nn -import jax.numpy as jnp - - -class FlaxCrossAttention(nn.Module): - r""" - A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 - - Parameters: - query_dim (:obj:`int`): - Input hidden states dimension - heads (:obj:`int`, *optional*, defaults to 8): - Number of heads - dim_head (:obj:`int`, *optional*, defaults to 64): - Hidden states dimension inside each head - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - - """ - query_dim: int - heads: int = 8 - dim_head: int = 64 - dropout: float = 0.0 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - inner_dim = self.dim_head * self.heads - self.scale = self.dim_head**-0.5 - - # Weights were exported with old names {to_q, to_k, to_v, to_out} - self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") - self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") - self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") - - self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def __call__(self, hidden_states, context=None, deterministic=True): - context = hidden_states if context is None else context - - query_proj = self.query(hidden_states) - key_proj = self.key(context) - value_proj = self.value(context) - - query_states = self.reshape_heads_to_batch_dim(query_proj) - key_states = self.reshape_heads_to_batch_dim(key_proj) - value_states = self.reshape_heads_to_batch_dim(value_proj) - - # compute attentions - attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) - attention_scores = attention_scores * self.scale - attention_probs = nn.softmax(attention_scores, axis=2) - - # attend to values - hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - hidden_states = self.proj_attn(hidden_states) - return hidden_states - - -class FlaxBasicTransformerBlock(nn.Module): - r""" - A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: - https://arxiv.org/abs/1706.03762 - - - Parameters: - dim (:obj:`int`): - Inner hidden states dimension - n_heads (:obj:`int`): - Number of heads - d_head (:obj:`int`): - Hidden states dimension inside each head - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - only_cross_attention (`bool`, defaults to `False`): - Whether to only apply cross attention. - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - dim: int - n_heads: int - d_head: int - dropout: float = 0.0 - only_cross_attention: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - # self attention (or cross_attention if only_cross_attention is True) - self.attn1 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) - # cross attention - self.attn2 = FlaxCrossAttention(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) - self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) - self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) - self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) - self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) - - def __call__(self, hidden_states, context, deterministic=True): - # self attention - residual = hidden_states - if self.only_cross_attention: - hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) - else: - hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) - hidden_states = hidden_states + residual - - # cross attention - residual = hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) - hidden_states = hidden_states + residual - - # feed forward - residual = hidden_states - hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) - hidden_states = hidden_states + residual - - return hidden_states - - -class FlaxTransformer2DModel(nn.Module): - r""" - A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: - https://arxiv.org/pdf/1506.02025.pdf - - - Parameters: - in_channels (:obj:`int`): - Input number of channels - n_heads (:obj:`int`): - Number of heads - d_head (:obj:`int`): - Hidden states dimension inside each head - depth (:obj:`int`, *optional*, defaults to 1): - Number of transformers block - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - use_linear_projection (`bool`, defaults to `False`): tbd - only_cross_attention (`bool`, defaults to `False`): tbd - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - n_heads: int - d_head: int - depth: int = 1 - dropout: float = 0.0 - use_linear_projection: bool = False - only_cross_attention: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) - - inner_dim = self.n_heads * self.d_head - if self.use_linear_projection: - self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) - else: - self.proj_in = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - - self.transformer_blocks = [ - FlaxBasicTransformerBlock( - inner_dim, - self.n_heads, - self.d_head, - dropout=self.dropout, - only_cross_attention=self.only_cross_attention, - dtype=self.dtype, - ) - for _ in range(self.depth) - ] - - if self.use_linear_projection: - self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) - else: - self.proj_out = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - - def __call__(self, hidden_states, context, deterministic=True): - batch, height, width, channels = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - if self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height * width, channels) - hidden_states = self.proj_in(hidden_states) - else: - hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.reshape(batch, height * width, channels) - - for transformer_block in self.transformer_blocks: - hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) - - if self.use_linear_projection: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, channels) - else: - hidden_states = hidden_states.reshape(batch, height, width, channels) - hidden_states = self.proj_out(hidden_states) - - hidden_states = hidden_states + residual - return hidden_states - - -class FlaxFeedForward(nn.Module): - r""" - Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's - [`FeedForward`] class, with the following simplifications: - - The activation function is currently hardcoded to a gated linear unit from: - https://arxiv.org/abs/2002.05202 - - `dim_out` is equal to `dim`. - - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. - - Parameters: - dim (:obj:`int`): - Inner hidden states dimension - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - dim: int - dropout: float = 0.0 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - # The second linear layer needs to be called - # net_2 for now to match the index of the Sequential layer - self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) - self.net_2 = nn.Dense(self.dim, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic=True): - hidden_states = self.net_0(hidden_states) - hidden_states = self.net_2(hidden_states) - return hidden_states - - -class FlaxGEGLU(nn.Module): - r""" - Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from - https://arxiv.org/abs/2002.05202. - - Parameters: - dim (:obj:`int`): - Input hidden states dimension - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - dim: int - dropout: float = 0.0 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - inner_dim = self.dim * 4 - self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic=True): - hidden_states = self.proj(hidden_states) - hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - return hidden_linear * nn.gelu(hidden_gelu) diff --git a/my_diffusers/models/autoencoder_kl.py b/my_diffusers/models/autoencoder_kl.py deleted file mode 100644 index 9cb0a4b2432bcf0602627a725be94051bb30f57b..0000000000000000000000000000000000000000 --- a/my_diffusers/models/autoencoder_kl.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, apply_forward_hook -from .modeling_utils import ModelMixin -from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder - - -@dataclass -class AutoencoderKLOutput(BaseOutput): - """ - Output of AutoencoderKL encoding method. - - Args: - latent_dist (`DiagonalGaussianDistribution`): - Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. - `DiagonalGaussianDistribution` allows for sampling latents from the distribution. - """ - - latent_dist: "DiagonalGaussianDistribution" - - -class AutoencoderKL(ModelMixin, ConfigMixin): - r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma - and Max Welling. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - scaling_factor (`float`, *optional*, defaults to 0.18215): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - sample_size: int = 32, - scaling_factor: float = 0.18215, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - ) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - ) - - self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) - self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) - - self.use_slicing = False - self.use_tiling = False - - # only relevant if vae tiling is enabled - self.tile_sample_min_size = self.config.sample_size - sample_size = ( - self.config.sample_size[0] - if isinstance(self.config.sample_size, (list, tuple)) - else self.config.sample_size - ) - self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1))) - self.tile_overlap_factor = 0.25 - - def enable_tiling(self, use_tiling: bool = True): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow - the processing of larger images. - """ - self.use_tiling = use_tiling - - def disable_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.enable_tiling(False) - - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - - @apply_forward_hook - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): - return self.tiled_encode(x, return_dict=return_dict) - - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): - return self.tiled_decode(z, return_dict=return_dict) - - z = self.post_quant_conv(z) - dec = self.decoder(z) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - @apply_forward_hook - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] - decoded = torch.cat(decoded_slices) - else: - decoded = self._decode(z).sample - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def blend_v(self, a, b, blend_extent): - for y in range(blend_extent): - b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) - return b - - def blend_h(self, a, b, blend_extent): - for x in range(blend_extent): - b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) - return b - - def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - r"""Encode a batch of images using a tiled encoder. - Args: - When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: - different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the - tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the - look of the output, but they should be much less noticeable. - x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. - """ - overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) - row_limit = self.tile_latent_min_size - blend_extent - - # Split the image into 512x512 tiles and encode them separately. - rows = [] - for i in range(0, x.shape[2], overlap_size): - row = [] - for j in range(0, x.shape[3], overlap_size): - tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] - tile = self.encoder(tile) - tile = self.quant_conv(tile) - row.append(tile) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=3)) - - moments = torch.cat(result_rows, dim=2) - posterior = DiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - r"""Decode a batch of images using a tiled decoder. - Args: - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: - different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the - tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the - look of the output, but they should be much less noticeable. - z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to - `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) - row_limit = self.tile_sample_min_size - blend_extent - - # Split z into overlapping 64x64 tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, z.shape[2], overlap_size): - row = [] - for j in range(0, z.shape[3], overlap_size): - tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] - tile = self.post_quant_conv(tile) - decoded = self.decoder(tile) - row.append(decoded) - rows.append(row) - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent) - result_row.append(tile[:, :, :row_limit, :row_limit]) - result_rows.append(torch.cat(result_row, dim=3)) - - dec = torch.cat(result_rows, dim=2) - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def forward( - self, - sample: torch.FloatTensor, - sample_posterior: bool = False, - return_dict: bool = True, - generator: Optional[torch.Generator] = None, - ) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - sample_posterior (`bool`, *optional*, defaults to `False`): - Whether to sample from the posterior. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - posterior = self.encode(x).latent_dist - if sample_posterior: - z = posterior.sample(generator=generator) - else: - z = posterior.mode() - dec = self.decode(z).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/my_diffusers/models/controlnet.py b/my_diffusers/models/controlnet.py deleted file mode 100644 index 5b55da5af37273c2ca1e9296ef45fd29e9830138..0000000000000000000000000000000000000000 --- a/my_diffusers/models/controlnet.py +++ /dev/null @@ -1,506 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn -from torch.nn import functional as F - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging -from .cross_attention import AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - DownBlock2D, - UNetMidBlock2DCrossAttn, - get_down_block, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class ControlNetOutput(BaseOutput): - down_block_res_samples: Tuple[torch.Tensor] - mid_block_res_sample: torch.Tensor - - -class ControlNetConditioningEmbedding(nn.Module): - """ - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." - """ - - def __init__( - self, - conditioning_embedding_channels: int, - conditioning_channels: int = 3, - block_out_channels: Tuple[int] = (16, 32, 96, 256), - ): - super().__init__() - - self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) - - self.blocks = nn.ModuleList([]) - - for i in range(len(block_out_channels) - 1): - channel_in = block_out_channels[i] - channel_out = block_out_channels[i + 1] - self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) - self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) - - self.conv_out = zero_module( - nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) - ) - - def forward(self, conditioning): - embedding = self.conv_in(conditioning) - embedding = F.silu(embedding) - - for block in self.blocks: - embedding = block(embedding) - embedding = F.silu(embedding) - - embedding = self.conv_out(embedding) - - return embedding - - -class ControlNetModel(ModelMixin, ConfigMixin): - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - in_channels: int = 4, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - projection_class_embeddings_input_dim: Optional[int] = None, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), - ): - super().__init__() - - # Check inputs - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - # input - conv_in_kernel = 3 - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - time_embed_dim = block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - ) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - # control net conditioning embedding - self.controlnet_cond_embedding = ControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - ) - - self.down_blocks = nn.ModuleList([]) - self.controlnet_down_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - for _ in range(layers_per_block): - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - if not is_final_block: - controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_down_blocks.append(controlnet_block) - - # mid - mid_block_channel = block_out_channels[-1] - - controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) - controlnet_block = zero_module(controlnet_block) - self.controlnet_mid_block = controlnet_block - - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=mid_block_channel, - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - - @property - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttnProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): - r""" - Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - - num_slicable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_slicable_layers * [1] - - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: torch.FloatTensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple]: - # check channel order - channel_order = self.config.controlnet_conditioning_channel_order - - if channel_order == "rgb": - # in rgb order by default - ... - elif channel_order == "bgr": - controlnet_cond = torch.flip(controlnet_cond, dims=[1]) - else: - raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - sample = self.conv_in(sample) - - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - - sample += controlnet_cond - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - # 5. Control net blocks - - controlnet_down_block_res_samples = () - - for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): - down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) - - down_block_res_samples = controlnet_down_block_res_samples - - mid_block_res_sample = self.controlnet_mid_block(sample) - - if not return_dict: - return (down_block_res_samples, mid_block_res_sample) - - return ControlNetOutput( - down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample - ) - - -def zero_module(module): - for p in module.parameters(): - nn.init.zeros_(p) - return module diff --git a/my_diffusers/models/cross_attention.py b/my_diffusers/models/cross_attention.py deleted file mode 100644 index a87f540b370e482e0cf34035702b0cf7c6c85ccf..0000000000000000000000000000000000000000 --- a/my_diffusers/models/cross_attention.py +++ /dev/null @@ -1,681 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, Optional, Union - -import torch -import torch.nn.functional as F -from torch import nn - -from ..utils import deprecate, logging -from ..utils.import_utils import is_xformers_available - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (`int`): The number of channels in the query. - cross_attention_dim (`int`, *optional*): - The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. - heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - bias (`bool`, *optional*, defaults to False): - Set to `True` for the query, key, and value linear layers to contain a bias parameter. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias=False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - cross_attention_norm: bool = False, - added_kv_proj_dim: Optional[int] = None, - norm_num_groups: Optional[int] = None, - processor: Optional["AttnProcessor"] = None, - ): - super().__init__() - inner_dim = dim_head * heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.cross_attention_norm = cross_attention_norm - - self.scale = dim_head**-0.5 - - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - - if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) - else: - self.group_norm = None - - if cross_attention_norm: - self.norm_cross = nn.LayerNorm(cross_attention_dim) - - self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) - - if self.added_kv_proj_dim is not None: - self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) - - self.to_out = nn.ModuleList([]) - self.to_out.append(nn.Linear(inner_dim, query_dim)) - self.to_out.append(nn.Dropout(dropout)) - - # set attention processor - # We use the AttnProcessor2_0 by default when torch2.x is used which uses - # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - if processor is None: - processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() - self.set_processor(processor) - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): - is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) - ) - - if use_memory_efficient_attention_xformers: - if self.added_kv_proj_dim is not None: - # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported - raise NotImplementedError( - "Memory efficient attention with `xformers` is currently not supported when" - " `self.added_kv_proj_dim` is defined." - ) - elif not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - - if is_lora: - processor = LoRAXFormersCrossAttnProcessor( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - attention_op=attention_op, - ) - processor.load_state_dict(self.processor.state_dict()) - processor.to(self.processor.to_q_lora.up.weight.device) - else: - processor = XFormersCrossAttnProcessor(attention_op=attention_op) - else: - if is_lora: - processor = LoRACrossAttnProcessor( - hidden_size=self.processor.hidden_size, - cross_attention_dim=self.processor.cross_attention_dim, - rank=self.processor.rank, - ) - processor.load_state_dict(self.processor.state_dict()) - processor.to(self.processor.to_q_lora.up.weight.device) - else: - processor = CrossAttnProcessor() - - self.set_processor(processor) - - def set_attention_slice(self, slice_size): - if slice_size is not None and slice_size > self.sliceable_head_dim: - raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") - - if slice_size is not None and self.added_kv_proj_dim is not None: - processor = SlicedAttnAddedKVProcessor(slice_size) - elif slice_size is not None: - processor = SlicedAttnProcessor(slice_size) - elif self.added_kv_proj_dim is not None: - processor = CrossAttnAddedKVProcessor() - else: - processor = CrossAttnProcessor() - - self.set_processor(processor) - - def set_processor(self, processor: "AttnProcessor"): - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): - # The `CrossAttention` class can call different attention processors / attention functions - # here we simply pass along all tensors to the selected processor class - # For standard processors that are defined here, `**cross_attention_kwargs` is empty - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - def batch_to_head_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def head_to_batch_dim(self, tensor): - head_size = self.heads - batch_size, seq_len, dim = tensor.shape - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def get_attention_scores(self, query, key, attention_mask=None): - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - if attention_mask is None: - baddbmm_input = torch.empty( - query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device - ) - beta = 0 - else: - baddbmm_input = attention_mask - beta = 1 - - attention_scores = torch.baddbmm( - baddbmm_input, - query, - key.transpose(-1, -2), - beta=beta, - alpha=self.scale, - ) - - if self.upcast_softmax: - attention_scores = attention_scores.float() - - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype) - - return attention_probs - - def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): - if batch_size is None: - deprecate( - "batch_size=None", - "0.0.15", - message=( - "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" - " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" - " `prepare_attention_mask` when preparing the attention_mask." - ), - ) - batch_size = 1 - - head_size = self.heads - if attention_mask is None: - return attention_mask - - if attention_mask.shape[-1] != target_length: - if attention_mask.device.type == "mps": - # HACK: MPS: Does not support padding by greater than dimension of input tensor. - # Instead, we can manually construct the padding tensor. - padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) - padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) - attention_mask = torch.cat([attention_mask, padding], dim=2) - else: - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - - if attention_mask.shape[0] < batch_size * head_size: - attention_mask = attention_mask.repeat_interleave(head_size, dim=0) - return attention_mask - - -class CrossAttnProcessor: - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4): - super().__init__() - - if rank > min(in_features, out_features): - raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - - self.down = nn.Linear(in_features, rank, bias=False) - self.up = nn.Linear(rank, out_features, bias=False) - - nn.init.normal_(self.down.weight, std=1 / rank) - nn.init.zeros_(self.up.weight) - - def forward(self, hidden_states): - orig_dtype = hidden_states.dtype - dtype = self.down.weight.dtype - - down_hidden_states = self.down(hidden_states.to(dtype)) - up_hidden_states = self.up(down_hidden_states) - - return up_hidden_states.to(orig_dtype) - - -class LoRACrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): - super().__init__() - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.rank = rank - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - - def __call__( - self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query) - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) - - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class CrossAttnAddedKVProcessor: - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) - batch_size, sequence_length, _ = hidden_states.shape - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) - encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) - - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - - return hidden_states - - -class XFormersCrossAttnProcessor: - def __init__(self, attention_op: Optional[Callable] = None): - self.attention_op = attention_op - - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query).contiguous() - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op - ) - hidden_states = hidden_states.to(query.dtype) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class AttnProcessor2_0: - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, inner_dim = hidden_states.shape - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -class LoRAXFormersCrossAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): - super().__init__() - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.rank = rank - self.attention_op = attention_op - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - - def __call__( - self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) - query = attn.head_to_batch_dim(query).contiguous() - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - - key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) - - key = attn.head_to_batch_dim(key).contiguous() - value = attn.head_to_batch_dim(value).contiguous() - - hidden_states = xformers.ops.memory_efficient_attention( - query, key, value, attn_bias=attention_mask, op=self.attention_op - ) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SlicedAttnProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - for i in range(hidden_states.shape[0] // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SlicedAttnAddedKVProcessor: - def __init__(self, slice_size): - self.slice_size = slice_size - - def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): - residual = hidden_states - hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) - encoder_hidden_states = encoder_hidden_states.transpose(1, 2) - - batch_size, sequence_length, _ = hidden_states.shape - - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) - encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) - - key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) - - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype - ) - - for i in range(hidden_states.shape[0] // self.slice_size): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - key_slice = key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) - - attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) - hidden_states = hidden_states + residual - - return hidden_states - - -AttnProcessor = Union[ - CrossAttnProcessor, - XFormersCrossAttnProcessor, - SlicedAttnProcessor, - CrossAttnAddedKVProcessor, - SlicedAttnAddedKVProcessor, - LoRACrossAttnProcessor, - LoRAXFormersCrossAttnProcessor, -] diff --git a/my_diffusers/models/dual_transformer_2d.py b/my_diffusers/models/dual_transformer_2d.py deleted file mode 100644 index 8b805c98147c82f674b55c448aa86332a1b04932..0000000000000000000000000000000000000000 --- a/my_diffusers/models/dual_transformer_2d.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -from torch import nn - -from .transformer_2d import Transformer2DModel, Transformer2DModelOutput - - -class DualTransformer2DModel(nn.Module): - """ - Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - ): - super().__init__() - self.transformers = nn.ModuleList( - [ - Transformer2DModel( - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - in_channels=in_channels, - num_layers=num_layers, - dropout=dropout, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attention_bias=attention_bias, - sample_size=sample_size, - num_vector_embeds=num_vector_embeds, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - ) - for _ in range(2) - ] - ) - - # Variables that can be set by a pipeline: - - # The ratio of transformer1 to transformer2's output states to be combined during inference - self.mix_ratio = 0.5 - - # The shape of `encoder_hidden_states` is expected to be - # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` - self.condition_lengths = [77, 257] - - # Which transformer to use to encode which condition. - # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` - self.transformer_index_for_condition = [1, 0] - - def forward( - self, - hidden_states, - encoder_hidden_states, - timestep=None, - attention_mask=None, - cross_attention_kwargs=None, - return_dict: bool = True, - ): - """ - Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - attention_mask (`torch.FloatTensor`, *optional*): - Optional attention mask to be applied in CrossAttention - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - input_states = hidden_states - - encoded_states = [] - tokens_start = 0 - # attention_mask is not used yet - for i in range(2): - # for each of the two transformers, pass the corresponding condition tokens - condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] - transformer_index = self.transformer_index_for_condition[i] - encoded_state = self.transformers[transformer_index]( - input_states, - encoder_hidden_states=condition_state, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] - encoded_states.append(encoded_state - input_states) - tokens_start += self.condition_lengths[i] - - output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) - output_states = output_states + input_states - - if not return_dict: - return (output_states,) - - return Transformer2DModelOutput(sample=output_states) diff --git a/my_diffusers/models/embeddings.py b/my_diffusers/models/embeddings.py deleted file mode 100644 index d12e75344ba1fc7e5703be7dd32609e248465c5c..0000000000000000000000000000000000000000 --- a/my_diffusers/models/embeddings.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -from typing import Optional - -import numpy as np -import torch -from torch import nn - - -def get_timestep_embedding( - timesteps: torch.Tensor, - embedding_dim: int, - flip_sin_to_cos: bool = False, - downscale_freq_shift: float = 1, - scale: float = 1, - max_period: int = 10000, -): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. - """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" - - half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) - - emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] - - # scale embeddings - emb = scale * emb - - # concat sine and cosine embeddings - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) - - # flip sine and cosine embeddings - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) - - # zero pad - if embedding_dim % 2 == 1: - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): - """ - grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) - """ - if embed_dim % 2 != 0: - raise ValueError("embed_dim must be divisible by 2") - - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -class PatchEmbed(nn.Module): - """2D Image to Patch Embedding""" - - def __init__( - self, - height=224, - width=224, - patch_size=16, - in_channels=3, - embed_dim=768, - layer_norm=False, - flatten=True, - bias=True, - ): - super().__init__() - - num_patches = (height // patch_size) * (width // patch_size) - self.flatten = flatten - self.layer_norm = layer_norm - - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) - if layer_norm: - self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) - else: - self.norm = None - - pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) - - def forward(self, latent): - latent = self.proj(latent) - if self.flatten: - latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC - if self.layer_norm: - latent = self.norm(latent) - return latent + self.pos_embed - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - - if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) - else: - self.cond_proj = None - - if act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "gelu": - self.act = nn.GELU() - else: - raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - - if post_act_fn is None: - self.post_act = None - elif post_act_fn == "silu": - self.post_act = nn.SiLU() - elif post_act_fn == "mish": - self.post_act = nn.Mish() - elif post_act_fn == "gelu": - self.post_act = nn.GELU() - else: - raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") - - def forward(self, sample, condition=None): - if condition is not None: - sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - - if self.post_act is not None: - sample = self.post_act(sample) - return sample - - -class Timesteps(nn.Module): - def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): - super().__init__() - self.num_channels = num_channels - self.flip_sin_to_cos = flip_sin_to_cos - self.downscale_freq_shift = downscale_freq_shift - - def forward(self, timesteps): - t_emb = get_timestep_embedding( - timesteps, - self.num_channels, - flip_sin_to_cos=self.flip_sin_to_cos, - downscale_freq_shift=self.downscale_freq_shift, - ) - return t_emb - - -class GaussianFourierProjection(nn.Module): - """Gaussian Fourier embeddings for noise levels.""" - - def __init__( - self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False - ): - super().__init__() - self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - self.log = log - self.flip_sin_to_cos = flip_sin_to_cos - - if set_W_to_weight: - # to delete later - self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) - - self.weight = self.W - - def forward(self, x): - if self.log: - x = torch.log(x) - - x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi - - if self.flip_sin_to_cos: - out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) - else: - out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) - return out - - -class ImagePositionalEmbeddings(nn.Module): - """ - Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the - height and width of the latent space. - - For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 - - For VQ-diffusion: - - Output vector embeddings are used as input for the transformer. - - Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. - - Args: - num_embed (`int`): - Number of embeddings for the latent pixels embeddings. - height (`int`): - Height of the latent image i.e. the number of height embeddings. - width (`int`): - Width of the latent image i.e. the number of width embeddings. - embed_dim (`int`): - Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. - """ - - def __init__( - self, - num_embed: int, - height: int, - width: int, - embed_dim: int, - ): - super().__init__() - - self.height = height - self.width = width - self.num_embed = num_embed - self.embed_dim = embed_dim - - self.emb = nn.Embedding(self.num_embed, embed_dim) - self.height_emb = nn.Embedding(self.height, embed_dim) - self.width_emb = nn.Embedding(self.width, embed_dim) - - def forward(self, index): - emb = self.emb(index) - - height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) - - # 1 x H x D -> 1 x H x 1 x D - height_emb = height_emb.unsqueeze(2) - - width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) - - # 1 x W x D -> 1 x 1 x W x D - width_emb = width_emb.unsqueeze(1) - - pos_emb = height_emb + width_emb - - # 1 x H x W x D -> 1 x L xD - pos_emb = pos_emb.view(1, self.height * self.width, -1) - - emb = emb + pos_emb[:, : emb.shape[1], :] - - return emb - - -class LabelEmbedding(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - - Args: - num_classes (`int`): The number of classes. - hidden_size (`int`): The size of the vector embeddings. - dropout_prob (`float`): The probability of dropping a label. - """ - - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - else: - drop_ids = torch.tensor(force_drop_ids == 1) - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -class CombinedTimestepLabelEmbeddings(nn.Module): - def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): - super().__init__() - - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) - - def forward(self, timestep, class_labels, hidden_dtype=None): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) - - class_labels = self.class_embedder(class_labels) # (N, D) - - conditioning = timesteps_emb + class_labels # (N, D) - - return conditioning diff --git a/my_diffusers/models/embeddings_flax.py b/my_diffusers/models/embeddings_flax.py deleted file mode 100644 index 88c2c45e4655b8013fa96e0b4408e3ec0a87c2c7..0000000000000000000000000000000000000000 --- a/my_diffusers/models/embeddings_flax.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import flax.linen as nn -import jax.numpy as jnp - - -def get_sinusoidal_embeddings( - timesteps: jnp.ndarray, - embedding_dim: int, - freq_shift: float = 1, - min_timescale: float = 1, - max_timescale: float = 1.0e4, - flip_sin_to_cos: bool = False, - scale: float = 1.0, -) -> jnp.ndarray: - """Returns the positional encoding (same as Tensor2Tensor). - - Args: - timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - embedding_dim: The number of output channels. - min_timescale: The smallest time unit (should probably be 0.0). - max_timescale: The largest time unit. - Returns: - a Tensor of timing signals [N, num_channels] - """ - assert timesteps.ndim == 1, "Timesteps should be a 1d-array" - assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" - num_timescales = float(embedding_dim // 2) - log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) - inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) - emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) - - # scale embeddings - scaled_time = scale * emb - - if flip_sin_to_cos: - signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) - else: - signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) - signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) - return signal - - -class FlaxTimestepEmbedding(nn.Module): - r""" - Time step Embedding Module. Learns embeddings for input time steps. - - Args: - time_embed_dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - time_embed_dim: int = 32 - dtype: jnp.dtype = jnp.float32 - - @nn.compact - def __call__(self, temb): - temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) - temb = nn.silu(temb) - temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) - return temb - - -class FlaxTimesteps(nn.Module): - r""" - Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 - - Args: - dim (`int`, *optional*, defaults to `32`): - Time step embedding dimension - """ - dim: int = 32 - flip_sin_to_cos: bool = False - freq_shift: float = 1 - - @nn.compact - def __call__(self, timesteps): - return get_sinusoidal_embeddings( - timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift - ) diff --git a/my_diffusers/models/modeling_flax_pytorch_utils.py b/my_diffusers/models/modeling_flax_pytorch_utils.py deleted file mode 100644 index f9de83f87dab84d2e7fdd77b835db787cb4f1cb6..0000000000000000000000000000000000000000 --- a/my_diffusers/models/modeling_flax_pytorch_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch - Flax general utilities.""" -import re - -import jax.numpy as jnp -from flax.traverse_util import flatten_dict, unflatten_dict -from jax.random import PRNGKey - -from ..utils import logging - - -logger = logging.get_logger(__name__) - - -def rename_key(key): - regex = r"\w+[.]\d+" - pats = re.findall(regex, key) - for pat in pats: - key = key.replace(pat, "_".join(pat.split("."))) - return key - - -##################### -# PyTorch => Flax # -##################### - - -# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 -# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py -def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): - """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" - - # conv norm or layer norm - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - if ( - any("norm" in str_ for str_ in pt_tuple_key) - and (pt_tuple_key[-1] == "bias") - and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) - and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) - ): - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - return renamed_pt_tuple_key, pt_tensor - elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) - return renamed_pt_tuple_key, pt_tensor - - # embedding - if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: - pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) - return renamed_pt_tuple_key, pt_tensor - - # conv layer - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: - pt_tensor = pt_tensor.transpose(2, 3, 1, 0) - return renamed_pt_tuple_key, pt_tensor - - # linear layer - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) - if pt_tuple_key[-1] == "weight": - pt_tensor = pt_tensor.T - return renamed_pt_tuple_key, pt_tensor - - # old PyTorch layer norm weight - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) - if pt_tuple_key[-1] == "gamma": - return renamed_pt_tuple_key, pt_tensor - - # old PyTorch layer norm bias - renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) - if pt_tuple_key[-1] == "beta": - return renamed_pt_tuple_key, pt_tensor - - return pt_tuple_key, pt_tensor - - -def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): - # Step 1: Convert pytorch tensor to numpy - pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} - - # Step 2: Since the model is stateless, get random Flax params - random_flax_params = flax_model.init_weights(PRNGKey(init_key)) - - random_flax_state_dict = flatten_dict(random_flax_params) - flax_state_dict = {} - - # Need to change some parameters name to match Flax names - for pt_key, pt_tensor in pt_state_dict.items(): - renamed_pt_key = rename_key(pt_key) - pt_tuple_key = tuple(renamed_pt_key.split(".")) - - # Correctly rename weight parameters - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) - - if flax_key in random_flax_state_dict: - if flax_tensor.shape != random_flax_state_dict[flax_key].shape: - raise ValueError( - f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " - f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." - ) - - # also add unexpected weight so that warning is thrown - flax_state_dict[flax_key] = jnp.asarray(flax_tensor) - - return unflatten_dict(flax_state_dict) diff --git a/my_diffusers/models/modeling_flax_utils.py b/my_diffusers/models/modeling_flax_utils.py deleted file mode 100644 index 58c492a974a30624cbf1bd638ad3ed5202ef3862..0000000000000000000000000000000000000000 --- a/my_diffusers/models/modeling_flax_utils.py +++ /dev/null @@ -1,526 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from pickle import UnpicklingError -from typing import Any, Dict, Union - -import jax -import jax.numpy as jnp -import msgpack.exceptions -from flax.core.frozen_dict import FrozenDict, unfreeze -from flax.serialization import from_bytes, to_bytes -from flax.traverse_util import flatten_dict, unflatten_dict -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError -from requests import HTTPError - -from .. import __version__, is_torch_available -from ..utils import ( - CONFIG_NAME, - DIFFUSERS_CACHE, - FLAX_WEIGHTS_NAME, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, - WEIGHTS_NAME, - logging, -) -from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax - - -logger = logging.get_logger(__name__) - - -class FlaxModelMixin: - r""" - Base class for all flax models. - - [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, - downloading and saving models. - """ - config_name = CONFIG_NAME - _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - _flax_internal_args = ["name", "parent", "dtype"] - - @classmethod - def _from_config(cls, config, **kwargs): - """ - All context managers that the model should be initialized under go here. - """ - return cls(config, **kwargs) - - def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: - """ - Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. - """ - - # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 - def conditional_cast(param): - if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): - param = param.astype(dtype) - return param - - if mask is None: - return jax.tree_map(conditional_cast, params) - - flat_params = flatten_dict(params) - flat_mask, _ = jax.tree_flatten(mask) - - for masked, key in zip(flat_mask, flat_params.keys()): - if masked: - param = flat_params[key] - flat_params[key] = conditional_cast(param) - - return unflatten_dict(flat_params) - - def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): - r""" - Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast - the `params` in place. - - This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full - half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. - - Arguments: - params (`Union[Dict, FrozenDict]`): - A `PyTree` of model parameters. - mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip. - - Examples: - - ```python - >>> from diffusers import FlaxUNet2DConditionModel - - >>> # load model - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision - >>> params = model.to_bf16(params) - >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) - >>> # then pass the mask as follows - >>> from flax import traverse_util - - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> flat_params = traverse_util.flatten_dict(params) - >>> mask = { - ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) - ... for path in flat_params - ... } - >>> mask = traverse_util.unflatten_dict(mask) - >>> params = model.to_bf16(params, mask) - ```""" - return self._cast_floating_to(params, jnp.bfloat16, mask) - - def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): - r""" - Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the - model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. - - Arguments: - params (`Union[Dict, FrozenDict]`): - A `PyTree` of model parameters. - mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip - - Examples: - - ```python - >>> from diffusers import FlaxUNet2DConditionModel - - >>> # Download model and configuration from huggingface.co - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> # By default, the model params will be in fp32, to illustrate the use of this method, - >>> # we'll first cast to fp16 and back to fp32 - >>> params = model.to_f16(params) - >>> # now cast back to fp32 - >>> params = model.to_fp32(params) - ```""" - return self._cast_floating_to(params, jnp.float32, mask) - - def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): - r""" - Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the - `params` in place. - - This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full - half-precision training or to save weights in float16 for inference in order to save memory and improve speed. - - Arguments: - params (`Union[Dict, FrozenDict]`): - A `PyTree` of model parameters. - mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip - - Examples: - - ```python - >>> from diffusers import FlaxUNet2DConditionModel - - >>> # load model - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> # By default, the model params will be in fp32, to cast these to float16 - >>> params = model.to_fp16(params) - >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) - >>> # then pass the mask as follows - >>> from flax import traverse_util - - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> flat_params = traverse_util.flatten_dict(params) - >>> mask = { - ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) - ... for path in flat_params - ... } - >>> mask = traverse_util.unflatten_dict(mask) - >>> params = model.to_fp16(params, mask) - ```""" - return self._cast_floating_to(params, jnp.float16, mask) - - def init_weights(self, rng: jax.random.KeyArray) -> Dict: - raise NotImplementedError(f"init_weights method has to be implemented for {self}") - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, os.PathLike], - dtype: jnp.dtype = jnp.float32, - *model_args, - **kwargs, - ): - r""" - Instantiate a pretrained flax model from a pre-trained model configuration. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids are namespaced under a user or organization name, like - `runwayml/stable-diffusion-v1-5`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], - e.g., `./my_model_directory/`. - dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): - The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. - - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - - If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and - [`~ModelMixin.to_bf16`]. - model_args (sequence of positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to - a configuration attribute will be used to override said attribute with the supplied `kwargs` - value. Remaining keys that do not correspond to any configuration attribute will be passed to the - underlying model's `__init__` function. - - Examples: - - ```python - >>> from diffusers import FlaxUNet2DConditionModel - - >>> # Download model and configuration from huggingface.co and cache. - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). - >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") - ```""" - config = kwargs.pop("config", None) - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) - from_pt = kwargs.pop("from_pt", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - - user_agent = { - "diffusers": __version__, - "file_type": "model", - "framework": "flax", - } - - # Load config if we don't provide a configuration - config_path = config if config is not None else pretrained_model_name_or_path - model, model_kwargs = cls.from_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - # model args - dtype=dtype, - **kwargs, - ) - - # Load model - pretrained_path_with_subfolder = ( - pretrained_model_name_or_path - if subfolder is None - else os.path.join(pretrained_model_name_or_path, subfolder) - ) - if os.path.isdir(pretrained_path_with_subfolder): - if from_pt: - if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): - raise EnvironmentError( - f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " - ) - model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) - elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): - # Load from a Flax checkpoint - model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) - # Check if pytorch weights exist instead - elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): - raise EnvironmentError( - f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" - " using `from_pt=True`." - ) - else: - raise EnvironmentError( - f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " - f"{pretrained_path_with_subfolder}." - ) - else: - try: - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision, - ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." - ) - except HTTPError as err: - raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" - f"{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" - " internet connection or see how to run the library in offline mode at" - " 'https://huggingface.co/docs/transformers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." - ) - - if from_pt: - if is_torch_available(): - from .modeling_utils import load_state_dict - else: - raise EnvironmentError( - "Can't load the model in PyTorch format because PyTorch is not installed. " - "Please, install PyTorch or use native Flax weights." - ) - - # Step 1: Get the pytorch file - pytorch_model_file = load_state_dict(model_file) - - # Step 2: Convert the weights - state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) - else: - try: - with open(model_file, "rb") as state_f: - state = from_bytes(cls, state_f.read()) - except (UnpicklingError, msgpack.exceptions.ExtraData) as e: - try: - with open(model_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") - # make sure all arrays are stored as jnp.ndarray - # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: - # https://github.com/google/flax/issues/1261 - state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) - - # flatten dicts - state = flatten_dict(state) - - params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) - required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) - - shape_state = flatten_dict(unfreeze(params_shape_tree)) - - missing_keys = required_params - set(state.keys()) - unexpected_keys = set(state.keys()) - required_params - - if missing_keys: - logger.warning( - f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " - "Make sure to call model.init_weights to initialize the missing weights." - ) - cls._missing_keys = missing_keys - - for key in state.keys(): - if key in shape_state and state[key].shape != shape_state[key].shape: - raise ValueError( - f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " - f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " - ) - - # remove unexpected keys to not be saved again - for unexpected_key in unexpected_keys: - del state[unexpected_key] - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" - " with another architecture." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - else: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" - f" was trained on, you can already use {model.__class__.__name__} for predictions without further" - " training." - ) - - return model, unflatten_dict(state) - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - params: Union[Dict, FrozenDict], - is_main_process: bool = True, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~FlaxModelMixin.from_pretrained`]` class method - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - params (`Union[Dict, FrozenDict]`): - A `PyTree` of model parameters. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - """ - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - model_to_save = self - - # Attach architecture to the config - # Save the config - if is_main_process: - model_to_save.save_config(save_directory) - - # save model - output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) - with open(output_model_file, "wb") as f: - model_bytes = to_bytes(params) - f.write(model_bytes) - - logger.info(f"Model weights saved in {output_model_file}") diff --git a/my_diffusers/models/modeling_pytorch_flax_utils.py b/my_diffusers/models/modeling_pytorch_flax_utils.py deleted file mode 100644 index 80975641412b79e5f75040c14a5a18df429acfe6..0000000000000000000000000000000000000000 --- a/my_diffusers/models/modeling_pytorch_flax_utils.py +++ /dev/null @@ -1,155 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch - Flax general utilities.""" - -from pickle import UnpicklingError - -import jax -import jax.numpy as jnp -import numpy as np -from flax.serialization import from_bytes -from flax.traverse_util import flatten_dict - -from ..utils import logging - - -logger = logging.get_logger(__name__) - - -##################### -# Flax => PyTorch # -##################### - - -# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 -def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): - try: - with open(model_file, "rb") as flax_state_f: - flax_state = from_bytes(None, flax_state_f.read()) - except UnpicklingError as e: - try: - with open(model_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") - - return load_flax_weights_in_pytorch_model(pt_model, flax_state) - - -def load_flax_weights_in_pytorch_model(pt_model, flax_state): - """Load flax checkpoints in a PyTorch model""" - - try: - import torch # noqa: F401 - except ImportError: - logger.error( - "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" - " instructions." - ) - raise - - # check if we have bf16 weights - is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() - if any(is_type_bf16): - # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 - - # and bf16 is not fully supported in PT yet. - logger.warning( - "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " - "before loading those in PyTorch model." - ) - flax_state = jax.tree_util.tree_map( - lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state - ) - - pt_model.base_model_prefix = "" - - flax_state_dict = flatten_dict(flax_state, sep=".") - pt_model_dict = pt_model.state_dict() - - # keep track of unexpected & missing keys - unexpected_keys = [] - missing_keys = set(pt_model_dict.keys()) - - for flax_key_tuple, flax_tensor in flax_state_dict.items(): - flax_key_tuple_array = flax_key_tuple.split(".") - - if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: - flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] - flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) - elif flax_key_tuple_array[-1] == "kernel": - flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] - flax_tensor = flax_tensor.T - elif flax_key_tuple_array[-1] == "scale": - flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] - - if "time_embedding" not in flax_key_tuple_array: - for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): - flax_key_tuple_array[i] = ( - flax_key_tuple_string.replace("_0", ".0") - .replace("_1", ".1") - .replace("_2", ".2") - .replace("_3", ".3") - ) - - flax_key = ".".join(flax_key_tuple_array) - - if flax_key in pt_model_dict: - if flax_tensor.shape != pt_model_dict[flax_key].shape: - raise ValueError( - f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " - f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." - ) - else: - # add weight to pytorch dict - flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor - pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) - # remove from missing keys - missing_keys.remove(flax_key) - else: - # weight is not expected by PyTorch model - unexpected_keys.append(flax_key) - - pt_model.load_state_dict(pt_model_dict) - - # re-transform missing_keys to list - missing_keys = list(missing_keys) - - if len(unexpected_keys) > 0: - logger.warning( - "Some weights of the Flax model were not used when initializing the PyTorch model" - f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" - f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" - " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" - f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" - " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" - " FlaxBertForSequenceClassification model)." - ) - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" - f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" - " use it for predictions and inference." - ) - - return pt_model diff --git a/my_diffusers/models/modeling_utils.py b/my_diffusers/models/modeling_utils.py deleted file mode 100644 index 87686595f766fbff85e064059e1b46850a23059f..0000000000000000000000000000000000000000 --- a/my_diffusers/models/modeling_utils.py +++ /dev/null @@ -1,907 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import os -import warnings -from functools import partial -from typing import Callable, List, Optional, Tuple, Union - -import torch -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError -from packaging import version -from requests import HTTPError -from torch import Tensor, device - -from .. import __version__ -from ..utils import ( - CONFIG_NAME, - DEPRECATED_REVISION_ARGS, - DIFFUSERS_CACHE, - FLAX_WEIGHTS_NAME, - HF_HUB_OFFLINE, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, - SAFETENSORS_WEIGHTS_NAME, - WEIGHTS_NAME, - is_accelerate_available, - is_safetensors_available, - is_torch_version, - logging, -) - - -logger = logging.get_logger(__name__) - - -if is_torch_version(">=", "1.9.0"): - _LOW_CPU_MEM_USAGE_DEFAULT = True -else: - _LOW_CPU_MEM_USAGE_DEFAULT = False - - -if is_accelerate_available(): - import accelerate - from accelerate.utils import set_module_tensor_to_device - from accelerate.utils.versions import is_torch_version - -if is_safetensors_available(): - import safetensors - - -def get_parameter_device(parameter: torch.nn.Module): - try: - return next(parameter.parameters()).device - except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].device - - -def get_parameter_dtype(parameter: torch.nn.Module): - try: - return next(parameter.parameters()).dtype - except StopIteration: - # For torch.nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = parameter._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype - - -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): - """ - Reads a checkpoint file, returning properly formatted errors if they arise. - """ - try: - if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): - return torch.load(checkpoint_file, map_location="cpu") - else: - return safetensors.torch.load_file(checkpoint_file, device="cpu") - except Exception as e: - try: - with open(checkpoint_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError( - f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " - "model. Make sure you have saved the model properly." - ) from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from checkpoint file for '{checkpoint_file}' " - f"at '{checkpoint_file}'. " - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." - ) - - -def _load_state_dict_into_model(model_to_load, state_dict): - # Convert old format to new format if needed from a PyTorch state_dict - # copy state_dict so _load_from_state_dict can modify it - state_dict = state_dict.copy() - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix=""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(model_to_load) - - return error_msgs - - -def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: - if variant is not None: - splits = weights_name.split(".") - splits = splits[:-1] + [variant] + splits[-1:] - weights_name = ".".join(splits) - - return weights_name - - -class ModelMixin(torch.nn.Module): - r""" - Base class for all models. - - [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading - and saving models. - - - **config_name** ([`str`]) -- A filename under which the model should be stored when calling - [`~models.ModelMixin.save_pretrained`]. - """ - config_name = CONFIG_NAME - _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - _supports_gradient_checkpointing = False - - def __init__(self): - super().__init__() - - @property - def is_gradient_checkpointing(self) -> bool: - """ - Whether gradient checkpointing is activated for this model or not. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) - - def enable_gradient_checkpointing(self): - """ - Activates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - if not self._supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - self.apply(partial(self._set_gradient_checkpointing, value=True)) - - def disable_gradient_checkpointing(self): - """ - Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - if self._supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, value=False)) - - def set_use_memory_efficient_attention_xformers( - self, valid: bool, attention_op: Optional[Callable] = None - ) -> None: - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid, attention_op) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - for module in self.children(): - if isinstance(module, torch.nn.Module): - fn_recursive_set_mem_eff(module) - - def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): - r""" - Enable memory efficient attention as implemented in xformers. - - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - - Parameters: - attention_op (`Callable`, *optional*): - Override the default `None` operator for use as `op` argument to the - [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) - function of xFormers. - - Examples: - - ```py - >>> import torch - >>> from diffusers import UNet2DConditionModel - >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp - - >>> model = UNet2DConditionModel.from_pretrained( - ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 - ... ) - >>> model = model.to("cuda") - >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) - ``` - """ - self.set_use_memory_efficient_attention_xformers(True, attention_op) - - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.set_use_memory_efficient_attention_xformers(False) - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - save_function: Callable = None, - safe_serialization: bool = False, - variant: Optional[str] = None, - ): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~models.ModelMixin.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. - """ - if safe_serialization and not is_safetensors_available(): - raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") - - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - model_to_save = self - - # Attach architecture to the config - # Save the config - if is_main_process: - model_to_save.save_config(save_directory) - - # Save the model - state_dict = model_to_save.state_dict() - - weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME - weights_name = _add_variant(weights_name, variant) - - # Save the model - if safe_serialization: - safetensors.torch.save_file( - state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"} - ) - else: - torch.save(state_dict, os.path.join(save_directory, weights_name)) - - logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a pretrained pytorch model from a pre-trained model configuration. - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - from_flax (`bool`, *optional*, defaults to `False`): - Load the model weights from a Flax checkpoint save file. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the - same device. - - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading by not initializing the weights and only loading the pre-trained weights. This - also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the - model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - setting this argument to `True` will raise an error. - variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) - force_download = kwargs.pop("force_download", False) - from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) - subfolder = kwargs.pop("subfolder", None) - device_map = kwargs.pop("device_map", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - variant = kwargs.pop("variant", None) - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if device_map is not None and not is_accelerate_available(): - raise NotImplementedError( - "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" - " `device_map=None`. You can install accelerate with `pip install accelerate`." - ) - - # Check if we can handle device_map and dispatching the weights - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `device_map=None`." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" - " dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) - - user_agent = { - "diffusers": __version__, - "file_type": "model", - "framework": "pytorch", - } - - # Load config if we don't provide a configuration - config_path = pretrained_model_name_or_path - - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the - # Load model - - model_file = None - if from_flax: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=FLAX_WEIGHTS_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - config, unused_kwargs = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - **kwargs, - ) - model = cls.from_config(config, **unused_kwargs) - - # Convert the weights - from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model - - model = load_flax_checkpoint_in_pytorch_model(model, model_file) - else: - if is_safetensors_available(): - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - except: # noqa: E722 - pass - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - - if low_cpu_mem_usage: - # Instantiate model with empty weights - with accelerate.init_empty_weights(): - config, unused_kwargs = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - **kwargs, - ) - model = cls.from_config(config, **unused_kwargs) - - # if device_map is None, load the state dict and move the params from meta device to the cpu - if device_map is None: - param_device = "cpu" - state_dict = load_state_dict(model_file, variant=variant) - # move the params from meta device to cpu - missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) - if len(missing_keys) > 0: - raise ValueError( - f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" - f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" - " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" - " those weights or else make sure your checkpoint file is correct." - ) - - for param_name, param in state_dict.items(): - accepts_dtype = "dtype" in set( - inspect.signature(set_module_tensor_to_device).parameters.keys() - ) - if accepts_dtype: - set_module_tensor_to_device( - model, param_name, param_device, value=param, dtype=torch_dtype - ) - else: - set_module_tensor_to_device(model, param_name, param_device, value=param) - else: # else let accelerate handle loading and dispatching. - # Load weights and dispatch according to the device_map - # by deafult the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) - - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - else: - config, unused_kwargs = cls.load_config( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - device_map=device_map, - **kwargs, - ) - model = cls.from_config(config, **unused_kwargs) - - state_dict = load_state_dict(model_file, variant=variant) - - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( - model, - state_dict, - model_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - ) - - loading_info = { - "missing_keys": missing_keys, - "unexpected_keys": unexpected_keys, - "mismatched_keys": mismatched_keys, - "error_msgs": error_msgs, - } - - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - elif torch_dtype is not None: - model = model.to(torch_dtype) - - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() - if output_loading_info: - return model, loading_info - - return model - - @classmethod - def _load_pretrained_model( - cls, - model, - state_dict, - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, - ): - # Retrieve missing & unexpected_keys - model_state_dict = model.state_dict() - loaded_keys = [k for k in state_dict.keys()] - - expected_keys = list(model_state_dict.keys()) - - original_loaded_keys = loaded_keys - - missing_keys = list(set(expected_keys) - set(loaded_keys)) - unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - - # Make sure we are able to load base models as well as derived models (with heads) - model_to_load = model - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( - state_dict, - model_state_dict, - original_loaded_keys, - ignore_mismatched_sizes, - ) - error_msgs = _load_state_dict_into_model(model_to_load, state_dict) - - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if "size mismatch" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - - if len(unexpected_keys) > 0: - logger.warning( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" - " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" - " identical (initializing a BertForSequenceClassification model from a" - " BertForSequenceClassification model)." - ) - else: - logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - elif len(mismatched_keys) == 0: - logger.info( - f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" - f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" - f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" - " without further training." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" - " able to use it for predictions and inference." - ) - - return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs - - @property - def device(self) -> device: - """ - `torch.device`: The device on which the module is (assuming that all the module parameters are on the same - device). - """ - return get_parameter_device(self) - - @property - def dtype(self) -> torch.dtype: - """ - `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). - """ - return get_parameter_dtype(self) - - def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: - """ - Get number of (optionally, trainable or non-embeddings) parameters in the module. - - Args: - only_trainable (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of trainable parameters - - exclude_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of non-embeddings parameters - - Returns: - `int`: The number of parameters. - """ - - if exclude_embeddings: - embedding_param_names = [ - f"{name}.weight" - for name, module_type in self.named_modules() - if isinstance(module_type, torch.nn.Embedding) - ] - non_embedding_parameters = [ - parameter for name, parameter in self.named_parameters() if name not in embedding_param_names - ] - return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) - else: - return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) - - -def _get_model_file( - pretrained_model_name_or_path, - *, - weights_name, - subfolder, - cache_dir, - force_download, - proxies, - resume_download, - local_files_only, - use_auth_token, - user_agent, - revision, -): - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): - return pretrained_model_name_or_path - elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): - # Load from a PyTorch checkpoint - model_file = os.path.join(pretrained_model_name_or_path, weights_name) - return model_file - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - ): - model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) - return model_file - else: - raise EnvironmentError( - f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." - ) - else: - # 1. First check if deprecated way of loading from branches is used - if ( - revision in DEPRECATED_REVISION_ARGS - and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) - and version.parse(version.parse(__version__).base_version) >= version.parse("0.15.0") - ): - try: - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=_add_variant(weights_name, revision), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision, - ) - warnings.warn( - f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", - FutureWarning, - ) - return model_file - except: # noqa: E722 - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name)}' so that the correct variant file can be added.", - FutureWarning, - ) - try: - # 2. Load model file as usual - model_file = hf_hub_download( - pretrained_model_name_or_path, - filename=weights_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, - revision=revision, - ) - return model_file - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." - ) - except HTTPError as err: - raise EnvironmentError( - f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {weights_name} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {weights_name}" - ) diff --git a/my_diffusers/models/prior_transformer.py b/my_diffusers/models/prior_transformer.py deleted file mode 100644 index b245612e6fc16800cd6f0cb2560d681f1360d60b..0000000000000000000000000000000000000000 --- a/my_diffusers/models/prior_transformer.py +++ /dev/null @@ -1,194 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import torch -import torch.nn.functional as F -from torch import nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .attention import BasicTransformerBlock -from .embeddings import TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin - - -@dataclass -class PriorTransformerOutput(BaseOutput): - """ - Args: - predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - The predicted CLIP image embedding conditioned on the CLIP text embedding input. - """ - - predicted_image_embedding: torch.FloatTensor - - -class PriorTransformer(ModelMixin, ConfigMixin): - """ - The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the - transformer predicts the image embeddings through a denoising diffusion process. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - For more details, see the original paper: https://arxiv.org/abs/2204.06125 - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. - embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP - image embeddings and text embeddings are both the same dimension. - num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the - length of the prompt after it has been tokenized. - additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the - projected hidden_states. The actual length of the used hidden_states is `num_embeddings + - additional_embeddings`. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 32, - attention_head_dim: int = 64, - num_layers: int = 20, - embedding_dim: int = 768, - num_embeddings=77, - additional_embeddings=4, - dropout: float = 0.0, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - self.additional_embeddings = additional_embeddings - - self.time_proj = Timesteps(inner_dim, True, 0) - self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) - - self.proj_in = nn.Linear(embedding_dim, inner_dim) - - self.embedding_proj = nn.Linear(embedding_dim, inner_dim) - self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) - - self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) - - self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - activation_fn="gelu", - attention_bias=True, - ) - for d in range(num_layers) - ] - ) - - self.norm_out = nn.LayerNorm(inner_dim) - self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) - - causal_attention_mask = torch.full( - [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 - ) - causal_attention_mask.triu_(1) - causal_attention_mask = causal_attention_mask[None, ...] - self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) - - self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) - - def forward( - self, - hidden_states, - timestep: Union[torch.Tensor, float, int], - proj_embedding: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.BoolTensor] = None, - return_dict: bool = True, - ): - """ - Args: - hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - x_t, the currently predicted image embeddings. - timestep (`torch.long`): - Current denoising step. - proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - Projected embedding vector the denoising process is conditioned on. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): - Hidden states of the text embeddings the denoising process is conditioned on. - attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): - Text mask for the text embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain - tuple. - - Returns: - [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: - [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - batch_size = hidden_states.shape[0] - - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(hidden_states.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) - - timesteps_projected = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might be fp16, so we need to cast here. - timesteps_projected = timesteps_projected.to(dtype=self.dtype) - time_embeddings = self.time_embedding(timesteps_projected) - - proj_embeddings = self.embedding_proj(proj_embedding) - encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) - hidden_states = self.proj_in(hidden_states) - prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) - positional_embeddings = self.positional_embedding.to(hidden_states.dtype) - - hidden_states = torch.cat( - [ - encoder_hidden_states, - proj_embeddings[:, None, :], - time_embeddings[:, None, :], - hidden_states[:, None, :], - prd_embedding, - ], - dim=1, - ) - - hidden_states = hidden_states + positional_embeddings - - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 - attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) - attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) - attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) - - for block in self.transformer_blocks: - hidden_states = block(hidden_states, attention_mask=attention_mask) - - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states[:, -1] - predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) - - if not return_dict: - return (predicted_image_embedding,) - - return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) - - def post_process_latents(self, prior_latents): - prior_latents = (prior_latents * self.clip_std) + self.clip_mean - return prior_latents diff --git a/my_diffusers/models/resnet.py b/my_diffusers/models/resnet.py deleted file mode 100644 index 7c14a7c4832dc399fc653fff9182ed3ebf3f6045..0000000000000000000000000000000000000000 --- a/my_diffusers/models/resnet.py +++ /dev/null @@ -1,766 +0,0 @@ -from functools import partial -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .attention import AdaGroupNorm - - -class Upsample1D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(x) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - - if self.use_conv: - x = self.conv(x) - - return x - - -class Downsample1D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.conv(x) - - -class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: - """ - - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - conv = None - if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward(self, hidden_states, output_size=None): - assert hidden_states.shape[1] == self.channels - - if self.use_conv_transpose: - return self.conv(hidden_states) - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states) - else: - hidden_states = self.Conv2d_0(hidden_states) - - return hidden_states - - -class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. - - Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: - """ - - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool2d(kernel_size=stride, stride=stride) - - # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed - if name == "conv": - self.Conv2d_0 = conv - self.conv = conv - elif name == "Conv2d_0": - self.conv = conv - else: - self.conv = conv - - def forward(self, hidden_states): - assert hidden_states.shape[1] == self.channels - if self.use_conv and self.padding == 0: - pad = (0, 1, 0, 1) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class FirUpsample2D(nn.Module): - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.use_conv = use_conv - self.fir_kernel = fir_kernel - self.out_channels = out_channels - - def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `upsample_2d()` followed by `Conv2d()`. - - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same - datatype as `hidden_states`. - """ - - assert isinstance(factor, int) and factor >= 1 - - # Setup filter kernel. - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - - if self.use_conv: - convH = weight.shape[2] - convW = weight.shape[3] - inC = weight.shape[1] - - pad_value = (kernel.shape[0] - factor) - (convW - 1) - - stride = (factor, factor) - # Determine data dimensions. - output_shape = ( - (hidden_states.shape[2] - 1) * factor + convH, - (hidden_states.shape[3] - 1) * factor + convW, - ) - output_padding = ( - output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, - output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, - ) - assert output_padding[0] >= 0 and output_padding[1] >= 0 - num_groups = hidden_states.shape[1] // inC - - # Transpose weights. - weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) - weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) - - inverse_conv = F.conv_transpose2d( - hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 - ) - - output = upfirdn2d_native( - inverse_conv, - torch.tensor(kernel, device=inverse_conv.device), - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), - ) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - - return output - - def forward(self, hidden_states): - if self.use_conv: - height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) - height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return height - - -class FirDownsample2D(nn.Module): - def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): - super().__init__() - out_channels = out_channels if out_channels else channels - if use_conv: - self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) - self.fir_kernel = fir_kernel - self.use_conv = use_conv - self.out_channels = out_channels - - def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): - """Fused `Conv2d()` followed by `downsample_2d()`. - Padding is performed only once at the beginning, not between the operations. The fused op is considerably more - efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of - arbitrary order. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - weight: - Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be - performed by `inChannels = x.shape[0] // numGroups`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * - factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and - same datatype as `x`. - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - # setup kernel - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * gain - - if self.use_conv: - _, _, convH, convW = weight.shape - pad_value = (kernel.shape[0] - factor) + (convW - 1) - stride_value = [factor, factor] - upfirdn_input = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - pad=((pad_value + 1) // 2, pad_value // 2), - ) - output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) - else: - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - torch.tensor(kernel, device=hidden_states.device), - down=factor, - pad=((pad_value + 1) // 2, pad_value // 2), - ) - - return output - - def forward(self, hidden_states): - if self.use_conv: - downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) - hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) - else: - hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) - - return hidden_states - - -# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead -class KDownsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, x): - x = F.pad(x, (self.pad,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) - return F.conv2d(x, weight, stride=2) - - -class KUpsample2D(nn.Module): - def __init__(self, pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 - self.pad = kernel_1d.shape[1] // 2 - 1 - self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) - - def forward(self, x): - x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) - return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) - - -class ResnetBlock2D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv2d layer. If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. - groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. - groups_out (`int`, *optional*, default to None): - The number of groups to use for the second normalization layer. if set to None, same as `groups`. - eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. - non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. - time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. - By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or - "ada_group" for a stronger conditioning with scale and shift. - kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see - [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. - output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. - use_in_shortcut (`bool`, *optional*, default to `True`): - If `True`, add a 1x1 nn.conv2d layer for skip-connection. - up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. - down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. - conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the - `conv_shortcut` output. - conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. - If None, same as `out_channels`. - """ - - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", # default, scale_shift, ada_group - kernel=None, - output_scale_factor=1.0, - use_in_shortcut=None, - up=False, - down=False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - self.time_embedding_norm = time_embedding_norm - - if groups_out is None: - groups_out = groups - - if self.time_embedding_norm == "ada_group": - self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) - else: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if temb_channels is not None: - if self.time_embedding_norm == "default": - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) - elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group": - self.time_emb_proj = None - else: - raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") - else: - self.time_emb_proj = None - - if self.time_embedding_norm == "ada_group": - self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) - else: - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - - self.dropout = torch.nn.Dropout(dropout) - conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - elif non_linearity == "gelu": - self.nonlinearity = nn.GELU() - - self.upsample = self.downsample = None - if self.up: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") - else: - self.upsample = Upsample2D(in_channels, use_conv=False) - elif self.down: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) - else: - self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") - - self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias - ) - - def forward(self, input_tensor, temb): - hidden_states = input_tensor - - if self.time_embedding_norm == "ada_group": - hidden_states = self.norm1(hidden_states, temb) - else: - hidden_states = self.norm1(hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) - - hidden_states = self.conv1(hidden_states) - - if self.time_emb_proj is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - if self.time_embedding_norm == "ada_group": - hidden_states = self.norm2(hidden_states, temb) - else: - hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class Mish(torch.nn.Module): - def forward(self, hidden_states): - return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) - - -# unet_rl.py -def rearrange_dims(tensor): - if len(tensor.shape) == 2: - return tensor[:, :, None] - if len(tensor.shape) == 3: - return tensor[:, :, None, :] - elif len(tensor.shape) == 4: - return tensor[:, :, 0, :] - else: - raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") - - -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) - self.group_norm = nn.GroupNorm(n_groups, out_channels) - self.mish = nn.Mish() - - def forward(self, x): - x = self.conv1d(x) - x = rearrange_dims(x) - x = self.group_norm(x) - x = rearrange_dims(x) - x = self.mish(x) - return x - - -# unet_rl.py -class ResidualTemporalBlock1D(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): - super().__init__() - self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) - self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) - - self.time_emb_act = nn.Mish() - self.time_emb = nn.Linear(embed_dim, out_channels) - - self.residual_conv = ( - nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() - ) - - def forward(self, x, t): - """ - Args: - x : [ batch_size x inp_channels x horizon ] - t : [ batch_size x embed_dim ] - - returns: - out : [ batch_size x out_channels x horizon ] - """ - t = self.time_emb_act(t) - t = self.time_emb(t) - out = self.conv_in(x) + rearrange_dims(t) - out = self.conv_out(out) - return out + self.residual_conv(x) - - -def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): - r"""Upsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given - filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified - `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is - a: multiple of the upsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H * factor, W * factor]` - """ - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * (gain * (factor**2)) - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, - kernel.to(device=hidden_states.device), - up=factor, - pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), - ) - return output - - -def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): - r"""Downsample2D a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the - given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the - specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its - shape is a multiple of the downsampling factor. - - Args: - hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - kernel: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). - - Returns: - output: Tensor of the shape `[N, C, H // factor, W // factor]` - """ - - assert isinstance(factor, int) and factor >= 1 - if kernel is None: - kernel = [1] * factor - - kernel = torch.tensor(kernel, dtype=torch.float32) - if kernel.ndim == 1: - kernel = torch.outer(kernel, kernel) - kernel /= torch.sum(kernel) - - kernel = kernel * gain - pad_value = kernel.shape[0] - factor - output = upfirdn2d_native( - hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) - ) - return output - - -def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): - up_x = up_y = up - down_x = down_y = down - pad_x0 = pad_y0 = pad[0] - pad_x1 = pad_y1 = pad[1] - - _, channel, in_h, in_w = tensor.shape - tensor = tensor.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = tensor.shape - kernel_h, kernel_w = kernel.shape - - out = tensor.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - out = out.to(tensor.device) # Move back to mps if necessary - out = out[ - :, - max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), - :, - ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) diff --git a/my_diffusers/models/resnet_flax.py b/my_diffusers/models/resnet_flax.py deleted file mode 100644 index 9a391f4b947e74beda03f26e376141b2b3c21502..0000000000000000000000000000000000000000 --- a/my_diffusers/models/resnet_flax.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import flax.linen as nn -import jax -import jax.numpy as jnp - - -class FlaxUpsample2D(nn.Module): - out_channels: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv = nn.Conv( - self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - batch, height, width, channels = hidden_states.shape - hidden_states = jax.image.resize( - hidden_states, - shape=(batch, height * 2, width * 2, channels), - method="nearest", - ) - hidden_states = self.conv(hidden_states) - return hidden_states - - -class FlaxDownsample2D(nn.Module): - out_channels: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv = nn.Conv( - self.out_channels, - kernel_size=(3, 3), - strides=(2, 2), - padding=((1, 1), (1, 1)), # padding="VALID", - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim - # hidden_states = jnp.pad(hidden_states, pad_width=pad) - hidden_states = self.conv(hidden_states) - return hidden_states - - -class FlaxResnetBlock2D(nn.Module): - in_channels: int - out_channels: int = None - dropout_prob: float = 0.0 - use_nin_shortcut: bool = None - dtype: jnp.dtype = jnp.float32 - - def setup(self): - out_channels = self.in_channels if self.out_channels is None else self.out_channels - - self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) - self.conv1 = nn.Conv( - out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) - - self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) - self.dropout = nn.Dropout(self.dropout_prob) - self.conv2 = nn.Conv( - out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut - - self.conv_shortcut = None - if use_nin_shortcut: - self.conv_shortcut = nn.Conv( - out_channels, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - - def __call__(self, hidden_states, temb, deterministic=True): - residual = hidden_states - hidden_states = self.norm1(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.conv1(hidden_states) - - temb = self.time_emb_proj(nn.swish(temb)) - temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) - hidden_states = hidden_states + temb - - hidden_states = self.norm2(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.dropout(hidden_states, deterministic) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - residual = self.conv_shortcut(residual) - - return hidden_states + residual diff --git a/my_diffusers/models/transformer_2d.py b/my_diffusers/models/transformer_2d.py deleted file mode 100644 index 2515c54bc2271fb588b5a4b42737ee3168be5e81..0000000000000000000000000000000000000000 --- a/my_diffusers/models/transformer_2d.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import BaseOutput, deprecate -from .attention import BasicTransformerBlock -from .embeddings import PatchEmbed -from .modeling_utils import ModelMixin - - -@dataclass -class Transformer2DModelOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions - for the unnoised latent pixels. - """ - - sample: torch.FloatTensor - - -class Transformer2DModel(ModelMixin, ConfigMixin): - """ - Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual - embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. - Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. - attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. - """ - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: Optional[int] = None, - num_vector_embeds: Optional[int] = None, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` - # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = (in_channels is not None) and (patch_size is None) - self.is_input_vectorized = num_vector_embeds is not None - self.is_input_patches = in_channels is not None and patch_size is not None - - if norm_type == "layer_norm" and num_embeds_ada_norm is not None: - deprecation_message = ( - f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" - " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." - " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" - " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" - " would be very nice if you could open a Pull request for the `transformer/config.json` file" - ) - deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) - norm_type = "ada_norm" - - if self.is_input_continuous and self.is_input_vectorized: - raise ValueError( - f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is None." - ) - elif self.is_input_vectorized and self.is_input_patches: - raise ValueError( - f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" - " sure that either `num_vector_embeds` or `num_patches` is None." - ) - elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: - raise ValueError( - f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" - f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." - ) - - # 2. Define input layers - if self.is_input_continuous: - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" - assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" - - self.height = sample_size - self.width = sample_size - self.num_vector_embeds = num_vector_embeds - self.num_latent_pixels = self.height * self.width - - self.latent_image_embedding = ImagePositionalEmbeddings( - num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width - ) - elif self.is_input_patches: - assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" - - self.height = sample_size - self.width = sample_size - - self.patch_size = patch_size - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - ) - - # 3. Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - if self.is_input_continuous: - # TODO: should use out_channels for continous projections - if use_linear_projection: - self.proj_out = nn.Linear(inner_dim, in_channels) - else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - elif self.is_input_vectorized: - self.norm_out = nn.LayerNorm(inner_dim) - self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - elif self.is_input_patches: - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) - self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, - return_dict: bool = True, - ): - """ - Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states - encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. - class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels - conditioning. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # 1. Input - if self.is_input_continuous: - batch, _, height, width = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) - elif self.is_input_vectorized: - hidden_states = self.latent_image_embedding(hidden_states) - elif self.is_input_patches: - hidden_states = self.pos_embed(hidden_states) - - # 2. Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - cross_attention_kwargs=cross_attention_kwargs, - class_labels=class_labels, - ) - - # 3. Output - if self.is_input_continuous: - if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - - output = hidden_states + residual - elif self.is_input_vectorized: - hidden_states = self.norm_out(hidden_states) - logits = self.out(hidden_states) - # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) - logits = logits.permute(0, 2, 1) - - # log(p(x_0)) - output = F.log_softmax(logits.double(), dim=1).float() - elif self.is_input_patches: - # TODO: cleanup! - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] - hidden_states = self.proj_out_2(hidden_states) - - # unpatchify - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) diff --git a/my_diffusers/models/unet_1d.py b/my_diffusers/models/unet_1d.py deleted file mode 100644 index 19166a2b260787381a36d70a6db7ca479a1e3107..0000000000000000000000000000000000000000 --- a/my_diffusers/models/unet_1d.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block - - -@dataclass -class UNet1DOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): - Hidden states output. Output of last layer of model. - """ - - sample: torch.FloatTensor - - -class UNet1DModel(ModelMixin, ConfigMixin): - r""" - UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. - in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. - time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`False`): Whether to flip sin to cos for fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(32, 32, 64)`): Tuple of block output channels. - mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. - out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. - act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. - norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. - layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. - downsample_each_block (`int`, *optional*, defaults to False: - experimental feature for using a UNet without upsampling. - """ - - @register_to_config - def __init__( - self, - sample_size: int = 65536, - sample_rate: Optional[int] = None, - in_channels: int = 2, - out_channels: int = 2, - extra_in_channels: int = 0, - time_embedding_type: str = "fourier", - flip_sin_to_cos: bool = True, - use_timestep_embedding: bool = False, - freq_shift: float = 0.0, - down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), - up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), - mid_block_type: Tuple[str] = "UNetMidBlock1D", - out_block_type: str = None, - block_out_channels: Tuple[int] = (32, 32, 64), - act_fn: str = None, - norm_num_groups: int = 8, - layers_per_block: int = 1, - downsample_each_block: bool = False, - ): - super().__init__() - self.sample_size = sample_size - - # time - if time_embedding_type == "fourier": - self.time_proj = GaussianFourierProjection( - embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = 2 * block_out_channels[0] - elif time_embedding_type == "positional": - self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift - ) - timestep_input_dim = block_out_channels[0] - - if use_timestep_embedding: - time_embed_dim = block_out_channels[0] * 4 - self.time_mlp = TimestepEmbedding( - in_channels=timestep_input_dim, - time_embed_dim=time_embed_dim, - act_fn=act_fn, - out_dim=block_out_channels[0], - ) - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - self.out_block = None - - # down - output_channel = in_channels - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - - if i == 0: - input_channel += extra_in_channels - - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=block_out_channels[0], - add_downsample=not is_final_block or downsample_each_block, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = get_mid_block( - mid_block_type, - in_channels=block_out_channels[-1], - mid_channels=block_out_channels[-1], - out_channels=block_out_channels[-1], - embed_dim=block_out_channels[0], - num_layers=layers_per_block, - add_downsample=downsample_each_block, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - if out_block_type is None: - final_upsample_channels = out_channels - else: - final_upsample_channels = block_out_channels[0] - - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = ( - reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels - ) - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block, - in_channels=prev_output_channel, - out_channels=output_channel, - temb_channels=block_out_channels[0], - add_upsample=not is_final_block, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) - self.out_block = get_out_block( - out_block_type=out_block_type, - num_groups_out=num_groups_out, - embed_dim=block_out_channels[0], - out_channels=out_channels, - act_fn=act_fn, - fc_dim=block_out_channels[-1] // 4, - ) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - return_dict: bool = True, - ) -> Union[UNet1DOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - timestep_embed = self.time_proj(timesteps) - if self.config.use_timestep_embedding: - timestep_embed = self.time_mlp(timestep_embed) - else: - timestep_embed = timestep_embed[..., None] - timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) - timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) - - # 2. down - down_block_res_samples = () - for downsample_block in self.down_blocks: - sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) - down_block_res_samples += res_samples - - # 3. mid - if self.mid_block: - sample = self.mid_block(sample, timestep_embed) - - # 4. up - for i, upsample_block in enumerate(self.up_blocks): - res_samples = down_block_res_samples[-1:] - down_block_res_samples = down_block_res_samples[:-1] - sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) - - # 5. post-process - if self.out_block: - sample = self.out_block(sample, timestep_embed) - - if not return_dict: - return (sample,) - - return UNet1DOutput(sample=sample) diff --git a/my_diffusers/models/unet_1d_blocks.py b/my_diffusers/models/unet_1d_blocks.py deleted file mode 100644 index a30f1f8e002e4611e9989e91114fcef810162e63..0000000000000000000000000000000000000000 --- a/my_diffusers/models/unet_1d_blocks.py +++ /dev/null @@ -1,668 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import torch -import torch.nn.functional as F -from torch import nn - -from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims - - -class DownResnetBlock1D(nn.Module): - def __init__( - self, - in_channels, - out_channels=None, - num_layers=1, - conv_shortcut=False, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.add_downsample = add_downsample - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - # there will always be at least one resnet - resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] - - for _ in range(num_layers): - resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) - - self.resnets = nn.ModuleList(resnets) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = None - - self.downsample = None - if add_downsample: - self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) - - def forward(self, hidden_states, temb=None): - output_states = () - - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.nonlinearity is not None: - hidden_states = self.nonlinearity(hidden_states) - - if self.downsample is not None: - hidden_states = self.downsample(hidden_states) - - return hidden_states, output_states - - -class UpResnetBlock1D(nn.Module): - def __init__( - self, - in_channels, - out_channels=None, - num_layers=1, - temb_channels=32, - groups=32, - groups_out=None, - non_linearity=None, - time_embedding_norm="default", - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.time_embedding_norm = time_embedding_norm - self.add_upsample = add_upsample - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - # there will always be at least one resnet - resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] - - for _ in range(num_layers): - resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) - - self.resnets = nn.ModuleList(resnets) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = None - - self.upsample = None - if add_upsample: - self.upsample = Upsample1D(out_channels, use_conv_transpose=True) - - def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): - if res_hidden_states_tuple is not None: - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) - - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - - if self.nonlinearity is not None: - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - hidden_states = self.upsample(hidden_states) - - return hidden_states - - -class ValueFunctionMidBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, embed_dim): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.embed_dim = embed_dim - - self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) - self.down1 = Downsample1D(out_channels // 2, use_conv=True) - self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) - self.down2 = Downsample1D(out_channels // 4, use_conv=True) - - def forward(self, x, temb=None): - x = self.res1(x, temb) - x = self.down1(x) - x = self.res2(x, temb) - x = self.down2(x) - return x - - -class MidResTemporalBlock1D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - embed_dim, - num_layers: int = 1, - add_downsample: bool = False, - add_upsample: bool = False, - non_linearity=None, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.add_downsample = add_downsample - - # there will always be at least one resnet - resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] - - for _ in range(num_layers): - resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) - - self.resnets = nn.ModuleList(resnets) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = None - - self.upsample = None - if add_upsample: - self.upsample = Downsample1D(out_channels, use_conv=True) - - self.downsample = None - if add_downsample: - self.downsample = Downsample1D(out_channels, use_conv=True) - - if self.upsample and self.downsample: - raise ValueError("Block cannot downsample and upsample") - - def forward(self, hidden_states, temb): - hidden_states = self.resnets[0](hidden_states, temb) - for resnet in self.resnets[1:]: - hidden_states = resnet(hidden_states, temb) - - if self.upsample: - hidden_states = self.upsample(hidden_states) - if self.downsample: - self.downsample = self.downsample(hidden_states) - - return hidden_states - - -class OutConv1DBlock(nn.Module): - def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): - super().__init__() - self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) - self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) - if act_fn == "silu": - self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": - self.final_conv1d_act = nn.Mish() - self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) - - def forward(self, hidden_states, temb=None): - hidden_states = self.final_conv1d_1(hidden_states) - hidden_states = rearrange_dims(hidden_states) - hidden_states = self.final_conv1d_gn(hidden_states) - hidden_states = rearrange_dims(hidden_states) - hidden_states = self.final_conv1d_act(hidden_states) - hidden_states = self.final_conv1d_2(hidden_states) - return hidden_states - - -class OutValueFunctionBlock(nn.Module): - def __init__(self, fc_dim, embed_dim): - super().__init__() - self.final_block = nn.ModuleList( - [ - nn.Linear(fc_dim + embed_dim, fc_dim // 2), - nn.Mish(), - nn.Linear(fc_dim // 2, 1), - ] - ) - - def forward(self, hidden_states, temb): - hidden_states = hidden_states.view(hidden_states.shape[0], -1) - hidden_states = torch.cat((hidden_states, temb), dim=-1) - for layer in self.final_block: - hidden_states = layer(hidden_states) - - return hidden_states - - -_kernels = { - "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], - "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], - "lanczos3": [ - 0.003689131001010537, - 0.015056144446134567, - -0.03399861603975296, - -0.066637322306633, - 0.13550527393817902, - 0.44638532400131226, - 0.44638532400131226, - 0.13550527393817902, - -0.066637322306633, - -0.03399861603975296, - 0.015056144446134567, - 0.003689131001010537, - ], -} - - -class Downsample1d(nn.Module): - def __init__(self, kernel="linear", pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer("kernel", kernel_1d) - - def forward(self, hidden_states): - hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) - weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) - indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) - return F.conv1d(hidden_states, weight, stride=2) - - -class Upsample1d(nn.Module): - def __init__(self, kernel="linear", pad_mode="reflect"): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) * 2 - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer("kernel", kernel_1d) - - def forward(self, hidden_states, temb=None): - hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) - weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) - indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) - return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) - - -class SelfAttention1d(nn.Module): - def __init__(self, in_channels, n_head=1, dropout_rate=0.0): - super().__init__() - self.channels = in_channels - self.group_norm = nn.GroupNorm(1, num_channels=in_channels) - self.num_heads = n_head - - self.query = nn.Linear(self.channels, self.channels) - self.key = nn.Linear(self.channels, self.channels) - self.value = nn.Linear(self.channels, self.channels) - - self.proj_attn = nn.Linear(self.channels, self.channels, 1) - - self.dropout = nn.Dropout(dropout_rate, inplace=True) - - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward(self, hidden_states): - residual = hidden_states - batch, channel_dim, seq = hidden_states.shape - - hidden_states = self.group_norm(hidden_states) - hidden_states = hidden_states.transpose(1, 2) - - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) - - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) - attention_probs = torch.softmax(attention_scores, dim=-1) - - # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(1, 2) - hidden_states = self.dropout(hidden_states) - - output = hidden_states + residual - - return output - - -class ResConvBlock(nn.Module): - def __init__(self, in_channels, mid_channels, out_channels, is_last=False): - super().__init__() - self.is_last = is_last - self.has_conv_skip = in_channels != out_channels - - if self.has_conv_skip: - self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) - - self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) - self.group_norm_1 = nn.GroupNorm(1, mid_channels) - self.gelu_1 = nn.GELU() - self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) - - if not self.is_last: - self.group_norm_2 = nn.GroupNorm(1, out_channels) - self.gelu_2 = nn.GELU() - - def forward(self, hidden_states): - residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states - - hidden_states = self.conv_1(hidden_states) - hidden_states = self.group_norm_1(hidden_states) - hidden_states = self.gelu_1(hidden_states) - hidden_states = self.conv_2(hidden_states) - - if not self.is_last: - hidden_states = self.group_norm_2(hidden_states) - hidden_states = self.gelu_2(hidden_states) - - output = hidden_states + residual - return output - - -class UNetMidBlock1D(nn.Module): - def __init__(self, mid_channels, in_channels, out_channels=None): - super().__init__() - - out_channels = in_channels if out_channels is None else out_channels - - # there is always at least one resnet - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - self.up = Upsample1d(kernel="cubic") - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.down(hidden_states) - for attn, resnet in zip(self.attentions, self.resnets): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - - hidden_states = self.up(hidden_states) - - return hidden_states - - -class AttnDownBlock1D(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.down(hidden_states) - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - - return hidden_states, (hidden_states,) - - -class DownBlock1D(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - - self.down = Downsample1d("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.down(hidden_states) - - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states, (hidden_states,) - - -class DownBlock1DNoSkip(nn.Module): - def __init__(self, out_channels, in_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = torch.cat([hidden_states, temb], dim=1) - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states, (hidden_states,) - - -class AttnUpBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - mid_channels = out_channels if mid_channels is None else mid_channels - - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), - ] - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.up = Upsample1d(kernel="cubic") - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states) - hidden_states = attn(hidden_states) - - hidden_states = self.up(hidden_states) - - return hidden_states - - -class UpBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - mid_channels = in_channels if mid_channels is None else mid_channels - - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), - ] - - self.resnets = nn.ModuleList(resnets) - self.up = Upsample1d(kernel="cubic") - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - hidden_states = self.up(hidden_states) - - return hidden_states - - -class UpBlock1DNoSkip(nn.Module): - def __init__(self, in_channels, out_channels, mid_channels=None): - super().__init__() - mid_channels = in_channels if mid_channels is None else mid_channels - - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), - ] - - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - res_hidden_states = res_hidden_states_tuple[-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - for resnet in self.resnets: - hidden_states = resnet(hidden_states) - - return hidden_states - - -def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): - if down_block_type == "DownResnetBlock1D": - return DownResnetBlock1D( - in_channels=in_channels, - num_layers=num_layers, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - ) - elif down_block_type == "DownBlock1D": - return DownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "AttnDownBlock1D": - return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "DownBlock1DNoSkip": - return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): - if up_block_type == "UpResnetBlock1D": - return UpResnetBlock1D( - in_channels=in_channels, - num_layers=num_layers, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - ) - elif up_block_type == "UpBlock1D": - return UpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "AttnUpBlock1D": - return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "UpBlock1DNoSkip": - return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) - raise ValueError(f"{up_block_type} does not exist.") - - -def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): - if mid_block_type == "MidResTemporalBlock1D": - return MidResTemporalBlock1D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - embed_dim=embed_dim, - add_downsample=add_downsample, - ) - elif mid_block_type == "ValueFunctionMidBlock1D": - return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) - elif mid_block_type == "UNetMidBlock1D": - return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) - raise ValueError(f"{mid_block_type} does not exist.") - - -def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): - if out_block_type == "OutConv1DBlock": - return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) - elif out_block_type == "ValueFunction": - return OutValueFunctionBlock(fc_dim, embed_dim) - return None diff --git a/my_diffusers/models/unet_2d.py b/my_diffusers/models/unet_2d.py deleted file mode 100644 index 8b1082a242b1e1be1d37787cf022e85c0d2db35f..0000000000000000000000000000000000000000 --- a/my_diffusers/models/unet_2d.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block - - -@dataclass -class UNet2DOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states output. Output of last layer of model. - """ - - sample: torch.FloatTensor - - -class UNet2DModel(ModelMixin, ConfigMixin): - r""" - UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. - out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`True`): Whether to flip sin to cos for fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block - types. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): - The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(224, 448, 672, 896)`): Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. - mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. - downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. - norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. - norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately - summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - """ - - @register_to_config - def __init__( - self, - sample_size: Optional[Union[int, Tuple[int, int]]] = None, - in_channels: int = 3, - out_channels: int = 3, - center_input_sample: bool = False, - time_embedding_type: str = "positional", - freq_shift: int = 0, - flip_sin_to_cos: bool = True, - down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), - up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), - block_out_channels: Tuple[int] = (224, 448, 672, 896), - layers_per_block: int = 2, - mid_block_scale_factor: float = 1, - downsample_padding: int = 1, - act_fn: str = "silu", - attention_head_dim: Optional[int] = 8, - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - resnet_time_scale_shift: str = "default", - add_attention: bool = True, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - ): - super().__init__() - - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - # input - self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) - - # time - if time_embedding_type == "fourier": - self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) - timestep_input_dim = 2 * block_out_channels[0] - elif time_embedding_type == "positional": - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - else: - self.class_embedding = None - - self.down_blocks = nn.ModuleList([]) - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=attention_head_dim, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - attn_num_head_channels=attention_head_dim, - resnet_groups=norm_num_groups, - add_attention=add_attention, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=attention_head_dim, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps - class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. - - Returns: - [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when doing class conditioning") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - skip_sample = sample - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "skip_conv"): - sample, res_samples, skip_sample = downsample_block( - hidden_states=sample, temb=emb, skip_sample=skip_sample - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, emb) - - # 5. up - skip_sample = None - for upsample_block in self.up_blocks: - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - if hasattr(upsample_block, "skip_conv"): - sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) - else: - sample = upsample_block(sample, res_samples, emb) - - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if skip_sample is not None: - sample += skip_sample - - if self.config.time_embedding_type == "fourier": - timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) - sample = sample / timesteps - - if not return_dict: - return (sample,) - - return UNet2DOutput(sample=sample) diff --git a/my_diffusers/models/unet_2d_blocks.py b/my_diffusers/models/unet_2d_blocks.py deleted file mode 100644 index 8269b77f54d47692b4053b770a900dac5d079d76..0000000000000000000000000000000000000000 --- a/my_diffusers/models/unet_2d_blocks.py +++ /dev/null @@ -1,2749 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import numpy as np -import torch -from torch import nn - -from .attention import AdaGroupNorm, AttentionBlock -from .cross_attention import CrossAttention, CrossAttnAddedKVProcessor -from .dual_transformer_2d import DualTransformer2DModel -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D -from .transformer_2d import Transformer2DModel - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "ResnetDownsampleBlock2D": - return ResnetDownsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownBlock2D": - return AttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return CrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SimpleCrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") - return SimpleCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "SkipDownBlock2D": - return SkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnSkipDownBlock2D": - return AttnSkipDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "DownEncoderBlock2D": - return DownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "AttnDownEncoderBlock2D": - return AttnDownEncoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "KDownBlock2D": - return KDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif down_block_type == "KCrossAttnDownBlock2D": - return KCrossAttnDownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - add_self_attention=True if not add_downsample else False, - ) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "ResnetUpsampleBlock2D": - return ResnetUpsampleBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") - return CrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SimpleCrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") - return SimpleCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpBlock2D": - return AttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "SkipUpBlock2D": - return SkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnSkipUpBlock2D": - return AttnSkipUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "UpDecoderBlock2D": - return UpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "AttnUpDecoderBlock2D": - return AttnUpDecoderBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "KUpBlock2D": - return KUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ) - elif up_block_type == "KCrossAttnUpBlock2D": - return KCrossAttnUpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - ) - - raise ValueError(f"{up_block_type} does not exist.") - - -class UNetMidBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = attn(hidden_states) - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class UNetMidBlock2DSimpleCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - ): - super().__init__() - - self.has_cross_attention = True - - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - self.num_heads = in_channels // self.attn_num_head_channels - - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - CrossAttention( - query_dim=in_channels, - cross_attention_dim=in_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - # resnet - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class AttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - # TODO(Patrick, William) - attention mask is not used - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class DownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnDownEncoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states - - -class AttnSkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - downsample_padding=1, - add_downsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class SkipDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - self.resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(in_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - if add_downsample: - self.resnet_down = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - down=True, - kernel="fir", - ) - self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) - else: - self.resnet_down = None - self.downsamplers = None - self.skip_conv = None - - def forward(self, hidden_states, temb=None, skip_sample=None): - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) - - if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) - for downsampler in self.downsamplers: - skip_sample = downsampler(skip_sample) - - hidden_states = self.skip_conv(skip_sample) + hidden_states - - output_states += (hidden_states,) - - return hidden_states, output_states, skip_sample - - -class ResnetDownsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class SimpleCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_downsample=True, - ): - super().__init__() - - self.has_cross_attention = True - - resnets = [] - attentions = [] - - self.attn_num_head_channels = attn_num_head_channels - self.num_heads = out_channels // self.attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - down=True, - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - output_states = () - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - for resnet, attn in zip(self.resnets, self.attentions): - # resnet - hidden_states = resnet(hidden_states, temb) - - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -class KDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 4, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - resnet_group_size: int = 32, - add_downsample=False, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=resnet_eps, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - # YiYi's comments- might be able to use FirDownsample2D, look into details later - self.downsamplers = nn.ModuleList([KDownsample2D()]) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states, output_states - - -class KCrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - cross_attention_dim: int, - dropout: float = 0.0, - num_layers: int = 4, - resnet_group_size: int = 32, - add_downsample=True, - attn_num_head_channels: int = 64, - add_self_attention: bool = False, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - temb_channels=temb_channels, - groups=groups, - groups_out=groups_out, - eps=resnet_eps, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - attentions.append( - KAttentionBlock( - out_channels, - out_channels // attn_num_head_channels, - attn_num_head_channels, - cross_attention_dim=cross_attention_dim, - temb_channels=temb_channels, - attention_bias=True, - add_self_attention=add_self_attention, - cross_attention_norm=True, - group_size=resnet_group_size, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - if add_downsample: - self.downsamplers = nn.ModuleList([KDownsample2D()]) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - ) - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - if self.downsamplers is None: - output_states += (None,) - else: - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - return hidden_states, output_states - - -class AttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - # TODO(Patrick, William) - attention mask is not used - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -class UpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnUpDecoderBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - def forward(self, hidden_states): - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class AttnSkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=np.sqrt(2.0), - upsample_padding=1, - add_upsample=True, - ): - super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(resnet_in_channels + res_skip_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - hidden_states = self.attentions[0](hidden_states) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample - - -class SkipUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), - add_upsample=True, - upsample_padding=1, - ): - super().__init__() - self.resnets = nn.ModuleList([]) - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - self.resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min((resnet_in_channels + res_skip_channels) // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) - if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() - else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 - - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) - - skip_sample = skip_sample + skip_sample_states - - hidden_states = self.resnet_up(hidden_states, temb) - - return hidden_states, skip_sample - - -class ResnetUpsampleBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - - return hidden_states - - -class SimpleCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - self.num_heads = out_channels // self.attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - CrossAttention( - query_dim=out_channels, - cross_attention_dim=out_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, - ) - ] - ) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - for resnet, attn in zip(self.resnets, self.attentions): - # resnet - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - hidden_states = resnet(hidden_states, temb) - - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) - - return hidden_states - - -class KUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 5, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - resnet_group_size: Optional[int] = 32, - add_upsample=True, - ): - super().__init__() - resnets = [] - k_in_channels = 2 * out_channels - k_out_channels = in_channels - num_layers = num_layers - 1 - - for i in range(num_layers): - in_channels = k_in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=k_out_channels if (i == num_layers - 1) else out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=groups, - groups_out=groups_out, - dropout=dropout, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([KUpsample2D()]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - res_hidden_states_tuple = res_hidden_states_tuple[-1] - if res_hidden_states_tuple is not None: - hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -class KCrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 4, - resnet_eps: float = 1e-5, - resnet_act_fn: str = "gelu", - resnet_group_size: int = 32, - attn_num_head_channels=1, # attention dim_head - cross_attention_dim: int = 768, - add_upsample: bool = True, - upcast_attention: bool = False, - ): - super().__init__() - resnets = [] - attentions = [] - - is_first_block = in_channels == out_channels == temb_channels - is_middle_block = in_channels != out_channels - add_self_attention = True if is_first_block else False - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - # in_channels, and out_channels for the block (k-unet) - k_in_channels = out_channels if is_first_block else 2 * out_channels - k_out_channels = in_channels - - num_layers = num_layers - 1 - - for i in range(num_layers): - in_channels = k_in_channels if i == 0 else out_channels - groups = in_channels // resnet_group_size - groups_out = out_channels // resnet_group_size - - if is_middle_block and (i == num_layers - 1): - conv_2d_out_channels = k_out_channels - else: - conv_2d_out_channels = None - - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - conv_2d_out_channels=conv_2d_out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=groups, - groups_out=groups_out, - dropout=dropout, - non_linearity=resnet_act_fn, - time_embedding_norm="ada_group", - conv_shortcut_bias=False, - ) - ) - attentions.append( - KAttentionBlock( - k_out_channels if (i == num_layers - 1) else out_channels, - k_out_channels // attn_num_head_channels - if (i == num_layers - 1) - else out_channels // attn_num_head_channels, - attn_num_head_channels, - cross_attention_dim=cross_attention_dim, - temb_channels=temb_channels, - attention_bias=True, - add_self_attention=add_self_attention, - cross_attention_norm=True, - upcast_attention=upcast_attention, - ) - ) - - self.resnets = nn.ModuleList(resnets) - self.attentions = nn.ModuleList(attentions) - - if add_upsample: - self.upsamplers = nn.ModuleList([KUpsample2D()]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - res_hidden_states_tuple = res_hidden_states_tuple[-1] - if res_hidden_states_tuple is not None: - hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - attention_mask, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - emb=temb, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) - - return hidden_states - - -# can potentially later be renamed to `No-feed-forward` attention -class KAttentionBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - upcast_attention: bool = False, - temb_channels: int = 768, # for ada_group_norm - add_self_attention: bool = False, - cross_attention_norm: bool = False, - group_size: int = 32, - ): - super().__init__() - self.add_self_attention = add_self_attention - - # 1. Self-Attn - if add_self_attention: - self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=None, - cross_attention_norm=None, - ) - - # 2. Cross-Attn - self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - cross_attention_norm=cross_attention_norm, - ) - - def _to_3d(self, hidden_states, height, weight): - return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) - - def _to_4d(self, hidden_states, height, weight): - return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) - - def forward( - self, - hidden_states, - encoder_hidden_states=None, - emb=None, - attention_mask=None, - cross_attention_kwargs=None, - ): - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - # 1. Self-Attention - if self.add_self_attention: - norm_hidden_states = self.norm1(hidden_states, emb) - - height, weight = norm_hidden_states.shape[2:] - norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=None, - **cross_attention_kwargs, - ) - attn_output = self._to_4d(attn_output, height, weight) - - hidden_states = attn_output + hidden_states - - # 2. Cross-Attention/None - norm_hidden_states = self.norm2(hidden_states, emb) - - height, weight = norm_hidden_states.shape[2:] - norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - **cross_attention_kwargs, - ) - attn_output = self._to_4d(attn_output, height, weight) - - hidden_states = attn_output + hidden_states - - return hidden_states diff --git a/my_diffusers/models/unet_2d_blocks_flax.py b/my_diffusers/models/unet_2d_blocks_flax.py deleted file mode 100644 index 8e9690d332c9743ff0411b170238e9c8c37699e1..0000000000000000000000000000000000000000 --- a/my_diffusers/models/unet_2d_blocks_flax.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import flax.linen as nn -import jax.numpy as jnp - -from .attention_flax import FlaxTransformer2DModel -from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D - - -class FlaxCrossAttnDownBlock2D(nn.Module): - r""" - Cross Attention 2D Downsizing block - original architecture from Unet transformers: - https://arxiv.org/abs/2103.06104 - - Parameters: - in_channels (:obj:`int`): - Input channels - out_channels (:obj:`int`): - Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of attention blocks layers - attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): - Number of attention heads of each spatial transformer block - add_downsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - out_channels: int - dropout: float = 0.0 - num_layers: int = 1 - attn_num_head_channels: int = 1 - add_downsample: bool = True - use_linear_projection: bool = False - only_cross_attention: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - resnets = [] - attentions = [] - - for i in range(self.num_layers): - in_channels = self.in_channels if i == 0 else self.out_channels - - res_block = FlaxResnetBlock2D( - in_channels=in_channels, - out_channels=self.out_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - resnets.append(res_block) - - attn_block = FlaxTransformer2DModel( - in_channels=self.out_channels, - n_heads=self.attn_num_head_channels, - d_head=self.out_channels // self.attn_num_head_channels, - depth=1, - use_linear_projection=self.use_linear_projection, - only_cross_attention=self.only_cross_attention, - dtype=self.dtype, - ) - attentions.append(attn_block) - - self.resnets = resnets - self.attentions = attentions - - if self.add_downsample: - self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) - - def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) - output_states += (hidden_states,) - - if self.add_downsample: - hidden_states = self.downsamplers_0(hidden_states) - output_states += (hidden_states,) - - return hidden_states, output_states - - -class FlaxDownBlock2D(nn.Module): - r""" - Flax 2D downsizing block - - Parameters: - in_channels (:obj:`int`): - Input channels - out_channels (:obj:`int`): - Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of attention blocks layers - add_downsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - out_channels: int - dropout: float = 0.0 - num_layers: int = 1 - add_downsample: bool = True - dtype: jnp.dtype = jnp.float32 - - def setup(self): - resnets = [] - - for i in range(self.num_layers): - in_channels = self.in_channels if i == 0 else self.out_channels - - res_block = FlaxResnetBlock2D( - in_channels=in_channels, - out_channels=self.out_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - resnets.append(res_block) - self.resnets = resnets - - if self.add_downsample: - self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) - - def __call__(self, hidden_states, temb, deterministic=True): - output_states = () - - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - output_states += (hidden_states,) - - if self.add_downsample: - hidden_states = self.downsamplers_0(hidden_states) - output_states += (hidden_states,) - - return hidden_states, output_states - - -class FlaxCrossAttnUpBlock2D(nn.Module): - r""" - Cross Attention 2D Upsampling block - original architecture from Unet transformers: - https://arxiv.org/abs/2103.06104 - - Parameters: - in_channels (:obj:`int`): - Input channels - out_channels (:obj:`int`): - Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of attention blocks layers - attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): - Number of attention heads of each spatial transformer block - add_upsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add upsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - out_channels: int - prev_output_channel: int - dropout: float = 0.0 - num_layers: int = 1 - attn_num_head_channels: int = 1 - add_upsample: bool = True - use_linear_projection: bool = False - only_cross_attention: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - resnets = [] - attentions = [] - - for i in range(self.num_layers): - res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels - resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels - - res_block = FlaxResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=self.out_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - resnets.append(res_block) - - attn_block = FlaxTransformer2DModel( - in_channels=self.out_channels, - n_heads=self.attn_num_head_channels, - d_head=self.out_channels // self.attn_num_head_channels, - depth=1, - use_linear_projection=self.use_linear_projection, - only_cross_attention=self.only_cross_attention, - dtype=self.dtype, - ) - attentions.append(attn_block) - - self.resnets = resnets - self.attentions = attentions - - if self.add_upsample: - self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) - - def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) - - hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) - - if self.add_upsample: - hidden_states = self.upsamplers_0(hidden_states) - - return hidden_states - - -class FlaxUpBlock2D(nn.Module): - r""" - Flax 2D upsampling block - - Parameters: - in_channels (:obj:`int`): - Input channels - out_channels (:obj:`int`): - Output channels - prev_output_channel (:obj:`int`): - Output channels from the previous block - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of attention blocks layers - add_downsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add downsampling layer before each final output - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - out_channels: int - prev_output_channel: int - dropout: float = 0.0 - num_layers: int = 1 - add_upsample: bool = True - dtype: jnp.dtype = jnp.float32 - - def setup(self): - resnets = [] - - for i in range(self.num_layers): - res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels - resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels - - res_block = FlaxResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=self.out_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - resnets.append(res_block) - - self.resnets = resnets - - if self.add_upsample: - self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) - - def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) - - hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - - if self.add_upsample: - hidden_states = self.upsamplers_0(hidden_states) - - return hidden_states - - -class FlaxUNetMidBlock2DCrossAttn(nn.Module): - r""" - Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 - - Parameters: - in_channels (:obj:`int`): - Input channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of attention blocks layers - attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): - Number of attention heads of each spatial transformer block - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - dropout: float = 0.0 - num_layers: int = 1 - attn_num_head_channels: int = 1 - use_linear_projection: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - # there is always at least one resnet - resnets = [ - FlaxResnetBlock2D( - in_channels=self.in_channels, - out_channels=self.in_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - ] - - attentions = [] - - for _ in range(self.num_layers): - attn_block = FlaxTransformer2DModel( - in_channels=self.in_channels, - n_heads=self.attn_num_head_channels, - d_head=self.in_channels // self.attn_num_head_channels, - depth=1, - use_linear_projection=self.use_linear_projection, - dtype=self.dtype, - ) - attentions.append(attn_block) - - res_block = FlaxResnetBlock2D( - in_channels=self.in_channels, - out_channels=self.in_channels, - dropout_prob=self.dropout, - dtype=self.dtype, - ) - resnets.append(res_block) - - self.resnets = resnets - self.attentions = attentions - - def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) - hidden_states = resnet(hidden_states, temb, deterministic=deterministic) - - return hidden_states diff --git a/my_diffusers/models/unet_2d_condition.py b/my_diffusers/models/unet_2d_condition.py deleted file mode 100644 index a5f997e9b400840e6ac55caac8890ff39fdf7fa0..0000000000000000000000000000000000000000 --- a/my_diffusers/models/unet_2d_condition.py +++ /dev/null @@ -1,654 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import UNet2DConditionLoadersMixin -from ..utils import BaseOutput, logging -from .cross_attention import AttnProcessor -from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .modeling_utils import ModelMixin -from .unet_2d_blocks import ( - CrossAttnDownBlock2D, - CrossAttnUpBlock2D, - DownBlock2D, - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - UpBlock2D, - get_down_block, - get_up_block, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor - - -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): - r""" - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the - mid block layer if `None`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, it will skip the normalization and activation layers in post-processing - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately - summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, default to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - timestep_post_act (`str, *optional*, default to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, default to `None`): - The dimension of `cond_proj` layer in timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - time_embedding_type: str = "positional", - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, - ): - super().__init__() - - self.sample_size = sample_size - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - if time_embedding_type == "fourier": - time_embed_dim = block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." - ) - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - self.conv_act = nn.SiLU() - else: - self.conv_norm_out = None - self.conv_act = None - - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) - - @property - def attn_processors(self) -> Dict[str, AttnProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): - r""" - Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - - num_slicable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_slicable_layers * [1] - - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - if down_block_additional_residuals is not None: - new_down_block_res_samples = () - - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample += down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - # 4. mid - if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - if mid_block_additional_residual is not None: - sample += mid_block_additional_residual - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) diff --git a/my_diffusers/models/unet_2d_condition_flax.py b/my_diffusers/models/unet_2d_condition_flax.py deleted file mode 100644 index a40473a25f558e3f53820b5e26c909244de1592f..0000000000000000000000000000000000000000 --- a/my_diffusers/models/unet_2d_condition_flax.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Tuple, Union - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict - -from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput -from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps -from .modeling_flax_utils import FlaxModelMixin -from .unet_2d_blocks_flax import ( - FlaxCrossAttnDownBlock2D, - FlaxCrossAttnUpBlock2D, - FlaxDownBlock2D, - FlaxUNetMidBlock2DCrossAttn, - FlaxUpBlock2D, -) - - -@flax.struct.dataclass -class FlaxUNet2DConditionOutput(BaseOutput): - """ - Args: - sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: jnp.ndarray - - -@flax_register_to_config -class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): - r""" - FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a - timestep and returns sample shaped output. - - This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - sample_size (`int`, *optional*): - The size of the input sample. - in_channels (`int`, *optional*, defaults to 4): - The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): - The number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", - "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", - "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): - The dimension of the attention heads. - cross_attention_dim (`int`, *optional*, defaults to 768): - The dimension of the cross attention features. - dropout (`float`, *optional*, defaults to 0): - Dropout probability for down, up and bottleneck blocks. - flip_sin_to_cos (`bool`, *optional*, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - - """ - - sample_size: int = 32 - in_channels: int = 4 - out_channels: int = 4 - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ) - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") - only_cross_attention: Union[bool, Tuple[bool]] = False - block_out_channels: Tuple[int] = (320, 640, 1280, 1280) - layers_per_block: int = 2 - attention_head_dim: Union[int, Tuple[int]] = 8 - cross_attention_dim: int = 1280 - dropout: float = 0.0 - use_linear_projection: bool = False - dtype: jnp.dtype = jnp.float32 - flip_sin_to_cos: bool = True - freq_shift: int = 0 - - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: - # init input tensors - sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) - sample = jnp.zeros(sample_shape, dtype=jnp.float32) - timesteps = jnp.ones((1,), dtype=jnp.int32) - encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] - - def setup(self): - block_out_channels = self.block_out_channels - time_embed_dim = block_out_channels[0] * 4 - - # input - self.conv_in = nn.Conv( - block_out_channels[0], - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - # time - self.time_proj = FlaxTimesteps( - block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift - ) - self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) - - only_cross_attention = self.only_cross_attention - if isinstance(only_cross_attention, bool): - only_cross_attention = (only_cross_attention,) * len(self.down_block_types) - - attention_head_dim = self.attention_head_dim - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(self.down_block_types) - - # down - down_blocks = [] - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(self.down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - if down_block_type == "CrossAttnDownBlock2D": - down_block = FlaxCrossAttnDownBlock2D( - in_channels=input_channel, - out_channels=output_channel, - dropout=self.dropout, - num_layers=self.layers_per_block, - attn_num_head_channels=attention_head_dim[i], - add_downsample=not is_final_block, - use_linear_projection=self.use_linear_projection, - only_cross_attention=only_cross_attention[i], - dtype=self.dtype, - ) - else: - down_block = FlaxDownBlock2D( - in_channels=input_channel, - out_channels=output_channel, - dropout=self.dropout, - num_layers=self.layers_per_block, - add_downsample=not is_final_block, - dtype=self.dtype, - ) - - down_blocks.append(down_block) - self.down_blocks = down_blocks - - # mid - self.mid_block = FlaxUNetMidBlock2DCrossAttn( - in_channels=block_out_channels[-1], - dropout=self.dropout, - attn_num_head_channels=attention_head_dim[-1], - use_linear_projection=self.use_linear_projection, - dtype=self.dtype, - ) - - # up - up_blocks = [] - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(self.up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - is_final_block = i == len(block_out_channels) - 1 - - if up_block_type == "CrossAttnUpBlock2D": - up_block = FlaxCrossAttnUpBlock2D( - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - num_layers=self.layers_per_block + 1, - attn_num_head_channels=reversed_attention_head_dim[i], - add_upsample=not is_final_block, - dropout=self.dropout, - use_linear_projection=self.use_linear_projection, - only_cross_attention=only_cross_attention[i], - dtype=self.dtype, - ) - else: - up_block = FlaxUpBlock2D( - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - num_layers=self.layers_per_block + 1, - add_upsample=not is_final_block, - dropout=self.dropout, - dtype=self.dtype, - ) - - up_blocks.append(up_block) - prev_output_channel = output_channel - self.up_blocks = up_blocks - - # out - self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) - self.conv_out = nn.Conv( - self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__( - self, - sample, - timesteps, - encoder_hidden_states, - return_dict: bool = True, - train: bool = False, - ) -> Union[FlaxUNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor - timestep (`jnp.ndarray` or `float` or `int`): timesteps - encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a - plain tuple. - train (`bool`, *optional*, defaults to `False`): - Use deterministic functions and disable dropout when not training. - - Returns: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. - """ - # 1. time - if not isinstance(timesteps, jnp.ndarray): - timesteps = jnp.array([timesteps], dtype=jnp.int32) - elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: - timesteps = timesteps.astype(dtype=jnp.float32) - timesteps = jnp.expand_dims(timesteps, 0) - - t_emb = self.time_proj(timesteps) - t_emb = self.time_embedding(t_emb) - - # 2. pre-process - sample = jnp.transpose(sample, (0, 2, 3, 1)) - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for down_block in self.down_blocks: - if isinstance(down_block, FlaxCrossAttnDownBlock2D): - sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) - else: - sample, res_samples = down_block(sample, t_emb, deterministic=not train) - down_block_res_samples += res_samples - - # 4. mid - sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) - - # 5. up - for up_block in self.up_blocks: - res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] - down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] - if isinstance(up_block, FlaxCrossAttnUpBlock2D): - sample = up_block( - sample, - temb=t_emb, - encoder_hidden_states=encoder_hidden_states, - res_hidden_states_tuple=res_samples, - deterministic=not train, - ) - else: - sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) - - # 6. post-process - sample = self.conv_norm_out(sample) - sample = nn.silu(sample) - sample = self.conv_out(sample) - sample = jnp.transpose(sample, (0, 3, 1, 2)) - - if not return_dict: - return (sample,) - - return FlaxUNet2DConditionOutput(sample=sample) diff --git a/my_diffusers/models/vae.py b/my_diffusers/models/vae.py deleted file mode 100644 index c5142a8f15b7dccabc53b2a7abf90b95ee499558..0000000000000000000000000000000000000000 --- a/my_diffusers/models/vae.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional - -import numpy as np -import torch -import torch.nn as nn - -from ..utils import BaseOutput, randn_tensor -from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block - - -@dataclass -class DecoderOutput(BaseOutput): - """ - Output of decoding method. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Decoded output sample of the model. Output of the last layer of the model. - """ - - sample: torch.FloatTensor - - -class Encoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - down_block_types=("DownEncoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - double_z=True, - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=None, - temb_channels=None, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attn_num_head_channels=None, - resnet_groups=norm_num_groups, - temb_channels=None, - ) - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) - - def forward(self, x): - sample = x - sample = self.conv_in(sample) - - # down - for down_block in self.down_blocks: - sample = down_block(sample) - - # middle - sample = self.mid_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - -class Decoder(nn.Module): - def __init__( - self, - in_channels=3, - out_channels=3, - up_block_types=("UpDecoderBlock2D",), - block_out_channels=(64,), - layers_per_block=2, - norm_num_groups=32, - act_fn="silu", - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attn_num_head_channels=None, - resnet_groups=norm_num_groups, - temb_channels=None, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = get_up_block( - up_block_type, - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - prev_output_channel=None, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - attn_num_head_channels=None, - temb_channels=None, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - - def forward(self, z): - sample = z - sample = self.conv_in(sample) - - # middle - sample = self.mid_block(sample) - - # up - for up_block in self.up_blocks: - sample = up_block(sample) - - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - return sample - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix - multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__( - self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True - ): - super().__init__() - self.n_e = n_e - self.vq_embed_dim = vq_embed_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.vq_embed_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) - - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like( - self.mean, device=self.parameters.device, dtype=self.parameters.dtype - ) - - def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: - # make sure sample is on the same device as the parameters and has same dtype - sample = randn_tensor( - self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype - ) - x = self.mean + self.std * sample - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - - def mode(self): - return self.mean diff --git a/my_diffusers/models/vae_flax.py b/my_diffusers/models/vae_flax.py deleted file mode 100644 index 994e3bb06adc0318acade07a4c29f95d71552320..0000000000000000000000000000000000000000 --- a/my_diffusers/models/vae_flax.py +++ /dev/null @@ -1,866 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers - -import math -from functools import partial -from typing import Tuple - -import flax -import flax.linen as nn -import jax -import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict - -from ..configuration_utils import ConfigMixin, flax_register_to_config -from ..utils import BaseOutput -from .modeling_flax_utils import FlaxModelMixin - - -@flax.struct.dataclass -class FlaxDecoderOutput(BaseOutput): - """ - Output of decoding method. - - Args: - sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Decoded output sample of the model. Output of the last layer of the model. - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - - sample: jnp.ndarray - - -@flax.struct.dataclass -class FlaxAutoencoderKLOutput(BaseOutput): - """ - Output of AutoencoderKL encoding method. - - Args: - latent_dist (`FlaxDiagonalGaussianDistribution`): - Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. - `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. - """ - - latent_dist: "FlaxDiagonalGaussianDistribution" - - -class FlaxUpsample2D(nn.Module): - """ - Flax implementation of 2D Upsample layer - - Args: - in_channels (`int`): - Input channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - - in_channels: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv = nn.Conv( - self.in_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - batch, height, width, channels = hidden_states.shape - hidden_states = jax.image.resize( - hidden_states, - shape=(batch, height * 2, width * 2, channels), - method="nearest", - ) - hidden_states = self.conv(hidden_states) - return hidden_states - - -class FlaxDownsample2D(nn.Module): - """ - Flax implementation of 2D Downsample layer - - Args: - in_channels (`int`): - Input channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - - in_channels: int - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.conv = nn.Conv( - self.in_channels, - kernel_size=(3, 3), - strides=(2, 2), - padding="VALID", - dtype=self.dtype, - ) - - def __call__(self, hidden_states): - pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim - hidden_states = jnp.pad(hidden_states, pad_width=pad) - hidden_states = self.conv(hidden_states) - return hidden_states - - -class FlaxResnetBlock2D(nn.Module): - """ - Flax implementation of 2D Resnet Block. - - Args: - in_channels (`int`): - Input channels - out_channels (`int`): - Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - groups (:obj:`int`, *optional*, defaults to `32`): - The number of groups to use for group norm. - use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): - Whether to use `nin_shortcut`. This activates a new layer inside ResNet block - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - - in_channels: int - out_channels: int = None - dropout: float = 0.0 - groups: int = 32 - use_nin_shortcut: bool = None - dtype: jnp.dtype = jnp.float32 - - def setup(self): - out_channels = self.in_channels if self.out_channels is None else self.out_channels - - self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) - self.conv1 = nn.Conv( - out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) - self.dropout_layer = nn.Dropout(self.dropout) - self.conv2 = nn.Conv( - out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut - - self.conv_shortcut = None - if use_nin_shortcut: - self.conv_shortcut = nn.Conv( - out_channels, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - - def __call__(self, hidden_states, deterministic=True): - residual = hidden_states - hidden_states = self.norm1(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.conv1(hidden_states) - - hidden_states = self.norm2(hidden_states) - hidden_states = nn.swish(hidden_states) - hidden_states = self.dropout_layer(hidden_states, deterministic) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - residual = self.conv_shortcut(residual) - - return hidden_states + residual - - -class FlaxAttentionBlock(nn.Module): - r""" - Flax Convolutional based multi-head attention block for diffusion-based VAE. - - Parameters: - channels (:obj:`int`): - Input channels - num_head_channels (:obj:`int`, *optional*, defaults to `None`): - Number of attention heads - num_groups (:obj:`int`, *optional*, defaults to `32`): - The number of groups to use for group norm - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - - """ - channels: int - num_head_channels: int = None - num_groups: int = 32 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 - - dense = partial(nn.Dense, self.channels, dtype=self.dtype) - - self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) - self.query, self.key, self.value = dense(), dense(), dense() - self.proj_attn = dense() - - def transpose_for_scores(self, projection): - new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) - new_projection = projection.reshape(new_projection_shape) - # (B, T, H, D) -> (B, H, T, D) - new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) - return new_projection - - def __call__(self, hidden_states): - residual = hidden_states - batch, height, width, channels = hidden_states.shape - - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.reshape((batch, height * width, channels)) - - query = self.query(hidden_states) - key = self.key(hidden_states) - value = self.value(hidden_states) - - # transpose - query = self.transpose_for_scores(query) - key = self.transpose_for_scores(key) - value = self.transpose_for_scores(value) - - # compute attentions - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) - attn_weights = nn.softmax(attn_weights, axis=-1) - - # attend to values - hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) - - hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) - new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) - hidden_states = hidden_states.reshape(new_hidden_states_shape) - - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.reshape((batch, height, width, channels)) - hidden_states = hidden_states + residual - return hidden_states - - -class FlaxDownEncoderBlock2D(nn.Module): - r""" - Flax Resnet blocks-based Encoder block for diffusion-based VAE. - - Parameters: - in_channels (:obj:`int`): - Input channels - out_channels (:obj:`int`): - Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of Resnet layer block - resnet_groups (:obj:`int`, *optional*, defaults to `32`): - The number of groups to use for the Resnet block group norm - add_downsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add downsample layer - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - out_channels: int - dropout: float = 0.0 - num_layers: int = 1 - resnet_groups: int = 32 - add_downsample: bool = True - dtype: jnp.dtype = jnp.float32 - - def setup(self): - resnets = [] - for i in range(self.num_layers): - in_channels = self.in_channels if i == 0 else self.out_channels - - res_block = FlaxResnetBlock2D( - in_channels=in_channels, - out_channels=self.out_channels, - dropout=self.dropout, - groups=self.resnet_groups, - dtype=self.dtype, - ) - resnets.append(res_block) - self.resnets = resnets - - if self.add_downsample: - self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic=True): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, deterministic=deterministic) - - if self.add_downsample: - hidden_states = self.downsamplers_0(hidden_states) - - return hidden_states - - -class FlaxUpDecoderBlock2D(nn.Module): - r""" - Flax Resnet blocks-based Decoder block for diffusion-based VAE. - - Parameters: - in_channels (:obj:`int`): - Input channels - out_channels (:obj:`int`): - Output channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of Resnet layer block - resnet_groups (:obj:`int`, *optional*, defaults to `32`): - The number of groups to use for the Resnet block group norm - add_upsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add upsample layer - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - out_channels: int - dropout: float = 0.0 - num_layers: int = 1 - resnet_groups: int = 32 - add_upsample: bool = True - dtype: jnp.dtype = jnp.float32 - - def setup(self): - resnets = [] - for i in range(self.num_layers): - in_channels = self.in_channels if i == 0 else self.out_channels - res_block = FlaxResnetBlock2D( - in_channels=in_channels, - out_channels=self.out_channels, - dropout=self.dropout, - groups=self.resnet_groups, - dtype=self.dtype, - ) - resnets.append(res_block) - - self.resnets = resnets - - if self.add_upsample: - self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) - - def __call__(self, hidden_states, deterministic=True): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, deterministic=deterministic) - - if self.add_upsample: - hidden_states = self.upsamplers_0(hidden_states) - - return hidden_states - - -class FlaxUNetMidBlock2D(nn.Module): - r""" - Flax Unet Mid-Block module. - - Parameters: - in_channels (:obj:`int`): - Input channels - dropout (:obj:`float`, *optional*, defaults to 0.0): - Dropout rate - num_layers (:obj:`int`, *optional*, defaults to 1): - Number of Resnet layer block - resnet_groups (:obj:`int`, *optional*, defaults to `32`): - The number of groups to use for the Resnet and Attention block group norm - attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): - Number of attention heads for each attention block - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int - dropout: float = 0.0 - num_layers: int = 1 - resnet_groups: int = 32 - attn_num_head_channels: int = 1 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) - - # there is always at least one resnet - resnets = [ - FlaxResnetBlock2D( - in_channels=self.in_channels, - out_channels=self.in_channels, - dropout=self.dropout, - groups=resnet_groups, - dtype=self.dtype, - ) - ] - - attentions = [] - - for _ in range(self.num_layers): - attn_block = FlaxAttentionBlock( - channels=self.in_channels, - num_head_channels=self.attn_num_head_channels, - num_groups=resnet_groups, - dtype=self.dtype, - ) - attentions.append(attn_block) - - res_block = FlaxResnetBlock2D( - in_channels=self.in_channels, - out_channels=self.in_channels, - dropout=self.dropout, - groups=resnet_groups, - dtype=self.dtype, - ) - resnets.append(res_block) - - self.resnets = resnets - self.attentions = attentions - - def __call__(self, hidden_states, deterministic=True): - hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states) - hidden_states = resnet(hidden_states, deterministic=deterministic) - - return hidden_states - - -class FlaxEncoder(nn.Module): - r""" - Flax Implementation of VAE Encoder. - - This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - in_channels (:obj:`int`, *optional*, defaults to 3): - Input channels - out_channels (:obj:`int`, *optional*, defaults to 3): - Output channels - down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): - DownEncoder block type - block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): - Tuple containing the number of output channels for each block - layers_per_block (:obj:`int`, *optional*, defaults to `2`): - Number of Resnet layer for each block - norm_num_groups (:obj:`int`, *optional*, defaults to `32`): - norm num group - act_fn (:obj:`str`, *optional*, defaults to `silu`): - Activation function - double_z (:obj:`bool`, *optional*, defaults to `False`): - Whether to double the last output channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` - """ - in_channels: int = 3 - out_channels: int = 3 - down_block_types: Tuple[str] = ("DownEncoderBlock2D",) - block_out_channels: Tuple[int] = (64,) - layers_per_block: int = 2 - norm_num_groups: int = 32 - act_fn: str = "silu" - double_z: bool = False - dtype: jnp.dtype = jnp.float32 - - def setup(self): - block_out_channels = self.block_out_channels - # in - self.conv_in = nn.Conv( - block_out_channels[0], - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - # downsampling - down_blocks = [] - output_channel = block_out_channels[0] - for i, _ in enumerate(self.down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = FlaxDownEncoderBlock2D( - in_channels=input_channel, - out_channels=output_channel, - num_layers=self.layers_per_block, - resnet_groups=self.norm_num_groups, - add_downsample=not is_final_block, - dtype=self.dtype, - ) - down_blocks.append(down_block) - self.down_blocks = down_blocks - - # middle - self.mid_block = FlaxUNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_groups=self.norm_num_groups, - attn_num_head_channels=None, - dtype=self.dtype, - ) - - # end - conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels - self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) - self.conv_out = nn.Conv( - conv_out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__(self, sample, deterministic: bool = True): - # in - sample = self.conv_in(sample) - - # downsampling - for block in self.down_blocks: - sample = block(sample, deterministic=deterministic) - - # middle - sample = self.mid_block(sample, deterministic=deterministic) - - # end - sample = self.conv_norm_out(sample) - sample = nn.swish(sample) - sample = self.conv_out(sample) - - return sample - - -class FlaxDecoder(nn.Module): - r""" - Flax Implementation of VAE Decoder. - - This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - in_channels (:obj:`int`, *optional*, defaults to 3): - Input channels - out_channels (:obj:`int`, *optional*, defaults to 3): - Output channels - up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): - UpDecoder block type - block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): - Tuple containing the number of output channels for each block - layers_per_block (:obj:`int`, *optional*, defaults to `2`): - Number of Resnet layer for each block - norm_num_groups (:obj:`int`, *optional*, defaults to `32`): - norm num group - act_fn (:obj:`str`, *optional*, defaults to `silu`): - Activation function - double_z (:obj:`bool`, *optional*, defaults to `False`): - Whether to double the last output channels - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` - """ - in_channels: int = 3 - out_channels: int = 3 - up_block_types: Tuple[str] = ("UpDecoderBlock2D",) - block_out_channels: int = (64,) - layers_per_block: int = 2 - norm_num_groups: int = 32 - act_fn: str = "silu" - dtype: jnp.dtype = jnp.float32 - - def setup(self): - block_out_channels = self.block_out_channels - - # z to block_in - self.conv_in = nn.Conv( - block_out_channels[-1], - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - # middle - self.mid_block = FlaxUNetMidBlock2D( - in_channels=block_out_channels[-1], - resnet_groups=self.norm_num_groups, - attn_num_head_channels=None, - dtype=self.dtype, - ) - - # upsampling - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - up_blocks = [] - for i, _ in enumerate(self.up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - - up_block = FlaxUpDecoderBlock2D( - in_channels=prev_output_channel, - out_channels=output_channel, - num_layers=self.layers_per_block + 1, - resnet_groups=self.norm_num_groups, - add_upsample=not is_final_block, - dtype=self.dtype, - ) - up_blocks.append(up_block) - prev_output_channel = output_channel - - self.up_blocks = up_blocks - - # end - self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) - self.conv_out = nn.Conv( - self.out_channels, - kernel_size=(3, 3), - strides=(1, 1), - padding=((1, 1), (1, 1)), - dtype=self.dtype, - ) - - def __call__(self, sample, deterministic: bool = True): - # z to block_in - sample = self.conv_in(sample) - - # middle - sample = self.mid_block(sample, deterministic=deterministic) - - # upsampling - for block in self.up_blocks: - sample = block(sample, deterministic=deterministic) - - sample = self.conv_norm_out(sample) - sample = nn.swish(sample) - sample = self.conv_out(sample) - - return sample - - -class FlaxDiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - # Last axis to account for channels-last - self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) - self.logvar = jnp.clip(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = jnp.exp(0.5 * self.logvar) - self.var = jnp.exp(self.logvar) - if self.deterministic: - self.var = self.std = jnp.zeros_like(self.mean) - - def sample(self, key): - return self.mean + self.std * jax.random.normal(key, self.mean.shape) - - def kl(self, other=None): - if self.deterministic: - return jnp.array([0.0]) - - if other is None: - return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) - - return 0.5 * jnp.sum( - jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, - axis=[1, 2, 3], - ) - - def nll(self, sample, axis=[1, 2, 3]): - if self.deterministic: - return jnp.array([0.0]) - - logtwopi = jnp.log(2.0 * jnp.pi) - return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) - - def mode(self): - return self.mean - - -@flax_register_to_config -class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): - r""" - Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational - Bayes by Diederik P. Kingma and Max Welling. - - This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to - general usage and behavior. - - Finally, this model supports inherent JAX features such as: - - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) - - Parameters: - in_channels (:obj:`int`, *optional*, defaults to 3): - Input channels - out_channels (:obj:`int`, *optional*, defaults to 3): - Output channels - down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): - DownEncoder block type - up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): - UpDecoder block type - block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): - Tuple containing the number of output channels for each block - layers_per_block (:obj:`int`, *optional*, defaults to `2`): - Number of Resnet layer for each block - act_fn (:obj:`str`, *optional*, defaults to `silu`): - Activation function - latent_channels (:obj:`int`, *optional*, defaults to `4`): - Latent space channels - norm_num_groups (:obj:`int`, *optional*, defaults to `32`): - Norm num group - sample_size (:obj:`int`, *optional*, defaults to 32): - Sample input size - scaling_factor (`float`, *optional*, defaults to 0.18215): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` - """ - in_channels: int = 3 - out_channels: int = 3 - down_block_types: Tuple[str] = ("DownEncoderBlock2D",) - up_block_types: Tuple[str] = ("UpDecoderBlock2D",) - block_out_channels: Tuple[int] = (64,) - layers_per_block: int = 1 - act_fn: str = "silu" - latent_channels: int = 4 - norm_num_groups: int = 32 - sample_size: int = 32 - scaling_factor: float = 0.18215 - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.encoder = FlaxEncoder( - in_channels=self.config.in_channels, - out_channels=self.config.latent_channels, - down_block_types=self.config.down_block_types, - block_out_channels=self.config.block_out_channels, - layers_per_block=self.config.layers_per_block, - act_fn=self.config.act_fn, - norm_num_groups=self.config.norm_num_groups, - double_z=True, - dtype=self.dtype, - ) - self.decoder = FlaxDecoder( - in_channels=self.config.latent_channels, - out_channels=self.config.out_channels, - up_block_types=self.config.up_block_types, - block_out_channels=self.config.block_out_channels, - layers_per_block=self.config.layers_per_block, - norm_num_groups=self.config.norm_num_groups, - act_fn=self.config.act_fn, - dtype=self.dtype, - ) - self.quant_conv = nn.Conv( - 2 * self.config.latent_channels, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - self.post_quant_conv = nn.Conv( - self.config.latent_channels, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) - - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: - # init input tensors - sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) - sample = jnp.zeros(sample_shape, dtype=jnp.float32) - - params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) - rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} - - return self.init(rngs, sample)["params"] - - def encode(self, sample, deterministic: bool = True, return_dict: bool = True): - sample = jnp.transpose(sample, (0, 2, 3, 1)) - - hidden_states = self.encoder(sample, deterministic=deterministic) - moments = self.quant_conv(hidden_states) - posterior = FlaxDiagonalGaussianDistribution(moments) - - if not return_dict: - return (posterior,) - - return FlaxAutoencoderKLOutput(latent_dist=posterior) - - def decode(self, latents, deterministic: bool = True, return_dict: bool = True): - if latents.shape[-1] != self.config.latent_channels: - latents = jnp.transpose(latents, (0, 2, 3, 1)) - - hidden_states = self.post_quant_conv(latents) - hidden_states = self.decoder(hidden_states, deterministic=deterministic) - - hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) - - if not return_dict: - return (hidden_states,) - - return FlaxDecoderOutput(sample=hidden_states) - - def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): - posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) - if sample_posterior: - rng = self.make_rng("gaussian") - hidden_states = posterior.latent_dist.sample(rng) - else: - hidden_states = posterior.latent_dist.mode() - - sample = self.decode(hidden_states, return_dict=return_dict).sample - - if not return_dict: - return (sample,) - - return FlaxDecoderOutput(sample=sample) diff --git a/my_diffusers/models/vq_model.py b/my_diffusers/models/vq_model.py deleted file mode 100644 index 65f734dccb2dd48174a48134294b597a2c0b8ea4..0000000000000000000000000000000000000000 --- a/my_diffusers/models/vq_model.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .modeling_utils import ModelMixin -from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer - - -@dataclass -class VQEncoderOutput(BaseOutput): - """ - Output of VQModel encoding method. - - Args: - latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Encoded output sample of the model. Output of the last layer of the model. - """ - - latents: torch.FloatTensor - - -class VQModel(ModelMixin, ConfigMixin): - r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray - Kavukcuoglu. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO - num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. - vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. - scaling_factor (`float`, *optional*, defaults to `0.18215`): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 3, - sample_size: int = 32, - num_vq_embeddings: int = 256, - norm_num_groups: int = 32, - vq_embed_dim: Optional[int] = None, - scaling_factor: float = 0.18215, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=False, - ) - - vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels - - self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) - self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) - self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - ) - - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: - h = self.encoder(x) - h = self.quant_conv(h) - - if not return_dict: - return (h,) - - return VQEncoderOutput(latents=h) - - def decode( - self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True - ) -> Union[DecoderOutput, torch.FloatTensor]: - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - r""" - Args: - sample (`torch.FloatTensor`): Input sample. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. - """ - x = sample - h = self.encode(x).latents - dec = self.decode(h).sample - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) diff --git a/my_diffusers/optimization.py b/my_diffusers/optimization.py deleted file mode 100644 index d7f923b49690da5df91f2ac219dc8bac2130af1d..0000000000000000000000000000000000000000 --- a/my_diffusers/optimization.py +++ /dev/null @@ -1,293 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch optimization for diffusion models.""" - -import math -from enum import Enum -from typing import Optional, Union - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR - -from .utils import logging - - -logger = logging.get_logger(__name__) - - -class SchedulerType(Enum): - LINEAR = "linear" - COSINE = "cosine" - COSINE_WITH_RESTARTS = "cosine_with_restarts" - POLYNOMIAL = "polynomial" - CONSTANT = "constant" - CONSTANT_WITH_WARMUP = "constant_with_warmup" - - -def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): - """ - Create a schedule with a constant learning rate, using the learning rate set in optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) - - -def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): - """ - Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate - increases linearly between 0 and the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1.0, num_warmup_steps)) - return 1.0 - - return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) - - -def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): - """ - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) - ) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_cosine_schedule_with_warmup( - optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_periods (`float`, *optional*, defaults to 0.5): - The number of periods of the cosine function in a schedule (the default is to just decrease from the max - value to 0 following a half-cosine). - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases - linearly between 0 and the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_cycles (`int`, *optional*, defaults to 1): - The number of hard restarts to use. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - if progress >= 1.0: - return 0.0 - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_polynomial_decay_schedule_with_warmup( - optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 -): - """ - Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the - optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the - initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - lr_end (`float`, *optional*, defaults to 1e-7): - The end LR. - power (`float`, *optional*, defaults to 1.0): - Power factor. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT - implementation at - https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - - """ - - lr_init = optimizer.defaults["lr"] - if not (lr_init > lr_end): - raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") - - def lr_lambda(current_step: int): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - elif current_step > num_training_steps: - return lr_end / lr_init # as LambdaLR multiplies by lr_init - else: - lr_range = lr_init - lr_end - decay_steps = num_training_steps - num_warmup_steps - pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining**power + lr_end - return decay / lr_init # as LambdaLR multiplies by lr_init - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -TYPE_TO_SCHEDULER_FUNCTION = { - SchedulerType.LINEAR: get_linear_schedule_with_warmup, - SchedulerType.COSINE: get_cosine_schedule_with_warmup, - SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, - SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, - SchedulerType.CONSTANT: get_constant_schedule, - SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, -} - - -def get_scheduler( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - num_cycles: int = 1, - power: float = 1.0, -): - """ - Unified API to get any scheduler from its name. - - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) - - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) - - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles - ) - - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power - ) - - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) diff --git a/my_diffusers/pipeline_utils.py b/my_diffusers/pipeline_utils.py deleted file mode 100644 index 5c0c2337dc048dd9ef164ac5cb92e4bf5e62d764..0000000000000000000000000000000000000000 --- a/my_diffusers/pipeline_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - -# NOTE: This file is deprecated and will be removed in a future version. -# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works - -from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 diff --git a/my_diffusers/pipelines/__init__.py b/my_diffusers/pipelines/__init__.py deleted file mode 100644 index f5b7026fdc854d7e4e35dc6b7e26501042870550..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/__init__.py +++ /dev/null @@ -1,126 +0,0 @@ -from ..utils import ( - OptionalDependencyNotAvailable, - is_flax_available, - is_k_diffusion_available, - is_librosa_available, - is_onnx_available, - is_torch_available, - is_transformers_available, -) - - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_pt_objects import * # noqa F403 -else: - from .dance_diffusion import DanceDiffusionPipeline - from .ddim import DDIMPipeline - from .ddpm import DDPMPipeline - from .dit import DiTPipeline - from .latent_diffusion import LDMSuperResolutionPipeline - from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput - from .pndm import PNDMPipeline - from .repaint import RePaintPipeline - from .score_sde_ve import ScoreSdeVePipeline - from .stochastic_karras_ve import KarrasVePipeline - -try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_librosa_objects import * # noqa F403 -else: - from .audio_diffusion import AudioDiffusionPipeline, Mel - -try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 -else: - from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline - from .latent_diffusion import LDMTextToImagePipeline - from .paint_by_example import PaintByExamplePipeline - from .semantic_stable_diffusion import SemanticStableDiffusionPipeline - from .stable_diffusion import ( - CycleDiffusionPipeline, - StableDiffusionAttendAndExcitePipeline, - StableDiffusionControlNetPipeline, - StableDiffusionDepth2ImgPipeline, - StableDiffusionImageVariationPipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionInpaintPipelineLegacy, - StableDiffusionInstructPix2PixPipeline, - StableDiffusionLatentUpscalePipeline, - StableDiffusionPanoramaPipeline, - StableDiffusionPipeline, - StableDiffusionPix2PixZeroPipeline, - StableDiffusionSAGPipeline, - StableDiffusionUpscalePipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline - from .versatile_diffusion import ( - VersatileDiffusionDualGuidedPipeline, - VersatileDiffusionImageVariationPipeline, - VersatileDiffusionPipeline, - VersatileDiffusionTextToImagePipeline, - ) - from .vq_diffusion import VQDiffusionPipeline - -try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_onnx_objects import * # noqa F403 -else: - from .onnx_utils import OnnxRuntimeModel - -try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 -else: - from .stable_diffusion import ( - OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionInpaintPipeline, - OnnxStableDiffusionInpaintPipelineLegacy, - OnnxStableDiffusionPipeline, - StableDiffusionOnnxPipeline, - ) - -try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 -else: - from .stable_diffusion import StableDiffusionKDiffusionPipeline - -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_objects import * # noqa F403 -else: - from .pipeline_flax_utils import FlaxDiffusionPipeline - - -try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 -else: - from .stable_diffusion import ( - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - ) diff --git a/my_diffusers/pipelines/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 4d20312383a357f43455c382cf670b67075c0154..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/__pycache__/onnx_utils.cpython-311.pyc b/my_diffusers/pipelines/__pycache__/onnx_utils.cpython-311.pyc deleted file mode 100644 index b2aec0f60953ba06b8ac240a962f6ee3359d65ff..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/__pycache__/onnx_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/__pycache__/pipeline_flax_utils.cpython-311.pyc b/my_diffusers/pipelines/__pycache__/pipeline_flax_utils.cpython-311.pyc deleted file mode 100644 index c0b4398c9b50b5494eee379d13c3a580790744d1..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/__pycache__/pipeline_flax_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/__pycache__/pipeline_utils.cpython-311.pyc b/my_diffusers/pipelines/__pycache__/pipeline_utils.cpython-311.pyc deleted file mode 100644 index f506137b4e7d294f7b5d702c65c693711618bd2a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/__pycache__/pipeline_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/alt_diffusion/__init__.py b/my_diffusers/pipelines/alt_diffusion/__init__.py deleted file mode 100644 index dab2d8db1045ef27ff5d2234951c1488f547401b..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/alt_diffusion/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy as np -import PIL -from PIL import Image - -from ...utils import BaseOutput, is_torch_available, is_transformers_available - - -@dataclass -# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt -class AltDiffusionPipelineOutput(BaseOutput): - """ - Output class for Alt Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, or `None` if safety checking could not be performed. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: Optional[List[bool]] - - -if is_transformers_available() and is_torch_available(): - from .modeling_roberta_series import RobertaSeriesModelWithTransformation - from .pipeline_alt_diffusion import AltDiffusionPipeline - from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline diff --git a/my_diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index f3f84d33972d53caee5d0df4a35c82e39527df1f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/alt_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-311.pyc b/my_diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-311.pyc deleted file mode 100644 index 0d63339c25147aacf297d5f08db89303512d4798..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/alt_diffusion/__pycache__/modeling_roberta_series.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-311.pyc b/my_diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-311.pyc deleted file mode 100644 index 8deaed2f2bbca3c16ca95ddbbcd4a5abd275fc8a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-311.pyc b/my_diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-311.pyc deleted file mode 100644 index c5d7d7a6eb629667404cd51ee98cebb1ec8b7846..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/my_diffusers/pipelines/alt_diffusion/modeling_roberta_series.py deleted file mode 100644 index 637d6dd18698f3c6f1787c5e4d4514e4fc254908..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +++ /dev/null @@ -1,109 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch -from torch import nn -from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel -from transformers.utils import ModelOutput - - -@dataclass -class TransformationModelOutput(ModelOutput): - """ - Base class for text model's outputs that also contains a pooling of the last hidden states. - - Args: - text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The text embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - projection_state: Optional[torch.FloatTensor] = None - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -class RobertaSeriesConfig(XLMRobertaConfig): - def __init__( - self, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - project_dim=512, - pooler_fn="cls", - learn_encoder=False, - use_attention_mask=True, - **kwargs, - ): - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) - self.project_dim = project_dim - self.pooler_fn = pooler_fn - self.learn_encoder = learn_encoder - self.use_attention_mask = use_attention_mask - - -class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - base_model_prefix = "roberta" - config_class = RobertaSeriesConfig - - def __init__(self, config): - super().__init__(config) - self.roberta = XLMRobertaModel(config) - self.transformation = nn.Linear(config.hidden_size, config.project_dim) - self.post_init() - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ): - r""" """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.base_model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - projection_state = self.transformation(outputs.last_hidden_state) - - return TransformationModelOutput( - projection_state=projection_state, - last_hidden_state=outputs.last_hidden_state, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/my_diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/my_diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py deleted file mode 100644 index 71e98480ed2df04d174b771846e586193abfdb18..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ /dev/null @@ -1,711 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer - -from diffusers.utils import is_accelerate_available, is_accelerate_version - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging, randn_tensor, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import AltDiffusionPipeline - - >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) - >>> pipe = pipe.to("cuda") - - >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" - >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" - >>> image = pipe(prompt).images[0] - ``` -""" - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Alt Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`RobertaSeriesModelWithTransformation`]): - Frozen text-encoder. Alt Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`XLMRobertaTokenizer`): - Tokenizer of class - [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: RobertaSeriesModelWithTransformation, - tokenizer: XLMRobertaTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. - - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Examples: - - Returns: - [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/my_diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py deleted file mode 100644 index 1e7872e3b0819fd9186caba5508fe1bff6163ff3..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ /dev/null @@ -1,723 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer - -from diffusers.utils import is_accelerate_available, is_accelerate_version - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import requests - >>> import torch - >>> from PIL import Image - >>> from io import BytesIO - - >>> from diffusers import AltDiffusionImg2ImgPipeline - - >>> device = "cuda" - >>> model_id_or_path = "BAAI/AltDiffusion-m9" - >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) - >>> pipe = pipe.to(device) - - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - - >>> response = requests.get(url) - >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") - >>> init_image = init_image.resize((768, 512)) - - >>> # "A fantasy landscape, trending on artstation" - >>> prompt = "幻想风景, artstation" - - >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images - >>> images[0].save("幻想风景.png") - ``` -""" - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionImg2ImgPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image to image generation using Alt Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`RobertaSeriesModelWithTransformation`]): - Frozen text-encoder. Alt Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`XLMRobertaTokenizer`): - Tokenizer of class - [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: RobertaSeriesModelWithTransformation, - tokenizer: XLMRobertaTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - - return timesteps, num_inference_steps - t_start - - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) - - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents - - return latents - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Examples: - - Returns: - [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Preprocess image - image = preprocess(image) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/audio_diffusion/__init__.py b/my_diffusers/pipelines/audio_diffusion/__init__.py deleted file mode 100644 index 58554c45ea52b9897293217652db36fdace7549f..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/audio_diffusion/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .mel import Mel -from .pipeline_audio_diffusion import AudioDiffusionPipeline diff --git a/my_diffusers/pipelines/audio_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/audio_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 3cb220bbaaa8edbd2e710efa970464c943755fec..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/audio_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/audio_diffusion/__pycache__/mel.cpython-311.pyc b/my_diffusers/pipelines/audio_diffusion/__pycache__/mel.cpython-311.pyc deleted file mode 100644 index 30677e8ea36cc77351ddac12522adee026857742..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/audio_diffusion/__pycache__/mel.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-311.pyc b/my_diffusers/pipelines/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-311.pyc deleted file mode 100644 index 685722471b3ec247f1d89ae3c3f563367d18fde7..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/audio_diffusion/mel.py b/my_diffusers/pipelines/audio_diffusion/mel.py deleted file mode 100644 index 1bf28fd25a5a5d39416eaf6bfd76b7f6945f4b19..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/audio_diffusion/mel.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np # noqa: E402 - -from ...configuration_utils import ConfigMixin, register_to_config -from ...schedulers.scheduling_utils import SchedulerMixin - - -try: - import librosa # noqa: E402 - - _librosa_can_be_imported = True - _import_error = "" -except Exception as e: - _librosa_can_be_imported = False - _import_error = ( - f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." - ) - - -from PIL import Image # noqa: E402 - - -class Mel(ConfigMixin, SchedulerMixin): - """ - Parameters: - x_res (`int`): x resolution of spectrogram (time) - y_res (`int`): y resolution of spectrogram (frequency bins) - sample_rate (`int`): sample rate of audio - n_fft (`int`): number of Fast Fourier Transforms - hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res) - top_db (`int`): loudest in decibels - n_iter (`int`): number of iterations for Griffin Linn mel inversion - """ - - config_name = "mel_config.json" - - @register_to_config - def __init__( - self, - x_res: int = 256, - y_res: int = 256, - sample_rate: int = 22050, - n_fft: int = 2048, - hop_length: int = 512, - top_db: int = 80, - n_iter: int = 32, - ): - self.hop_length = hop_length - self.sr = sample_rate - self.n_fft = n_fft - self.top_db = top_db - self.n_iter = n_iter - self.set_resolution(x_res, y_res) - self.audio = None - - if not _librosa_can_be_imported: - raise ValueError(_import_error) - - def set_resolution(self, x_res: int, y_res: int): - """Set resolution. - - Args: - x_res (`int`): x resolution of spectrogram (time) - y_res (`int`): y resolution of spectrogram (frequency bins) - """ - self.x_res = x_res - self.y_res = y_res - self.n_mels = self.y_res - self.slice_size = self.x_res * self.hop_length - 1 - - def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): - """Load audio. - - Args: - audio_file (`str`): must be a file on disk due to Librosa limitation or - raw_audio (`np.ndarray`): audio as numpy array - """ - if audio_file is not None: - self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) - else: - self.audio = raw_audio - - # Pad with silence if necessary. - if len(self.audio) < self.x_res * self.hop_length: - self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) - - def get_number_of_slices(self) -> int: - """Get number of slices in audio. - - Returns: - `int`: number of spectograms audio can be sliced into - """ - return len(self.audio) // self.slice_size - - def get_audio_slice(self, slice: int = 0) -> np.ndarray: - """Get slice of audio. - - Args: - slice (`int`): slice number of audio (out of get_number_of_slices()) - - Returns: - `np.ndarray`: audio as numpy array - """ - return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] - - def get_sample_rate(self) -> int: - """Get sample rate: - - Returns: - `int`: sample rate of audio - """ - return self.sr - - def audio_slice_to_image(self, slice: int) -> Image.Image: - """Convert slice of audio to spectrogram. - - Args: - slice (`int`): slice number of audio to convert (out of get_number_of_slices()) - - Returns: - `PIL Image`: grayscale image of x_res x y_res - """ - S = librosa.feature.melspectrogram( - y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels - ) - log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) - bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) - image = Image.fromarray(bytedata) - return image - - def image_to_audio(self, image: Image.Image) -> np.ndarray: - """Converts spectrogram to audio. - - Args: - image (`PIL Image`): x_res x y_res grayscale image - - Returns: - audio (`np.ndarray`): raw audio - """ - bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) - log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db - S = librosa.db_to_power(log_S) - audio = librosa.feature.inverse.mel_to_audio( - S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter - ) - return audio diff --git a/my_diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/my_diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py deleted file mode 100644 index 8f0925ac4aaa46323b3599220c4fb7a386607298..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py +++ /dev/null @@ -1,266 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from math import acos, sin -from typing import List, Tuple, Union - -import numpy as np -import torch -from PIL import Image - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, DDPMScheduler -from ...utils import randn_tensor -from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput -from .mel import Mel - - -class AudioDiffusionPipeline(DiffusionPipeline): - """ - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None - unet ([`UNet2DConditionModel`]): UNET model - mel ([`Mel`]): transform audio <-> spectrogram - scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler - """ - - _optional_components = ["vqvae"] - - def __init__( - self, - vqvae: AutoencoderKL, - unet: UNet2DConditionModel, - mel: Mel, - scheduler: Union[DDIMScheduler, DDPMScheduler], - ): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) - - def get_input_dims(self) -> Tuple: - """Returns dimension of input image - - Returns: - `Tuple`: (height, width) - """ - input_module = self.vqvae if self.vqvae is not None else self.unet - # For backwards compatibility - sample_size = ( - (input_module.sample_size, input_module.sample_size) - if type(input_module.sample_size) == int - else input_module.sample_size - ) - return sample_size - - def get_default_steps(self) -> int: - """Returns default number of steps recommended for inference - - Returns: - `int`: number of steps - """ - return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - audio_file: str = None, - raw_audio: np.ndarray = None, - slice: int = 0, - start_step: int = 0, - steps: int = None, - generator: torch.Generator = None, - mask_start_secs: float = 0, - mask_end_secs: float = 0, - step_generator: torch.Generator = None, - eta: float = 0, - noise: torch.Tensor = None, - encoding: torch.Tensor = None, - return_dict=True, - ) -> Union[ - Union[AudioPipelineOutput, ImagePipelineOutput], - Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]], - ]: - """Generate random mel spectrogram from audio input and convert to audio. - - Args: - batch_size (`int`): number of samples to generate - audio_file (`str`): must be a file on disk due to Librosa limitation or - raw_audio (`np.ndarray`): audio as numpy array - slice (`int`): slice number of audio to convert - start_step (int): step to start from - steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM) - generator (`torch.Generator`): random number generator or None - mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start - mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end - step_generator (`torch.Generator`): random number generator used to de-noise or None - eta (`float`): parameter between 0 and 1 used with DDIM scheduler - noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None - encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim) - return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple - - Returns: - `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios - """ - - steps = steps or self.get_default_steps() - self.scheduler.set_timesteps(steps) - step_generator = step_generator or generator - # For backwards compatibility - if type(self.unet.sample_size) == int: - self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size) - input_dims = self.get_input_dims() - self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) - if noise is None: - noise = randn_tensor( - ( - batch_size, - self.unet.in_channels, - self.unet.sample_size[0], - self.unet.sample_size[1], - ), - generator=generator, - device=self.device, - ) - images = noise - mask = None - - if audio_file is not None or raw_audio is not None: - self.mel.load_audio(audio_file, raw_audio) - input_image = self.mel.audio_slice_to_image(slice) - input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( - (input_image.height, input_image.width) - ) - input_image = (input_image / 255) * 2 - 1 - input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) - - if self.vqvae is not None: - input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( - generator=generator - )[0] - input_images = self.vqvae.config.scaling_factor * input_images - - if start_step > 0: - images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) - - pixels_per_second = ( - self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length - ) - mask_start = int(mask_start_secs * pixels_per_second) - mask_end = int(mask_end_secs * pixels_per_second) - mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) - - for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): - if isinstance(self.unet, UNet2DConditionModel): - model_output = self.unet(images, t, encoding)["sample"] - else: - model_output = self.unet(images, t)["sample"] - - if isinstance(self.scheduler, DDIMScheduler): - images = self.scheduler.step( - model_output=model_output, - timestep=t, - sample=images, - eta=eta, - generator=step_generator, - )["prev_sample"] - else: - images = self.scheduler.step( - model_output=model_output, - timestep=t, - sample=images, - generator=step_generator, - )["prev_sample"] - - if mask is not None: - if mask_start > 0: - images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] - if mask_end > 0: - images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] - - if self.vqvae is not None: - # 0.18215 was scaling factor used in training to ensure unit variance - images = 1 / self.vqvae.config.scaling_factor * images - images = self.vqvae.decode(images)["sample"] - - images = (images / 2 + 0.5).clamp(0, 1) - images = images.cpu().permute(0, 2, 3, 1).numpy() - images = (images * 255).round().astype("uint8") - images = list( - map(lambda _: Image.fromarray(_[:, :, 0]), images) - if images.shape[3] == 1 - else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images) - ) - - audios = list(map(lambda _: self.mel.image_to_audio(_), images)) - if not return_dict: - return images, (self.mel.get_sample_rate(), audios) - - return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) - - @torch.no_grad() - def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: - """Reverse step process: recover noisy image from generated image. - - Args: - images (`List[PIL Image]`): list of images to encode - steps (`int`): number of encoding steps to perform (defaults to 50) - - Returns: - `np.ndarray`: noise tensor of shape (batch_size, 1, height, width) - """ - - # Only works with DDIM as this method is deterministic - assert isinstance(self.scheduler, DDIMScheduler) - self.scheduler.set_timesteps(steps) - sample = np.array( - [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] - ) - sample = (sample / 255) * 2 - 1 - sample = torch.Tensor(sample).to(self.device) - - for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): - prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps - alpha_prod_t = self.scheduler.alphas_cumprod[t] - alpha_prod_t_prev = ( - self.scheduler.alphas_cumprod[prev_timestep] - if prev_timestep >= 0 - else self.scheduler.final_alpha_cumprod - ) - beta_prod_t = 1 - alpha_prod_t - model_output = self.unet(sample, t)["sample"] - pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output - sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) - sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output - - return sample - - @staticmethod - def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: - """Spherical Linear intERPolation - - Args: - x0 (`torch.Tensor`): first tensor to interpolate between - x1 (`torch.Tensor`): seconds tensor to interpolate between - alpha (`float`): interpolation between 0 and 1 - - Returns: - `torch.Tensor`: interpolated tensor - """ - - theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) - return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) diff --git a/my_diffusers/pipelines/dance_diffusion/__init__.py b/my_diffusers/pipelines/dance_diffusion/__init__.py deleted file mode 100644 index 55d7f8ff9807083a10c844f7003cf0696d8258a3..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/dance_diffusion/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_dance_diffusion import DanceDiffusionPipeline diff --git a/my_diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 63592855295a01040f89b88f3450d06ff26c2838..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-311.pyc b/my_diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-311.pyc deleted file mode 100644 index 00e7300b622f1e336bf8a4e2798ea930249744fa..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/my_diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py deleted file mode 100644 index 018e020491ce3711117f9afe13547f12b8ddf48e..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import torch - -from ...utils import logging, randn_tensor -from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class DanceDiffusionPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of - [`IPNDMScheduler`]. - """ - - def __init__(self, unet, scheduler): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - num_inference_steps: int = 100, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - audio_length_in_s: Optional[float] = None, - return_dict: bool = True, - ) -> Union[AudioPipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of audio samples to generate. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at - the expense of slower inference. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): - The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* - `sample_size`, will be `audio_length_in_s` * `self.unet.sample_rate`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - - if audio_length_in_s is None: - audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate - - sample_size = audio_length_in_s * self.unet.sample_rate - - down_scale_factor = 2 ** len(self.unet.up_blocks) - if sample_size < 3 * down_scale_factor: - raise ValueError( - f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" - f" {3 * down_scale_factor / self.unet.sample_rate}." - ) - - original_sample_size = int(sample_size) - if sample_size % down_scale_factor != 0: - sample_size = ((audio_length_in_s * self.unet.sample_rate) // down_scale_factor + 1) * down_scale_factor - logger.info( - f"{audio_length_in_s} is increased to {sample_size / self.unet.sample_rate} so that it can be handled" - f" by the model. It will be cut to {original_sample_size / self.unet.sample_rate} after the denoising" - " process." - ) - sample_size = int(sample_size) - - dtype = next(iter(self.unet.parameters())).dtype - shape = (batch_size, self.unet.in_channels, sample_size) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype) - - # set step values - self.scheduler.set_timesteps(num_inference_steps, device=audio.device) - self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) - - for t in self.progress_bar(self.scheduler.timesteps): - # 1. predict noise model_output - model_output = self.unet(audio, t).sample - - # 2. compute previous image: x_t -> t_t-1 - audio = self.scheduler.step(model_output, t, audio).prev_sample - - audio = audio.clamp(-1, 1).float().cpu().numpy() - - audio = audio[:, :, :original_sample_size] - - if not return_dict: - return (audio,) - - return AudioPipelineOutput(audios=audio) diff --git a/my_diffusers/pipelines/ddim/__init__.py b/my_diffusers/pipelines/ddim/__init__.py deleted file mode 100644 index 85e8118e75e7e4352f8efb12552ba9fff4bf491c..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/ddim/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_ddim import DDIMPipeline diff --git a/my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 824a860a363f050068ae401f5804e76e100b90db..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-311.pyc b/my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-311.pyc deleted file mode 100644 index 11c3db10c0e7e5ccabcc316791cc5acb5175af1b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/ddim/pipeline_ddim.py b/my_diffusers/pipelines/ddim/pipeline_ddim.py deleted file mode 100644 index 0e7f2258fa999cc4cdd999a63c287f38eb7ac9a6..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/ddim/pipeline_ddim.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple, Union - -import torch - -from ...schedulers import DDIMScheduler -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class DDIMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of - [`DDPMScheduler`], or [`DDIMScheduler`]. - """ - - def __init__(self, unet, scheduler): - super().__init__() - - # make sure scheduler can always be converted to DDIM - scheduler = DDIMScheduler.from_config(scheduler.config) - - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - eta: float = 0.0, - num_inference_steps: int = 50, - use_clipped_model_output: Optional[bool] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - eta (`float`, *optional*, defaults to 0.0): - The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - use_clipped_model_output (`bool`, *optional*, defaults to `None`): - if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed - downstream to the scheduler. So use `None` for schedulers which don't support this argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - - # Sample gaussian noise to begin loop - if isinstance(self.unet.sample_size, int): - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) - else: - image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) - - # set step values - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.progress_bar(self.scheduler.timesteps): - # 1. predict noise model_output - model_output = self.unet(image, t).sample - - # 2. predict previous mean of image x_t-1 and add variance depending on eta - # eta corresponds to η in paper and should be between [0, 1] - # do x_t -> x_t-1 - image = self.scheduler.step( - model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator - ).prev_sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/ddpm/__init__.py b/my_diffusers/pipelines/ddpm/__init__.py deleted file mode 100644 index bb228ee012e80493b617b314c867ecadba7ca1ce..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/ddpm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_ddpm import DDPMPipeline diff --git a/my_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 4f6820e2a0960795f0ebfb933c80566003c505f0..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/ddpm/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-311.pyc b/my_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-311.pyc deleted file mode 100644 index 82952966aa6e9c2b217fb2e74773c4ecff046f11..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/ddpm/__pycache__/pipeline_ddpm.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/ddpm/pipeline_ddpm.py b/my_diffusers/pipelines/ddpm/pipeline_ddpm.py deleted file mode 100644 index 549dbb29d5e7898af4e8883c608c897bb5021cbe..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/ddpm/pipeline_ddpm.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import torch - -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class DDPMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of - [`DDPMScheduler`], or [`DDIMScheduler`]. - """ - - def __init__(self, unet, scheduler): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - num_inference_steps: int = 1000, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - num_inference_steps (`int`, *optional*, defaults to 1000): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - # Sample gaussian noise to begin loop - if isinstance(self.unet.sample_size, int): - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) - else: - image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) - - if self.device.type == "mps": - # randn does not work reproducibly on mps - image = randn_tensor(image_shape, generator=generator) - image = image.to(self.device) - else: - image = randn_tensor(image_shape, generator=generator, device=self.device) - - # set step values - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.progress_bar(self.scheduler.timesteps): - # 1. predict noise model_output - model_output = self.unet(image, t).sample - - # 2. compute previous image: x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/dit/__init__.py b/my_diffusers/pipelines/dit/__init__.py deleted file mode 100644 index 4ef0729cb4905d5e177ba15533375fce50084406..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/dit/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_dit import DiTPipeline diff --git a/my_diffusers/pipelines/dit/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/dit/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index a67a7c591a06e2f23e17c878e122bef35b9481d2..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/dit/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-311.pyc b/my_diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-311.pyc deleted file mode 100644 index 6cdf2626ede14372854efa72a68255e89a9af984..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/dit/pipeline_dit.py b/my_diffusers/pipelines/dit/pipeline_dit.py deleted file mode 100644 index f0d30697af43ca0781e3df8df801bd150078952f..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/dit/pipeline_dit.py +++ /dev/null @@ -1,199 +0,0 @@ -# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) -# William Peebles and Saining Xie -# -# Copyright (c) 2021 OpenAI -# MIT License -# -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, List, Optional, Tuple, Union - -import torch - -from ...models import AutoencoderKL, Transformer2DModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class DiTPipeline(DiffusionPipeline): - r""" - This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - transformer ([`Transformer2DModel`]): - Class conditioned Transformer in Diffusion model to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - scheduler ([`DDIMScheduler`]): - A scheduler to be used in combination with `dit` to denoise the encoded image latents. - """ - - def __init__( - self, - transformer: Transformer2DModel, - vae: AutoencoderKL, - scheduler: KarrasDiffusionSchedulers, - id2label: Optional[Dict[int, str]] = None, - ): - super().__init__() - self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) - - # create a imagenet -> id dictionary for easier use - self.labels = {} - if id2label is not None: - for key, value in id2label.items(): - for label in value.split(","): - self.labels[label.lstrip().rstrip()] = int(key) - self.labels = dict(sorted(self.labels.items())) - - def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: - r""" - - Map label strings, *e.g.* from ImageNet, to corresponding class ids. - - Parameters: - label (`str` or `dict` of `str`): label strings to be mapped to class ids. - - Returns: - `list` of `int`: Class ids to be processed by pipeline. - """ - - if not isinstance(label, list): - label = list(label) - - for l in label: - if l not in self.labels: - raise ValueError( - f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." - ) - - return [self.labels[l] for l in label] - - @torch.no_grad() - def __call__( - self, - class_labels: List[int], - guidance_scale: float = 4.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - num_inference_steps: int = 50, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Function invoked when calling the pipeline for generation. - - Args: - class_labels (List[int]): - List of imagenet class labels for the images to be generated. - guidance_scale (`float`, *optional*, defaults to 4.0): - Scale of the guidance signal. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - num_inference_steps (`int`, *optional*, defaults to 250): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. - """ - - batch_size = len(class_labels) - latent_size = self.transformer.config.sample_size - latent_channels = self.transformer.config.in_channels - - latents = randn_tensor( - shape=(batch_size, latent_channels, latent_size, latent_size), - generator=generator, - device=self.device, - dtype=self.transformer.dtype, - ) - latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents - - class_labels = torch.tensor(class_labels, device=self.device).reshape(-1) - class_null = torch.tensor([1000] * batch_size, device=self.device) - class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels - - # set step values - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.progress_bar(self.scheduler.timesteps): - if guidance_scale > 1: - half = latent_model_input[: len(latent_model_input) // 2] - latent_model_input = torch.cat([half, half], dim=0) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - timesteps = t - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - if isinstance(timesteps, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(latent_model_input.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(latent_model_input.shape[0]) - # predict noise model_output - noise_pred = self.transformer( - latent_model_input, timestep=timesteps, class_labels=class_labels_input - ).sample - - # perform guidance - if guidance_scale > 1: - eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] - cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) - - half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) - eps = torch.cat([half_eps, half_eps], dim=0) - - noise_pred = torch.cat([eps, rest], dim=1) - - # learned sigma - if self.transformer.config.out_channels // 2 == latent_channels: - model_output, _ = torch.split(noise_pred, latent_channels, dim=1) - else: - model_output = noise_pred - - # compute previous image: x_t -> x_t-1 - latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample - - if guidance_scale > 1: - latents, _ = latent_model_input.chunk(2, dim=0) - else: - latents = latent_model_input - - latents = 1 / self.vae.config.scaling_factor * latents - samples = self.vae.decode(latents).sample - - samples = (samples / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - samples = self.numpy_to_pil(samples) - - if not return_dict: - return (samples,) - - return ImagePipelineOutput(images=samples) diff --git a/my_diffusers/pipelines/latent_diffusion/__init__.py b/my_diffusers/pipelines/latent_diffusion/__init__.py deleted file mode 100644 index 0cce9a89bcbeaac8468d75e9d16c9d3731f738c7..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/latent_diffusion/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from ...utils import is_transformers_available -from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline - - -if is_transformers_available(): - from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline diff --git a/my_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index c1e29b0b5a84e1902ff0e79311fba419460ac122..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-311.pyc b/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-311.pyc deleted file mode 100644 index 3edda160eb8418d55a96fc1ce4fc3922eb803bbd..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-311.pyc b/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-311.pyc deleted file mode 100644 index 499be8a7952280cc2f41ecb0fc81576305f21f29..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py deleted file mode 100644 index 623b456e52b5bb0c283f943af9bc825863846afe..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ /dev/null @@ -1,724 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.utils.checkpoint -from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutput -from transformers.utils import logging - -from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class LDMTextToImagePipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) Model to encode and decode images to and from latent representations. - bert ([`LDMBertModel`]): - Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. - tokenizer (`transformers.BertTokenizer`): - Tokenizer of class - [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - - def __init__( - self, - vqvae: Union[VQModel, AutoencoderKL], - bert: PreTrainedModel, - tokenizer: PreTrainedTokenizer, - unet: Union[UNet2DModel, UNet2DConditionModel], - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - ): - super().__init__() - self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 1.0, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[Tuple, ImagePipelineOutput]: - r""" - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 1.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at - the, usually at the expense of lower image quality. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get unconditional embeddings for classifier free guidance - if guidance_scale != 1.0: - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" - ) - negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0] - - # get prompt text embeddings - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") - prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0] - - # get the initial random noise unless the user supplied it - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - self.scheduler.set_timesteps(num_inference_steps) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - - extra_kwargs = {} - if accepts_eta: - extra_kwargs["eta"] = eta - - for t in self.progress_bar(self.scheduler.timesteps): - if guidance_scale == 1.0: - # guidance_scale of 1 means no guidance - latents_input = latents - context = prompt_embeds - else: - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = torch.cat([latents] * 2) - context = torch.cat([negative_prompt_embeds, prompt_embeds]) - - # predict the noise residual - noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample - # perform guidance - if guidance_scale != 1.0: - noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample - - # scale and decode the image latents with vae - latents = 1 / self.vqvae.config.scaling_factor * latents - image = self.vqvae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) - - -################################################################################ -# Code for the text transformer model -################################################################################ -""" PyTorch LDMBERT model.""" - - -logger = logging.get_logger(__name__) - -LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "ldm-bert", - # See all LDMBert models at https://huggingface.co/models?filter=ldmbert -] - - -LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", -} - - -""" LDMBERT model configuration""" - - -class LDMBertConfig(PretrainedConfig): - model_type = "ldmbert" - keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} - - def __init__( - self, - vocab_size=30522, - max_position_embeddings=77, - encoder_layers=32, - encoder_ffn_dim=5120, - encoder_attention_heads=8, - head_dim=64, - encoder_layerdrop=0.0, - activation_function="gelu", - d_model=1280, - dropout=0.1, - attention_dropout=0.0, - activation_dropout=0.0, - init_std=0.02, - classifier_dropout=0.0, - scale_embedding=False, - use_cache=True, - pad_token_id=0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.d_model = d_model - self.encoder_ffn_dim = encoder_ffn_dim - self.encoder_layers = encoder_layers - self.encoder_attention_heads = encoder_attention_heads - self.head_dim = head_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.activation_dropout = activation_dropout - self.activation_function = activation_function - self.init_std = init_std - self.encoder_layerdrop = encoder_layerdrop - self.classifier_dropout = classifier_dropout - self.use_cache = use_cache - self.num_hidden_layers = encoder_layers - self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - - super().__init__(pad_token_id=pad_token_id, **kwargs) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert -class LDMBertAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - head_dim: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = False, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = head_dim - self.inner_dim = head_dim * num_heads - - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - - self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) - self.out_proj = nn.Linear(self.inner_dim, embed_dim) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention and past_key_value is not None: - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class LDMBertEncoderLayer(nn.Module): - def __init__(self, config: LDMBertConfig): - super().__init__() - self.embed_dim = config.d_model - self.self_attn = LDMBertAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - head_dim=config.head_dim, - dropout=config.attention_dropout, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert -class LDMBertPreTrainedModel(PreTrainedModel): - config_class = LDMBertConfig - base_model_prefix = "model" - _supports_gradient_checkpointing = True - _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LDMBertEncoder,)): - module.gradient_checkpointing = value - - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - } - return dummy_inputs - - -class LDMBertEncoder(LDMBertPreTrainedModel): - """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`LDMBertEncoderLayer`]. - - Args: - config: LDMBertConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, config: LDMBertConfig): - super().__init__(config) - - self.dropout = config.dropout - - embed_dim = config.d_model - self.padding_idx = config.pad_token_id - self.max_source_positions = config.max_position_embeddings - - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) - self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) - self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) - self.layer_norm = nn.LayerNorm(embed_dim) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - seq_len = input_shape[1] - if position_ids is None: - position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) - embed_pos = self.embed_positions(position_ids) - - hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - # check if head_mask has a correct number of layers specified if desired - if head_mask is not None: - if head_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The head_mask should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - hidden_states = self.layer_norm(hidden_states) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - - -class LDMBertModel(LDMBertPreTrainedModel): - _no_split_modules = [] - - def __init__(self, config: LDMBertConfig): - super().__init__(config) - self.model = LDMBertEncoder(config) - self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - return outputs diff --git a/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py deleted file mode 100644 index 2ecf5f24a4a72fa98ef7b15a7f8100416958af5d..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +++ /dev/null @@ -1,159 +0,0 @@ -import inspect -from typing import List, Optional, Tuple, Union - -import numpy as np -import PIL -import torch -import torch.utils.checkpoint - -from ...models import UNet2DModel, VQModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) -from ...utils import PIL_INTERPOLATION, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -class LDMSuperResolutionPipeline(DiffusionPipeline): - r""" - A pipeline for image super-resolution using Latent - - This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations. - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], - [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. - """ - - def __init__( - self, - vqvae: VQModel, - unet: UNet2DModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], - ): - super().__init__() - self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - image: Union[torch.Tensor, PIL.Image.Image] = None, - batch_size: Optional[int] = 1, - num_inference_steps: Optional[int] = 100, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ) -> Union[Tuple, ImagePipelineOutput]: - r""" - Args: - image (`torch.Tensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - batch_size (`int`, *optional*, defaults to 1): - Number of images to generate. - num_inference_steps (`int`, *optional*, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - if isinstance(image, PIL.Image.Image): - batch_size = 1 - elif isinstance(image, torch.Tensor): - batch_size = image.shape[0] - else: - raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") - - if isinstance(image, PIL.Image.Image): - image = preprocess(image) - - height, width = image.shape[-2:] - - # in_channels should be 6: 3 for latents, 3 for low resolution image - latents_shape = (batch_size, self.unet.in_channels // 2, height, width) - latents_dtype = next(self.unet.parameters()).dtype - - latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) - - image = image.to(device=self.device, dtype=latents_dtype) - - # set timesteps and move to the correct device - self.scheduler.set_timesteps(num_inference_steps, device=self.device) - timesteps_tensor = self.scheduler.timesteps - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_kwargs = {} - if accepts_eta: - extra_kwargs["eta"] = eta - - for t in self.progress_bar(timesteps_tensor): - # concat latents and low resolution image in the channel dimension. - latents_input = torch.cat([latents, image], dim=1) - latents_input = self.scheduler.scale_model_input(latents_input, t) - # predict the noise residual - noise_pred = self.unet(latents_input, t).sample - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample - - # decode the image latents with the VQVAE - image = self.vqvae.decode(latents).sample - image = torch.clamp(image, -1.0, 1.0) - image = image / 2 + 0.5 - image = image.cpu().permute(0, 2, 3, 1).numpy() - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/__init__.py b/my_diffusers/pipelines/latent_diffusion_uncond/__init__.py deleted file mode 100644 index 1b9fc5270a62bbb18d1393263101d4b9f73b7511..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/latent_diffusion_uncond/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_latent_diffusion_uncond import LDMPipeline diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index bc7399f8de8bee1a8d527edf35051bbe3b4fff07..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-311.pyc b/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-311.pyc deleted file mode 100644 index ad916c9ffc0c4c3a3c486bb112407b45c1bea8fb..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py deleted file mode 100644 index dc0200feedb114a8f2258d72c3f46036d00cd4cb..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import List, Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel, VQModel -from ...schedulers import DDIMScheduler -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class LDMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) Model to encode and decode images to and from latent representations. - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents. - """ - - def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): - super().__init__() - self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - eta: float = 0.0, - num_inference_steps: int = 50, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[Tuple, ImagePipelineOutput]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - Number of images to generate. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - - latents = randn_tensor( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - ) - latents = latents.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - - self.scheduler.set_timesteps(num_inference_steps) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - - extra_kwargs = {} - if accepts_eta: - extra_kwargs["eta"] = eta - - for t in self.progress_bar(self.scheduler.timesteps): - latent_model_input = self.scheduler.scale_model_input(latents, t) - # predict the noise residual - noise_prediction = self.unet(latent_model_input, t).sample - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample - - # decode the image latents with the VAE - image = self.vqvae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/onnx_utils.py b/my_diffusers/pipelines/onnx_utils.py deleted file mode 100644 index 07c32e4e84bfee0241733a077fef9c0dec06905e..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/onnx_utils.py +++ /dev/null @@ -1,212 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import shutil -from pathlib import Path -from typing import Optional, Union - -import numpy as np -from huggingface_hub import hf_hub_download - -from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging - - -if is_onnx_available(): - import onnxruntime as ort - - -logger = logging.get_logger(__name__) - -ORT_TO_NP_TYPE = { - "tensor(bool)": np.bool_, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - "tensor(int16)": np.int16, - "tensor(uint16)": np.uint16, - "tensor(int32)": np.int32, - "tensor(uint32)": np.uint32, - "tensor(int64)": np.int64, - "tensor(uint64)": np.uint64, - "tensor(float16)": np.float16, - "tensor(float)": np.float32, - "tensor(double)": np.float64, -} - - -class OnnxRuntimeModel: - def __init__(self, model=None, **kwargs): - logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") - self.model = model - self.model_save_dir = kwargs.get("model_save_dir", None) - self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) - - def __call__(self, **kwargs): - inputs = {k: np.array(v) for k, v in kwargs.items()} - return self.model.run(None, inputs) - - @staticmethod - def load_model(path: Union[str, Path], provider=None, sess_options=None): - """ - Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` - - Arguments: - path (`str` or `Path`): - Directory from which to load - provider(`str`, *optional*): - Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` - """ - if provider is None: - logger.info("No onnxruntime provider specified, using CPUExecutionProvider") - provider = "CPUExecutionProvider" - - return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) - - def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): - """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the - latest_model_name. - - Arguments: - save_directory (`str` or `Path`): - Directory where to save the model file. - file_name(`str`, *optional*): - Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the - model with a different name. - """ - model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME - - src_path = self.model_save_dir.joinpath(self.latest_model_name) - dst_path = Path(save_directory).joinpath(model_file_name) - try: - shutil.copyfile(src_path, dst_path) - except shutil.SameFileError: - pass - - # copy external weights (for models >2GB) - src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) - if src_path.exists(): - dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) - try: - shutil.copyfile(src_path, dst_path) - except shutil.SameFileError: - pass - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - **kwargs, - ): - """ - Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class - method.: - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - """ - if os.path.isfile(save_directory): - logger.error(f"Provided path ({save_directory}) should be a directory, not a file") - return - - os.makedirs(save_directory, exist_ok=True) - - # saving model weights/files - self._save_pretrained(save_directory, **kwargs) - - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - use_auth_token: Optional[Union[bool, str, None]] = None, - revision: Optional[Union[str, None]] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - file_name: Optional[str] = None, - provider: Optional[str] = None, - sess_options: Optional["ort.SessionOptions"] = None, - **kwargs, - ): - """ - Load a model from a directory or the HF Hub. - - Arguments: - model_id (`str` or `Path`): - Directory from which to load - use_auth_token (`str` or `bool`): - Is needed to load models from a private or gated repository - revision (`str`): - Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id - cache_dir (`Union[str, Path]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - file_name(`str`): - Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load - different model files from the same repository or directory. - provider(`str`): - The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. - kwargs (`Dict`, *optional*): - kwargs will be passed to the model during initialization - """ - model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME - # load model from local directory - if os.path.isdir(model_id): - model = OnnxRuntimeModel.load_model( - os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options - ) - kwargs["model_save_dir"] = Path(model_id) - # load model from hub - else: - # download model - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=model_file_name, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - ) - kwargs["model_save_dir"] = Path(model_cache_path).parent - kwargs["latest_model_name"] = Path(model_cache_path).name - model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) - return cls(model=model, **kwargs) - - @classmethod - def from_pretrained( - cls, - model_id: Union[str, Path], - force_download: bool = True, - use_auth_token: Optional[str] = None, - cache_dir: Optional[str] = None, - **model_kwargs, - ): - revision = None - if len(str(model_id).split("@")) == 2: - model_id, revision = model_id.split("@") - - return cls._from_pretrained( - model_id=model_id, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - use_auth_token=use_auth_token, - **model_kwargs, - ) diff --git a/my_diffusers/pipelines/paint_by_example/__init__.py b/my_diffusers/pipelines/paint_by_example/__init__.py deleted file mode 100644 index f0fc8cb71e3f4e1e8baf16c7143658ca64934306..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/paint_by_example/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy as np -import PIL -from PIL import Image - -from ...utils import is_torch_available, is_transformers_available - - -if is_transformers_available() and is_torch_available(): - from .image_encoder import PaintByExampleImageEncoder - from .pipeline_paint_by_example import PaintByExamplePipeline diff --git a/my_diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 1a714efc4f3da4e2de0b72b5b52b9c25c9fa6557..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/paint_by_example/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-311.pyc b/my_diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-311.pyc deleted file mode 100644 index db056b926baa9bc0fe69ebe2b84f10b6f519d469..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/paint_by_example/__pycache__/image_encoder.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-311.pyc b/my_diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-311.pyc deleted file mode 100644 index a89487431edeae11381b65b81eeacb3b7cbc780f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/paint_by_example/__pycache__/pipeline_paint_by_example.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/paint_by_example/image_encoder.py b/my_diffusers/pipelines/paint_by_example/image_encoder.py deleted file mode 100644 index 831489eefed167264c8fd8f57e1ed59610ebb858..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/paint_by_example/image_encoder.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -from torch import nn -from transformers import CLIPPreTrainedModel, CLIPVisionModel - -from ...models.attention import BasicTransformerBlock -from ...utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class PaintByExampleImageEncoder(CLIPPreTrainedModel): - def __init__(self, config, proj_size=768): - super().__init__(config) - self.proj_size = proj_size - - self.model = CLIPVisionModel(config) - self.mapper = PaintByExampleMapper(config) - self.final_layer_norm = nn.LayerNorm(config.hidden_size) - self.proj_out = nn.Linear(config.hidden_size, self.proj_size) - - # uncondition for scaling - self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) - - def forward(self, pixel_values, return_uncond_vector=False): - clip_output = self.model(pixel_values=pixel_values) - latent_states = clip_output.pooler_output - latent_states = self.mapper(latent_states[:, None]) - latent_states = self.final_layer_norm(latent_states) - latent_states = self.proj_out(latent_states) - if return_uncond_vector: - return latent_states, self.uncond_vector - - return latent_states - - -class PaintByExampleMapper(nn.Module): - def __init__(self, config): - super().__init__() - num_layers = (config.num_hidden_layers + 1) // 5 - hid_size = config.hidden_size - num_heads = 1 - self.blocks = nn.ModuleList( - [ - BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) - for _ in range(num_layers) - ] - ) - - def forward(self, hidden_states): - for block in self.blocks: - hidden_states = block(hidden_states) - - return hidden_states diff --git a/my_diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/my_diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py deleted file mode 100644 index 59205e4c24b04d893ebec01cf7844feed5f252b3..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ /dev/null @@ -1,578 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from transformers import CLIPFeatureExtractor - -from diffusers.utils import is_accelerate_available - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion import StableDiffusionPipelineOutput -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from .image_encoder import PaintByExampleImageEncoder - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def prepare_mask_and_masked_image(image, mask): - """ - Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - - - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") - - # Batch single image - if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Batched mask - if mask.shape[0] == image.shape[0]: - mask = mask.unsqueeze(1) - else: - mask = mask.unsqueeze(0) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - assert mask.shape[1] == 1, "Mask image must have a single channel" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # paint-by-example inverses the mask - mask = 1 - mask - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - if isinstance(image, PIL.Image.Image): - image = [image] - - image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, PIL.Image.Image): - mask = [mask] - - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - - # paint-by-example inverses the mask - mask = 1 - mask - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * mask - - return mask, masked_image - - -class PaintByExamplePipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - # TODO: feature_extractor is required to encode initial images (if they are in PIL format), - # we should give a descriptive message if the pipeline doesn't have one. - _optional_components = ["safety_checker"] - - def __init__( - self, - vae: AutoencoderKL, - image_encoder: PaintByExampleImageEncoder, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = False, - ): - super().__init__() - - self.register_modules( - vae=vae, - image_encoder=image_encoder, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.vae, self.image_encoder]: - cpu_offload(cpu_offloaded_model, execution_device=device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs - def check_inputs(self, image, height, width, callback_steps): - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" - f" {type(image)}" - ) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) - - mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask - masked_image_latents = ( - torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - return mask, masked_image_latents - - def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(images=image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) - - # duplicate image embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = image_embeddings.shape - image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) - image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) - negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) - - return image_embeddings - - @torch.no_grad() - def __call__( - self, - example_image: Union[torch.FloatTensor, PIL.Image.Image], - image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - example_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): - The exemplar image to guide the image generation. - image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): - `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will - be masked out with `mask_image` and repainted according to `prompt`. - mask_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted - to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) - instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 1. Define call parameters - if isinstance(image, PIL.Image.Image): - batch_size = 1 - elif isinstance(image, list): - batch_size = len(image) - else: - batch_size = image.shape[0] - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 2. Preprocess mask and image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) - height, width = masked_image.shape[-2:] - - # 3. Check inputs - self.check_inputs(example_image, height, width, callback_steps) - - # 4. Encode input image - image_embeddings = self._encode_image( - example_image, device, num_images_per_prompt, do_classifier_free_guidance - ) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - image_embeddings.dtype, - device, - generator, - latents, - ) - - # 7. Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - image_embeddings.dtype, - device, - generator, - do_classifier_free_guidance, - ) - - # 8. Check that sizes of mask, masked image and latents match - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 10. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 11. Post-processing - image = self.decode_latents(latents) - - # 12. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) - - # 13. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/pipeline_flax_utils.py b/my_diffusers/pipelines/pipeline_flax_utils.py deleted file mode 100644 index 30e32c3d66e9e5a0f26dc1a0c30b9e61ca8ac6a8..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/pipeline_flax_utils.py +++ /dev/null @@ -1,549 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import inspect -import os -from typing import Any, Dict, List, Optional, Union - -import flax -import numpy as np -import PIL -from flax.core.frozen_dict import FrozenDict -from huggingface_hub import snapshot_download -from PIL import Image -from tqdm.auto import tqdm - -from ..configuration_utils import ConfigMixin -from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin -from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin -from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging - - -if is_transformers_available(): - from transformers import FlaxPreTrainedModel - -INDEX_FILE = "diffusion_flax_model.bin" - - -logger = logging.get_logger(__name__) - - -LOADABLE_CLASSES = { - "diffusers": { - "FlaxModelMixin": ["save_pretrained", "from_pretrained"], - "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], - "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], - }, - "transformers": { - "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], - "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], - "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], - "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], - "ProcessorMixin": ["save_pretrained", "from_pretrained"], - "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], - }, -} - -ALL_IMPORTABLE_CLASSES = {} -for library in LOADABLE_CLASSES: - ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) - - -def import_flax_or_no_model(module, class_name): - try: - # 1. First make sure that if a Flax object is present, import this one - class_obj = getattr(module, "Flax" + class_name) - except AttributeError: - # 2. If this doesn't work, it's not a model and we don't append "Flax" - class_obj = getattr(module, class_name) - except AttributeError: - raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") - - return class_obj - - -@flax.struct.dataclass -class FlaxImagePipelineOutput(BaseOutput): - """ - Output class for image pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - - -class FlaxDiffusionPipeline(ConfigMixin): - r""" - Base class for all models. - - [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion - pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all - pipelines to: - - - enabling/disabling the progress bar for the denoising iteration - - Class attributes: - - - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all - components of the diffusion pipeline. - """ - config_name = "model_index.json" - - def register_modules(self, **kwargs): - # import it here to avoid circular import - from diffusers import pipelines - - for name, module in kwargs.items(): - if module is None: - register_dict = {name: (None, None)} - else: - # retrieve library - library = module.__module__.split(".")[0] - - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir - - # retrieve class_name - class_name = module.__class__.__name__ - - register_dict = {name: (library, class_name)} - - # save model index config - self.register_to_config(**register_dict) - - # set models - setattr(self, name, module) - - def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): - # TODO: handle inference_state - """ - Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to - a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading - method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class - method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - """ - self.save_config(save_directory) - - model_index_dict = dict(self.config) - model_index_dict.pop("_class_name") - model_index_dict.pop("_diffusers_version") - model_index_dict.pop("_module", None) - - for pipeline_component_name in model_index_dict.keys(): - sub_model = getattr(self, pipeline_component_name) - if sub_model is None: - # edge case for saving a pipeline with safety_checker=None - continue - - model_cls = sub_model.__class__ - - save_method_name = None - # search for the model's base class in LOADABLE_CLASSES - for library_name, library_classes in LOADABLE_CLASSES.items(): - library = importlib.import_module(library_name) - for base_class, save_load_methods in library_classes.items(): - class_candidate = getattr(library, base_class, None) - if class_candidate is not None and issubclass(model_cls, class_candidate): - # if we found a suitable base class in LOADABLE_CLASSES then grab its save method - save_method_name = save_load_methods[0] - break - if save_method_name is not None: - break - - save_method = getattr(sub_model, save_method_name) - expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) - - if expects_params: - save_method( - os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] - ) - else: - save_method(os.path.join(save_directory, pipeline_component_name)) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a Flax diffusion pipeline from pre-trained pipeline weights. - - The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on - https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like - `CompVis/ldm-text2im-large-256`. - - A path to a *directory* containing pipeline weights saved using - [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. - dtype (`str` or `jnp.dtype`, *optional*): - Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. specify the folder name here. - - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the - specific pipeline class. The overwritten components are then directly passed to the pipelines - `__init__` method. See example below for more information. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - Examples: - - ```py - >>> from diffusers import FlaxDiffusionPipeline - - >>> # Download pipeline from huggingface.co and cache. - >>> # Requires to be logged in to Hugging Face hub, - >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens) - >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", - ... revision="bf16", - ... dtype=jnp.bfloat16, - ... ) - - >>> # Download pipeline, but use a different scheduler - >>> from diffusers import FlaxDPMSolverMultistepScheduler - - >>> model_id = "runwayml/stable-diffusion-v1-5" - >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained( - ... model_id, - ... subfolder="scheduler", - ... ) - - >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained( - ... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp - ... ) - >>> dpm_params["scheduler"] = dpmpp_state - ``` - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - from_pt = kwargs.pop("from_pt", False) - dtype = kwargs.pop("dtype", None) - - # 1. Download the checkpoints and configs - # use snapshot download here to get it working from from_pretrained - if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.load_config( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - # make sure we only download sub-folders and `diffusers` filenames - folder_names = [k for k in config_dict.keys() if not k.startswith("_")] - allow_patterns = [os.path.join(k, "*") for k in folder_names] - allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] - - # make sure we don't download PyTorch weights, unless when using from_pt - ignore_patterns = "*.bin" if not from_pt else [] - - if cls != FlaxDiffusionPipeline: - requested_pipeline_class = cls.__name__ - else: - requested_pipeline_class = config_dict.get("_class_name", cls.__name__) - requested_pipeline_class = ( - requested_pipeline_class - if requested_pipeline_class.startswith("Flax") - else "Flax" + requested_pipeline_class - ) - - user_agent = {"pipeline_class": requested_pipeline_class} - user_agent = http_user_agent(user_agent) - - # download all allow_patterns - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, - ) - else: - cached_folder = pretrained_model_name_or_path - - config_dict = cls.load_config(cached_folder) - - # 2. Load the pipeline class, if using custom module then load it from the hub - # if we load from explicit class, let's use it - if cls != FlaxDiffusionPipeline: - pipeline_class = cls - else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - class_name = ( - config_dict["_class_name"] - if config_dict["_class_name"].startswith("Flax") - else "Flax" + config_dict["_class_name"] - ) - pipeline_class = getattr(diffusers_module, class_name) - - # some modules can be passed directly to the init - # in this case they are already instantiated in `kwargs` - # extract them here - expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - - init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - - init_kwargs = {} - - # inference_params - params = {} - - # import it here to avoid circular import - from diffusers import pipelines - - # 3. Load each module in the pipeline - for name, (library_name, class_name) in init_dict.items(): - if class_name is None: - # edge case for when the pipeline was saved with safety_checker=None - init_kwargs[name] = None - continue - - is_pipeline_module = hasattr(pipelines, library_name) - loaded_sub_model = None - sub_model_should_be_defined = True - - # if the model is in a pipeline module, then we load it from the pipeline - if name in passed_class_obj: - # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - expected_class_obj = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - expected_class_obj = class_candidate - - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" - ) - elif passed_class_obj[name] is None: - logger.warning( - f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" - f" that this might lead to problems when using {pipeline_class} and is not recommended." - ) - sub_model_should_be_defined = False - else: - logger.warning( - f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" - " has the correct type" - ) - - # set passed class object - loaded_sub_model = passed_class_obj[name] - elif is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = import_flax_or_no_model(pipeline_module, class_name) - - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in importable_classes.keys()} - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - class_obj = import_flax_or_no_model(library, class_name) - - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - if loaded_sub_model is None and sub_model_should_be_defined: - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - load_method = getattr(class_obj, load_method_name) - - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loadable_folder = os.path.join(cached_folder, name) - else: - loaded_sub_model = cached_folder - - if issubclass(class_obj, FlaxModelMixin): - loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) - params[name] = loaded_params - elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): - if from_pt: - # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here - loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) - loaded_params = loaded_sub_model.params - del loaded_sub_model._params - else: - loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) - params[name] = loaded_params - elif issubclass(class_obj, FlaxSchedulerMixin): - loaded_sub_model, scheduler_state = load_method(loadable_folder) - params[name] = scheduler_state - else: - loaded_sub_model = load_method(loadable_folder) - - init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - model = pipeline_class(**init_kwargs, dtype=dtype) - return model, params - - @staticmethod - def _get_signature_keys(obj): - parameters = inspect.signature(obj.__init__).parameters - required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} - optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) - expected_modules = set(required_parameters.keys()) - set(["self"]) - return expected_modules, optional_parameters - - @property - def components(self) -> Dict[str, Any]: - r""" - - The `self.components` property can be useful to run different pipelines with the same weights and - configurations to not have to re-allocate memory. - - Examples: - - ```py - >>> from diffusers import ( - ... FlaxStableDiffusionPipeline, - ... FlaxStableDiffusionImg2ImgPipeline, - ... ) - - >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16 - ... ) - >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) - ``` - - Returns: - A dictionary containing all the modules needed to initialize the pipeline. - """ - expected_modules, optional_parameters = self._get_signature_keys(self) - components = { - k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters - } - - if set(components.keys()) != expected_modules: - raise ValueError( - f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" - f" {expected_modules} to be defined, but {components} are defined." - ) - - return components - - @staticmethod - def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - # TODO: make it compatible with jax.lax - def progress_bar(self, iterable): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - return tqdm(iterable, **self._progress_bar_config) - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs diff --git a/my_diffusers/pipelines/pipeline_utils.py b/my_diffusers/pipelines/pipeline_utils.py deleted file mode 100644 index 65b348d2e7d35aed77e354be1b7b289f1efd20d3..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/pipeline_utils.py +++ /dev/null @@ -1,1136 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import inspect -import os -import re -import warnings -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import PIL -import torch -from huggingface_hub import model_info, snapshot_download -from packaging import version -from PIL import Image -from tqdm.auto import tqdm - -import diffusers - -from .. import __version__ -from ..configuration_utils import ConfigMixin -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT -from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from ..utils import ( - CONFIG_NAME, - DEPRECATED_REVISION_ARGS, - DIFFUSERS_CACHE, - HF_HUB_OFFLINE, - SAFETENSORS_WEIGHTS_NAME, - WEIGHTS_NAME, - BaseOutput, - deprecate, - get_class_from_dynamic_module, - http_user_agent, - is_accelerate_available, - is_accelerate_version, - is_safetensors_available, - is_torch_version, - is_transformers_available, - logging, -) - - -if is_transformers_available(): - import transformers - from transformers import PreTrainedModel - from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME - from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME - from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME - -from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME - - -if is_accelerate_available(): - import accelerate - - -INDEX_FILE = "diffusion_pytorch_model.bin" -CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" -DUMMY_MODULES_FOLDER = "diffusers.utils" -TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" - - -logger = logging.get_logger(__name__) - - -LOADABLE_CLASSES = { - "diffusers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_pretrained", "from_pretrained"], - "DiffusionPipeline": ["save_pretrained", "from_pretrained"], - "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], - }, - "transformers": { - "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], - "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], - "PreTrainedModel": ["save_pretrained", "from_pretrained"], - "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], - "ProcessorMixin": ["save_pretrained", "from_pretrained"], - "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], - }, - "onnxruntime.training": { - "ORTModule": ["save_pretrained", "from_pretrained"], - }, -} - -ALL_IMPORTABLE_CLASSES = {} -for library in LOADABLE_CLASSES: - ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) - - -@dataclass -class ImagePipelineOutput(BaseOutput): - """ - Output class for image pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - - -@dataclass -class AudioPipelineOutput(BaseOutput): - """ - Output class for audio pipelines. - - Args: - audios (`np.ndarray`) - List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the - denoised audio samples of the diffusion pipeline. - """ - - audios: np.ndarray - - -def is_safetensors_compatible(filenames, variant=None) -> bool: - """ - Checking for safetensors compatibility: - - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch - files to know which safetensors files are needed. - - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. - - Converting default pytorch serialized filenames to safetensors serialized filenames: - - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" - - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" - extension is replaced with ".safetensors" - """ - pt_filenames = [] - - sf_filenames = set() - - for filename in filenames: - _, extension = os.path.splitext(filename) - - if extension == ".bin": - pt_filenames.append(filename) - elif extension == ".safetensors": - sf_filenames.add(filename) - - for filename in pt_filenames: - # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam' - path, filename = os.path.split(filename) - filename, extension = os.path.splitext(filename) - - if filename == "pytorch_model": - filename = "model" - elif filename == f"pytorch_model.{variant}": - filename = f"model.{variant}" - else: - filename = filename - - expected_sf_filename = os.path.join(path, filename) - expected_sf_filename = f"{expected_sf_filename}.safetensors" - - if expected_sf_filename not in sf_filenames: - logger.warning(f"{expected_sf_filename} not found") - return False - - return True - - -def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]: - filenames = set(sibling.rfilename for sibling in info.siblings) - weight_names = [ - WEIGHTS_NAME, - SAFETENSORS_WEIGHTS_NAME, - FLAX_WEIGHTS_NAME, - ONNX_WEIGHTS_NAME, - ONNX_EXTERNAL_WEIGHTS_NAME, - ] - - if is_transformers_available(): - weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] - - # model_pytorch, diffusion_model_pytorch, ... - weight_prefixes = [w.split(".")[0] for w in weight_names] - # .bin, .safetensors, ... - weight_suffixs = [w.split(".")[-1] for w in weight_names] - - variant_file_regex = ( - re.compile(f"({'|'.join(weight_prefixes)})(.{variant}.)({'|'.join(weight_suffixs)})") - if variant is not None - else None - ) - non_variant_file_regex = re.compile(f"{'|'.join(weight_names)}") - - if variant is not None: - variant_filenames = set(f for f in filenames if variant_file_regex.match(f.split("/")[-1]) is not None) - else: - variant_filenames = set() - - non_variant_filenames = set(f for f in filenames if non_variant_file_regex.match(f.split("/")[-1]) is not None) - - usable_filenames = set(variant_filenames) - for f in non_variant_filenames: - variant_filename = f"{f.split('.')[0]}.{variant}.{f.split('.')[1]}" - if variant_filename not in usable_filenames: - usable_filenames.add(f) - - return usable_filenames, variant_filenames - - -class DiffusionPipeline(ConfigMixin): - r""" - Base class for all models. - - [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines - and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: - - - move all PyTorch modules to the device of your choice - - enabling/disabling the progress bar for the denoising iteration - - Class attributes: - - - **config_name** (`str`) -- name of the config file that will store the class and module names of all - components of the diffusion pipeline. - - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be - passed for the pipeline to function (should be overridden by subclasses). - """ - config_name = "model_index.json" - _optional_components = [] - - def register_modules(self, **kwargs): - # import it here to avoid circular import - from diffusers import pipelines - - for name, module in kwargs.items(): - # retrieve library - if module is None: - register_dict = {name: (None, None)} - else: - library = module.__module__.split(".")[0] - - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir - - # retrieve class_name - class_name = module.__class__.__name__ - - register_dict = {name: (library, class_name)} - - # save model index config - self.register_to_config(**register_dict) - - # set models - setattr(self, name, module) - - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - safe_serialization: bool = False, - variant: Optional[str] = None, - ): - """ - Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to - a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading - method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). - variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. - """ - self.save_config(save_directory) - - model_index_dict = dict(self.config) - model_index_dict.pop("_class_name") - model_index_dict.pop("_diffusers_version") - model_index_dict.pop("_module", None) - - expected_modules, optional_kwargs = self._get_signature_keys(self) - - def is_saveable_module(name, value): - if name not in expected_modules: - return False - if name in self._optional_components and value[0] is None: - return False - return True - - model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} - - for pipeline_component_name in model_index_dict.keys(): - sub_model = getattr(self, pipeline_component_name) - model_cls = sub_model.__class__ - - save_method_name = None - # search for the model's base class in LOADABLE_CLASSES - for library_name, library_classes in LOADABLE_CLASSES.items(): - library = importlib.import_module(library_name) - for base_class, save_load_methods in library_classes.items(): - class_candidate = getattr(library, base_class, None) - if class_candidate is not None and issubclass(model_cls, class_candidate): - # if we found a suitable base class in LOADABLE_CLASSES then grab its save method - save_method_name = save_load_methods[0] - break - if save_method_name is not None: - break - - save_method = getattr(sub_model, save_method_name) - - # Call the save method with the argument safe_serialization only if it's supported - save_method_signature = inspect.signature(save_method) - save_method_accept_safe = "safe_serialization" in save_method_signature.parameters - save_method_accept_variant = "variant" in save_method_signature.parameters - - save_kwargs = {} - if save_method_accept_safe: - save_kwargs["safe_serialization"] = safe_serialization - if save_method_accept_variant: - save_kwargs["variant"] = variant - - save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) - - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): - if torch_device is None: - return self - - # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. - def module_is_sequentially_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): - return False - - return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - - def module_is_offloaded(module): - if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): - return False - - return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) - - # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer - pipeline_is_sequentially_offloaded = any( - module_is_sequentially_offloaded(module) for _, module in self.components.items() - ) - if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda": - raise ValueError( - "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." - ) - - # Display a warning in this case (the operation succeeds but the benefits are lost) - pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": - logger.warning( - f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." - ) - - module_names, _, _ = self.extract_init_dict(dict(self.config)) - is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - if ( - module.dtype == torch.float16 - and str(torch_device) in ["cpu"] - and not silence_dtype_warnings - and not is_offloaded - ): - logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" - " is not recommended to move them to `cpu` as running them will fail. Please make" - " sure to use an accelerator to run the pipeline in inference, due to the lack of" - " support for`float16` operations on this device in PyTorch. Please, remove the" - " `torch_dtype=torch.float16` argument, or use another device for inference." - ) - module.to(torch_device) - return self - - @property - def device(self) -> torch.device: - r""" - Returns: - `torch.device`: The torch device on which the pipeline is located. - """ - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for name in module_names.keys(): - module = getattr(self, name) - if isinstance(module, torch.nn.Module): - return module.device - return torch.device("cpu") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - r""" - Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. - - The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on - https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like - `CompVis/ldm-text2im-large-256`. - - A path to a *directory* containing pipeline weights saved using - [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. - custom_pipeline (`str`, *optional*): - - - - This is an experimental feature and is likely to change in the future. - - - - Can be either: - - - A string, the *repo id* of a custom pipeline hosted inside a model repo on - https://huggingface.co/. Valid repo ids have to be located under a user or organization name, - like `hf-internal-testing/diffusers-dummy-pipeline`. - - - - It is required that the model repo has a file, called `pipeline.py` that defines the custom - pipeline. - - - - - A string, the *file name* of a community pipeline hosted on GitHub under - https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to - match exactly the file name without `.py` located under the above link, *e.g.* - `clip_guided_stable_diffusion`. - - - - Community pipelines are always loaded from the current `main` branch of GitHub. - - - - - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. - - - - It is required that the directory has a file, called `pipeline.py` that defines the custom - pipeline. - - - - For more information on how to load and create custom pipelines, please have a look at [Loading and - Adding Custom - Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) - - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub): - The specific model version to use. It can be a branch name, a tag name, or a commit id similar to - `revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a - custom pipeline from GitHub. - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. specify the folder name here. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the - same device. - - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading by not initializing the weights and only loading the pre-trained weights. This - also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the - model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - setting this argument to `True` will raise an error. - return_cached_folder (`bool`, *optional*, defaults to `False`): - If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the - specific pipeline class. The overwritten components are then directly passed to the pipelines - `__init__` method. See example below for more information. - variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` - - - - - - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. - - - - Examples: - - ```py - >>> from diffusers import DiffusionPipeline - - >>> # Download pipeline from huggingface.co and cache. - >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") - - >>> # Download pipeline that requires an authorization token - >>> # For more information on access tokens, please refer to this section - >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) - >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - - >>> # Use a different scheduler - >>> from diffusers import LMSDiscreteScheduler - - >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.scheduler = scheduler - ``` - """ - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - from_flax = kwargs.pop("from_flax", False) - torch_dtype = kwargs.pop("torch_dtype", None) - custom_pipeline = kwargs.pop("custom_pipeline", None) - custom_revision = kwargs.pop("custom_revision", None) - provider = kwargs.pop("provider", None) - sess_options = kwargs.pop("sess_options", None) - device_map = kwargs.pop("device_map", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - return_cached_folder = kwargs.pop("return_cached_folder", False) - variant = kwargs.pop("variant", None) - - # 1. Download the checkpoints and configs - # use snapshot download here to get it working from from_pretrained - if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.load_config( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - - # retrieve all folder_names that contain relevant files - folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] - - if not local_files_only: - info = model_info( - pretrained_model_name_or_path, - use_auth_token=use_auth_token, - revision=revision, - ) - model_filenames, variant_filenames = variant_compatible_siblings(info, variant=variant) - model_folder_names = set([os.path.split(f)[0] for f in model_filenames]) - - if revision in DEPRECATED_REVISION_ARGS and version.parse( - version.parse(__version__).base_version - ) >= version.parse("0.15.0"): - info = model_info( - pretrained_model_name_or_path, - use_auth_token=use_auth_token, - revision=None, - ) - comp_model_filenames, _ = variant_compatible_siblings(info, variant=revision) - comp_model_filenames = [ - ".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames - ] - - if set(comp_model_filenames) == set(model_filenames): - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{variant}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", - FutureWarning, - ) - else: - warnings.warn( - f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", - FutureWarning, - ) - - # all filenames compatible with variant will be added - allow_patterns = list(model_filenames) - - # allow all patterns from non-model folders - # this enables downloading schedulers, tokenizers, ... - allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names] - # also allow downloading config.jsons with the model - allow_patterns += [os.path.join(k, "*.json") for k in model_folder_names] - - allow_patterns += [ - SCHEDULER_CONFIG_NAME, - CONFIG_NAME, - cls.config_name, - CUSTOM_PIPELINE_FILE_NAME, - ] - - if from_flax: - ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant): - ignore_patterns = ["*.bin", "*.msgpack"] - - safetensors_variant_filenames = set([f for f in variant_filenames if f.endswith(".safetensors")]) - safetensors_model_filenames = set([f for f in model_filenames if f.endswith(".safetensors")]) - if ( - len(safetensors_variant_filenames) > 0 - and safetensors_model_filenames != safetensors_variant_filenames - ): - logger.warn( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) - - else: - ignore_patterns = ["*.safetensors", "*.msgpack"] - - bin_variant_filenames = set([f for f in variant_filenames if f.endswith(".bin")]) - bin_model_filenames = set([f for f in model_filenames if f.endswith(".bin")]) - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warn( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) - - else: - # allow everything since it has to be downloaded anyways - ignore_patterns = allow_patterns = None - - if cls != DiffusionPipeline: - requested_pipeline_class = cls.__name__ - else: - requested_pipeline_class = config_dict.get("_class_name", cls.__name__) - user_agent = {"pipeline_class": requested_pipeline_class} - if custom_pipeline is not None and not custom_pipeline.endswith(".py"): - user_agent["custom_pipeline"] = custom_pipeline - - user_agent = http_user_agent(user_agent) - - # download all allow_patterns - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, - ) - else: - cached_folder = pretrained_model_name_or_path - config_dict = cls.load_config(cached_folder) - - # retrieve which subfolders should load variants - model_variants = {} - if variant is not None: - for folder in os.listdir(cached_folder): - folder_path = os.path.join(cached_folder, folder) - is_folder = os.path.isdir(folder_path) and folder in config_dict - variant_exists = is_folder and any(path.split(".")[1] == variant for path in os.listdir(folder_path)) - if variant_exists: - model_variants[folder] = variant - - # 2. Load the pipeline class, if using custom module then load it from the hub - # if we load from explicit class, let's use it - if custom_pipeline is not None: - if custom_pipeline.endswith(".py"): - path = Path(custom_pipeline) - # decompose into folder & file - file_name = path.name - custom_pipeline = path.parent.absolute() - else: - file_name = CUSTOM_PIPELINE_FILE_NAME - - pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision - ) - elif cls != DiffusionPipeline: - pipeline_class = cls - else: - diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) - - # To be removed in 1.0.0 - if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( - version.parse(config_dict["_diffusers_version"]).base_version - ) <= version.parse("0.5.1"): - from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy - - pipeline_class = StableDiffusionInpaintPipelineLegacy - - deprecation_message = ( - "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" - f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" - " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" - " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" - f" checkpoint {pretrained_model_name_or_path} to the format of" - " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" - " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." - ) - deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) - - # some modules can be passed directly to the init - # in this case they are already instantiated in `kwargs` - # extract them here - expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) - passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - - init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) - - # define init kwargs - init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} - init_kwargs = {**init_kwargs, **passed_pipe_kwargs} - - # remove `null` components - def load_module(name, value): - if value[0] is None: - return False - if name in passed_class_obj and passed_class_obj[name] is None: - return False - return True - - init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - - # Special case: safety_checker must be loaded separately when using `from_flax` - if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: - raise NotImplementedError( - "The safety checker cannot be automatically loaded when loading weights `from_flax`." - " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" - " separately if you need it." - ) - - if len(unused_kwargs) > 0: - logger.warning( - f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." - ) - - if low_cpu_mem_usage and not is_accelerate_available(): - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if device_map is not None and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `device_map=None`." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - if low_cpu_mem_usage is False and device_map is not None: - raise ValueError( - f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" - " dispatching. Please make sure to set `low_cpu_mem_usage=True`." - ) - - # import it here to avoid circular import - from diffusers import pipelines - - # 3. Load each module in the pipeline - for name, (library_name, class_name) in init_dict.items(): - # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names - if class_name.startswith("Flax"): - class_name = class_name[4:] - - is_pipeline_module = hasattr(pipelines, library_name) - loaded_sub_model = None - - # if the model is in a pipeline module, then we load it from the pipeline - if name in passed_class_obj: - # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module: - library = importlib.import_module(library_name) - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - expected_class_obj = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - expected_class_obj = class_candidate - - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): - raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" - ) - else: - logger.warning( - f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" - " has the correct type" - ) - - # set passed class object - loaded_sub_model = passed_class_obj[name] - elif is_pipeline_module: - pipeline_module = getattr(pipelines, library_name) - class_obj = getattr(pipeline_module, class_name) - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in importable_classes.keys()} - else: - # else we just import it from the library. - library = importlib.import_module(library_name) - - class_obj = getattr(library, class_name) - importable_classes = LOADABLE_CLASSES[library_name] - class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - - if loaded_sub_model is None: - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - if load_method_name is None: - none_module = class_obj.__module__ - is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( - TRANSFORMERS_DUMMY_MODULES_FOLDER - ) - if is_dummy_path and "dummy" in none_module: - # call class_obj for nice error message of missing requirements - class_obj() - - raise ValueError( - f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" - f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." - ) - - load_method = getattr(class_obj, load_method_name) - loading_kwargs = {} - - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - - is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) - - if is_transformers_available(): - transformers_version = version.parse(version.parse(transformers.__version__).base_version) - else: - transformers_version = "N/A" - - is_transformers_model = ( - is_transformers_available() - and issubclass(class_obj, PreTrainedModel) - and transformers_version >= version.parse("4.20.0") - ) - - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. - # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. - # This makes sure that the weights won't be initialized which significantly speeds up loading. - if is_diffusers_model or is_transformers_model: - loading_kwargs["device_map"] = device_map - loading_kwargs["variant"] = model_variants.pop(name, None) - if from_flax: - loading_kwargs["from_flax"] = True - - # the following can be deleted once the minimum required `transformers` version - # is higher than 4.27 - if ( - is_transformers_model - and loading_kwargs["variant"] is not None - and transformers_version < version.parse("4.27.0") - ): - raise ImportError( - f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" - ) - elif is_transformers_model and loading_kwargs["variant"] is None: - loading_kwargs.pop("variant") - - # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` - if not (from_flax and is_transformers_model): - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - else: - loading_kwargs["low_cpu_mem_usage"] = False - - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder, **loading_kwargs) - - init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - # 4. Potentially add passed objects if expected - missing_modules = set(expected_modules) - set(init_kwargs.keys()) - passed_modules = list(passed_class_obj.keys()) - optional_modules = pipeline_class._optional_components - if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): - for module in missing_modules: - init_kwargs[module] = passed_class_obj.get(module, None) - elif len(missing_modules) > 0: - passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs - raise ValueError( - f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." - ) - - # 5. Instantiate the pipeline - model = pipeline_class(**init_kwargs) - - if return_cached_folder: - return model, cached_folder - return model - - @staticmethod - def _get_signature_keys(obj): - parameters = inspect.signature(obj.__init__).parameters - required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} - optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) - expected_modules = set(required_parameters.keys()) - set(["self"]) - return expected_modules, optional_parameters - - @property - def components(self) -> Dict[str, Any]: - r""" - - The `self.components` property can be useful to run different pipelines with the same weights and - configurations to not have to re-allocate memory. - - Examples: - - ```py - >>> from diffusers import ( - ... StableDiffusionPipeline, - ... StableDiffusionImg2ImgPipeline, - ... StableDiffusionInpaintPipeline, - ... ) - - >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) - >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) - ``` - - Returns: - A dictionary containing all the modules needed to initialize the pipeline. - """ - expected_modules, optional_parameters = self._get_signature_keys(self) - components = { - k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters - } - - if set(components.keys()) != expected_modules: - raise ValueError( - f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" - f" {expected_modules} to be defined, but {components.keys()} are defined." - ) - - return components - - @staticmethod - def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - if images.shape[-1] == 1: - # special case for grayscale (single channel) images - pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] - else: - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - def progress_bar(self, iterable=None, total=None): - if not hasattr(self, "_progress_bar_config"): - self._progress_bar_config = {} - elif not isinstance(self._progress_bar_config, dict): - raise ValueError( - f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." - ) - - if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) - elif total is not None: - return tqdm(total=total, **self._progress_bar_config) - else: - raise ValueError("Either `total` or `iterable` has to be defined.") - - def set_progress_bar_config(self, **kwargs): - self._progress_bar_config = kwargs - - def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): - r""" - Enable memory efficient attention as implemented in xformers. - - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. - - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. - - Parameters: - attention_op (`Callable`, *optional*): - Override the default `None` operator for use as `op` argument to the - [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) - function of xFormers. - - Examples: - - ```py - >>> import torch - >>> from diffusers import DiffusionPipeline - >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp - - >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) - >>> pipe = pipe.to("cuda") - >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) - >>> # Workaround for not accepting attention shape using VAE for Flash Attention - >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) - ``` - """ - self.set_use_memory_efficient_attention_xformers(True, attention_op) - - def disable_xformers_memory_efficient_attention(self): - r""" - Disable memory efficient attention as implemented in xformers. - """ - self.set_use_memory_efficient_attention_xformers(False) - - def set_use_memory_efficient_attention_xformers( - self, valid: bool, attention_op: Optional[Callable] = None - ) -> None: - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, "set_use_memory_efficient_attention_xformers"): - module.set_use_memory_efficient_attention_xformers(valid, attention_op) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module): - fn_recursive_set_mem_eff(module) - - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - self.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - - def set_attention_slice(self, slice_size: Optional[int]): - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size) diff --git a/my_diffusers/pipelines/pndm/__init__.py b/my_diffusers/pipelines/pndm/__init__.py deleted file mode 100644 index 488eb4f5f2b29c071fdc044ef282bc2838148c1e..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/pndm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_pndm import PNDMPipeline diff --git a/my_diffusers/pipelines/pndm/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/pndm/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 23937f32c507c6242aae16dd5799f2967ffc6c1e..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/pndm/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-311.pyc b/my_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-311.pyc deleted file mode 100644 index d2a9d18538faf40972c461119d03157ca1955019..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/pndm/__pycache__/pipeline_pndm.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/pndm/pipeline_pndm.py b/my_diffusers/pipelines/pndm/pipeline_pndm.py deleted file mode 100644 index 56fb72d3f4ff9827da4b35e2a1ef9095fa741f01..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/pndm/pipeline_pndm.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel -from ...schedulers import PNDMScheduler -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class PNDMPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. - """ - - unet: UNet2DModel - scheduler: PNDMScheduler - - def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): - super().__init__() - - scheduler = PNDMScheduler.from_config(scheduler.config) - - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - num_inference_steps: int = 50, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, `optional`, defaults to 1): The number of images to generate. - num_inference_steps (`int`, `optional`, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - generator (`torch.Generator`, `optional`): A [torch - generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose - between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a - [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - # For more information on the sampling method you can take a look at Algorithm 2 of - # the official paper: https://arxiv.org/pdf/2202.09778.pdf - - # Sample gaussian noise to begin loop - image = randn_tensor( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) - - self.scheduler.set_timesteps(num_inference_steps) - for t in self.progress_bar(self.scheduler.timesteps): - model_output = self.unet(image, t).sample - - image = self.scheduler.step(model_output, t, image).prev_sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/repaint/__init__.py b/my_diffusers/pipelines/repaint/__init__.py deleted file mode 100644 index 16bc86d1cedf6243fb92f7ba331b5a6188133298..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/repaint/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_repaint import RePaintPipeline diff --git a/my_diffusers/pipelines/repaint/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/repaint/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 02058785d98f01dbb4eab6f3ffee9f84d4e12daf..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/repaint/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-311.pyc b/my_diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-311.pyc deleted file mode 100644 index ca4fb1955eed28979f4fded54db6b0ea214df41c..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/repaint/__pycache__/pipeline_repaint.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/repaint/pipeline_repaint.py b/my_diffusers/pipelines/repaint/pipeline_repaint.py deleted file mode 100644 index 2376b5cf55df0f1baf7a24eb57e12c88652e9e57..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/repaint/pipeline_repaint.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Tuple, Union - -import numpy as np -import PIL -import torch - -from ...models import UNet2DModel -from ...schedulers import RePaintScheduler -from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): - if isinstance(mask, torch.Tensor): - return mask - elif isinstance(mask, PIL.Image.Image): - mask = [mask] - - if isinstance(mask[0], PIL.Image.Image): - w, h = mask[0].size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] - mask = np.concatenate(mask, axis=0) - mask = mask.astype(np.float32) / 255.0 - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - elif isinstance(mask[0], torch.Tensor): - mask = torch.cat(mask, dim=0) - return mask - - -class RePaintPipeline(DiffusionPipeline): - unet: UNet2DModel - scheduler: RePaintScheduler - - def __init__(self, unet, scheduler): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - image: Union[torch.Tensor, PIL.Image.Image], - mask_image: Union[torch.Tensor, PIL.Image.Image], - num_inference_steps: int = 250, - eta: float = 0.0, - jump_length: int = 10, - jump_n_sample: int = 10, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - image (`torch.FloatTensor` or `PIL.Image.Image`): - The original image to inpaint on. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - The mask_image where 0.0 values define which part of the original image to inpaint (change). - num_inference_steps (`int`, *optional*, defaults to 1000): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - eta (`float`): - The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM - and 1.0 is DDPM scheduler respectively. - jump_length (`int`, *optional*, defaults to 10): - The number of steps taken forward in time before going backward in time for a single jump ("j" in - RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. - jump_n_sample (`int`, *optional*, defaults to 10): - The number of times we will make forward time jump for a given chosen time sample. Take a look at - Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - - message = "Please use `image` instead of `original_image`." - original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs) - original_image = original_image or image - - original_image = _preprocess_image(original_image) - original_image = original_image.to(device=self.device, dtype=self.unet.dtype) - mask_image = _preprocess_mask(mask_image) - mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype) - - batch_size = original_image.shape[0] - - # sample gaussian noise to begin the loop - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - image_shape = original_image.shape - image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) - - # set step values - self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) - self.scheduler.eta = eta - - t_last = self.scheduler.timesteps[0] + 1 - generator = generator[0] if isinstance(generator, list) else generator - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - if t < t_last: - # predict the noise residual - model_output = self.unet(image, t).sample - # compute previous image: x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample - - else: - # compute the reverse: x_t-1 -> x_t - image = self.scheduler.undo_step(image, t_last, generator) - t_last = t - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/score_sde_ve/__init__.py b/my_diffusers/pipelines/score_sde_ve/__init__.py deleted file mode 100644 index c7c2a85c067b707c155e78a3c8b84562999134e7..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/score_sde_ve/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_score_sde_ve import ScoreSdeVePipeline diff --git a/my_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 2ffbd2a8f7c1f5b8fb9732fcea543162c5cc15e3..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/score_sde_ve/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-311.pyc b/my_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-311.pyc deleted file mode 100644 index df1d0a26c98d584b078bd03bfafecf02750ac3ad..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py deleted file mode 100644 index 60a6f1e70f4a4b51cec74a315904a4e8e7cf6bfa..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel -from ...schedulers import ScoreSdeVeScheduler -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class ScoreSdeVePipeline(DiffusionPipeline): - r""" - Parameters: - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): - The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. - """ - unet: UNet2DModel - scheduler: ScoreSdeVeScheduler - - def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - num_inference_steps: int = 2000, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[ImagePipelineOutput, Tuple]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - - img_size = self.unet.config.sample_size - shape = (batch_size, 3, img_size, img_size) - - model = self.unet - - sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma - sample = sample.to(self.device) - - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.set_sigmas(num_inference_steps) - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) - - # correction step - for _ in range(self.scheduler.config.correct_steps): - model_output = self.unet(sample, sigma_t).sample - sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample - - # prediction step - model_output = model(sample, sigma_t).sample - output = self.scheduler.step_pred(model_output, t, sample, generator=generator) - - sample, sample_mean = output.prev_sample, output.prev_sample_mean - - sample = sample_mean.clamp(0, 1) - sample = sample.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - sample = self.numpy_to_pil(sample) - - if not return_dict: - return (sample,) - - return ImagePipelineOutput(images=sample) diff --git a/my_diffusers/pipelines/semantic_stable_diffusion/__init__.py b/my_diffusers/pipelines/semantic_stable_diffusion/__init__.py deleted file mode 100644 index 0e312c5e30138e106930421ad8c55c23f01e60e7..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/semantic_stable_diffusion/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional, Union - -import numpy as np -import PIL -from PIL import Image - -from ...utils import BaseOutput, is_torch_available, is_transformers_available - - -@dataclass -class SemanticStableDiffusionPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, or `None` if safety checking could not be performed. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: Optional[List[bool]] - - -if is_transformers_available() and is_torch_available(): - from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline diff --git a/my_diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 8233f1ce3626e0803cacdb095df0222606c5cacf..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-311.pyc b/my_diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-311.pyc deleted file mode 100644 index b18e2af98ce3e591d4a86385f2f1707784eb21f1..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/my_diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py deleted file mode 100644 index a421a844c3296f2a6a57a1aec7f83761193d718a..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ /dev/null @@ -1,702 +0,0 @@ -import inspect -from itertools import repeat -from typing import Callable, List, Optional, Union - -import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline -from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, randn_tensor -from . import SemanticStableDiffusionPipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import SemanticStableDiffusionPipeline - - >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> out = pipe( - ... prompt="a photo of the face of a woman", - ... num_images_per_prompt=1, - ... guidance_scale=7, - ... editing_prompt=[ - ... "smiling, smile", # Concepts to apply - ... "glasses, wearing glasses", - ... "curls, wavy hair, curly hair", - ... "beard, full beard, mustache", - ... ], - ... reverse_editing_direction=[ - ... False, - ... False, - ... False, - ... False, - ... ], # Direction of guidance i.e. increase all concepts - ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept - ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept - ... edit_threshold=[ - ... 0.99, - ... 0.975, - ... 0.925, - ... 0.96, - ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions - ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance - ... edit_mom_beta=0.6, # Momentum beta - ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other - ... ) - >>> image = out.images[0] - ``` -""" - - -class SemanticStableDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation with latent editing. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - This model builds on the implementation of ['StableDiffusionPipeline'] - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`Q16SafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - editing_prompt: Optional[Union[str, List[str]]] = None, - editing_prompt_embeddings: Optional[torch.Tensor] = None, - reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, - edit_guidance_scale: Optional[Union[float, List[float]]] = 5, - edit_warmup_steps: Optional[Union[int, List[int]]] = 10, - edit_cooldown_steps: Optional[Union[int, List[int]]] = None, - edit_threshold: Optional[Union[float, List[float]]] = 0.9, - edit_momentum_scale: Optional[float] = 0.1, - edit_mom_beta: Optional[float] = 0.4, - edit_weights: Optional[List[float]] = None, - sem_guidance: Optional[List[torch.Tensor]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - editing_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting - `editing_prompt = None`. Guidance direction of prompt should be specified via - `reverse_editing_direction`. - editing_prompt_embeddings (`torch.Tensor>`, *optional*): - Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be - specified via `reverse_editing_direction`. - reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): - Whether the corresponding prompt in `editing_prompt` should be increased or decreased. - edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): - Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. - `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA - Paper](https://arxiv.org/pdf/2301.12247.pdf). - edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): - Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum - will still be calculated for those steps and applied once all warmup periods are over. - `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). - edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): - Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. - edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): - Threshold of semantic guidance. - edit_momentum_scale (`float`, *optional*, defaults to 0.1): - Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 - momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller - than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are - finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA - Paper](https://arxiv.org/pdf/2301.12247.pdf). - edit_mom_beta (`float`, *optional*, defaults to 0.4): - Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous - momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller - than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA - Paper](https://arxiv.org/pdf/2301.12247.pdf). - edit_weights (`List[float]`, *optional*, defaults to `None`): - Indicates how much each individual concept should influence the overall guidance. If no weights are - provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA - Paper](https://arxiv.org/pdf/2301.12247.pdf). - sem_guidance (`List[torch.Tensor]`, *optional*): - List of pre-generated guidance vectors to be applied at generation. Length of the list has to - correspond to `num_inference_steps`. - - Returns: - [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True, - otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the - second element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - - if editing_prompt: - enable_edit_guidance = True - if isinstance(editing_prompt, str): - editing_prompt = [editing_prompt] - enabled_editing_prompts = len(editing_prompt) - elif editing_prompt_embeddings is not None: - enable_edit_guidance = True - enabled_editing_prompts = editing_prompt_embeddings.shape[0] - else: - enabled_editing_prompts = 0 - enable_edit_guidance = False - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if enable_edit_guidance: - # get safety text embeddings - if editing_prompt_embeddings is None: - edit_concepts_input = self.tokenizer( - [x for item in editing_prompt for x in repeat(item, batch_size)], - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - - edit_concepts_input_ids = edit_concepts_input.input_ids - - if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode( - edit_concepts_input_ids[:, self.tokenizer.model_max_length :] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] - edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] - else: - edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed_edit, seq_len_edit, _ = edit_concepts.shape - edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) - edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if enable_edit_guidance: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - # get the initial random noise unless the user supplied it - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=self.device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - text_embeddings.dtype, - self.device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # Initialize edit_momentum to None - edit_momentum = None - - self.uncond_estimates = None - self.text_estimates = None - self.edit_estimates = None - self.sem_guidance = None - - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] - noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] - noise_pred_edit_concepts = noise_pred_out[2:] - - # default text guidance - noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) - # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) - - if self.uncond_estimates is None: - self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) - self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() - - if self.text_estimates is None: - self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) - self.text_estimates[i] = noise_pred_text.detach().cpu() - - if self.edit_estimates is None and enable_edit_guidance: - self.edit_estimates = torch.zeros( - (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) - ) - - if self.sem_guidance is None: - self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) - - if edit_momentum is None: - edit_momentum = torch.zeros_like(noise_guidance) - - if enable_edit_guidance: - concept_weights = torch.zeros( - (len(noise_pred_edit_concepts), noise_guidance.shape[0]), - device=self.device, - dtype=noise_guidance.dtype, - ) - noise_guidance_edit = torch.zeros( - (len(noise_pred_edit_concepts), *noise_guidance.shape), - device=self.device, - dtype=noise_guidance.dtype, - ) - # noise_guidance_edit = torch.zeros_like(noise_guidance) - warmup_inds = [] - for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): - self.edit_estimates[i, c] = noise_pred_edit_concept - if isinstance(edit_guidance_scale, list): - edit_guidance_scale_c = edit_guidance_scale[c] - else: - edit_guidance_scale_c = edit_guidance_scale - - if isinstance(edit_threshold, list): - edit_threshold_c = edit_threshold[c] - else: - edit_threshold_c = edit_threshold - if isinstance(reverse_editing_direction, list): - reverse_editing_direction_c = reverse_editing_direction[c] - else: - reverse_editing_direction_c = reverse_editing_direction - if edit_weights: - edit_weight_c = edit_weights[c] - else: - edit_weight_c = 1.0 - if isinstance(edit_warmup_steps, list): - edit_warmup_steps_c = edit_warmup_steps[c] - else: - edit_warmup_steps_c = edit_warmup_steps - - if isinstance(edit_cooldown_steps, list): - edit_cooldown_steps_c = edit_cooldown_steps[c] - elif edit_cooldown_steps is None: - edit_cooldown_steps_c = i + 1 - else: - edit_cooldown_steps_c = edit_cooldown_steps - if i >= edit_warmup_steps_c: - warmup_inds.append(c) - if i >= edit_cooldown_steps_c: - noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) - continue - - noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond - # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) - tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) - - tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) - if reverse_editing_direction_c: - noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 - concept_weights[c, :] = tmp_weights - - noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c - - # torch.quantile function expects float32 - if noise_guidance_edit_tmp.dtype == torch.float32: - tmp = torch.quantile( - torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), - edit_threshold_c, - dim=2, - keepdim=False, - ) - else: - tmp = torch.quantile( - torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), - edit_threshold_c, - dim=2, - keepdim=False, - ).to(noise_guidance_edit_tmp.dtype) - - noise_guidance_edit_tmp = torch.where( - torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], - noise_guidance_edit_tmp, - torch.zeros_like(noise_guidance_edit_tmp), - ) - noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp - - # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp - - warmup_inds = torch.tensor(warmup_inds).to(self.device) - if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: - concept_weights = concept_weights.to("cpu") # Offload to cpu - noise_guidance_edit = noise_guidance_edit.to("cpu") - - concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) - concept_weights_tmp = torch.where( - concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp - ) - concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) - # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) - - noise_guidance_edit_tmp = torch.index_select( - noise_guidance_edit.to(self.device), 0, warmup_inds - ) - noise_guidance_edit_tmp = torch.einsum( - "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp - ) - noise_guidance_edit_tmp = noise_guidance_edit_tmp - noise_guidance = noise_guidance + noise_guidance_edit_tmp - - self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() - - del noise_guidance_edit_tmp - del concept_weights_tmp - concept_weights = concept_weights.to(self.device) - noise_guidance_edit = noise_guidance_edit.to(self.device) - - concept_weights = torch.where( - concept_weights < 0, torch.zeros_like(concept_weights), concept_weights - ) - - concept_weights = torch.nan_to_num(concept_weights) - - noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) - - noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum - - edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit - - if warmup_inds.shape[0] == len(noise_pred_edit_concepts): - noise_guidance = noise_guidance + noise_guidance_edit - self.sem_guidance[i] = noise_guidance_edit.detach().cpu() - - if sem_guidance is not None: - edit_guidance = sem_guidance[i].to(self.device) - noise_guidance = noise_guidance + edit_guidance - - noise_pred = noise_pred_uncond + noise_guidance - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Post-processing - image = self.decode_latents(latents) - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) - else: - has_nsfw_concept = None - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/__init__.py b/my_diffusers/pipelines/stable_diffusion/__init__.py deleted file mode 100644 index c20394845e86bc43529704398e56160877755e7e..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/__init__.py +++ /dev/null @@ -1,131 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Union - -import numpy as np -import PIL -from PIL import Image - -from ...utils import ( - BaseOutput, - OptionalDependencyNotAvailable, - is_flax_available, - is_k_diffusion_available, - is_k_diffusion_version, - is_onnx_available, - is_torch_available, - is_transformers_available, - is_transformers_version, -) - - -@dataclass -class StableDiffusionPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, or `None` if safety checking could not be performed. - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: Optional[List[bool]] - - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 -else: - from .pipeline_cycle_diffusion import CycleDiffusionPipeline - from .pipeline_stable_diffusion import StableDiffusionPipeline - from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline - from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline - from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline - from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline - from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy - from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline - from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline - from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline - from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline - from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline - from .pipeline_stable_unclip import StableUnCLIPPipeline - from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline - from .safety_checker import StableDiffusionSafetyChecker - from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer - -try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline -else: - from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline - - -try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import ( - StableDiffusionDepth2ImgPipeline, - StableDiffusionPix2PixZeroPipeline, - ) -else: - from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline - from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline - - -try: - if not ( - is_torch_available() - and is_transformers_available() - and is_k_diffusion_available() - and is_k_diffusion_version(">=", "0.0.12") - ): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 -else: - from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline - -try: - if not (is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_onnx_objects import * # noqa F403 -else: - from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline - from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline - from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline - from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy - -if is_transformers_available() and is_flax_available(): - import flax - - @flax.struct.dataclass - class FlaxStableDiffusionPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - images (`np.ndarray`) - Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content. - """ - - images: np.ndarray - nsfw_content_detected: List[bool] - - from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState - from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline - from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline - from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline - from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 7725f36ea8c73f6f5d4372b13ca568b1cc969ec4..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-311.pyc deleted file mode 100644 index e02b36c6b66ca86af6814ebba88ebef7e2c6c6a0..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/convert_from_ckpt.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-311.pyc deleted file mode 100644 index 3920dc82b57e1691b53e3542b2555bc5a0e027f6..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_cycle_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-311.pyc deleted file mode 100644 index 154eb2515c033bafe840b36c92c0d35ae8742b37..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion_img2img.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion_img2img.cpython-311.pyc deleted file mode 100644 index b8556cd280b3728f1f0f3ea9e2170ce7ffee72e2..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion_img2img.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion_inpaint.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion_inpaint.cpython-311.pyc deleted file mode 100644 index d27481f04030434d9b47438ff5f3879c9b3f49c7..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_flax_stable_diffusion_inpaint.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-311.pyc deleted file mode 100644 index c9fa0f6a8547d37ac66a8d89e7d5dd2953241583..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-311.pyc deleted file mode 100644 index 311eaeffce56a9351b1f6fa648e31529b128df42..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_img2img.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-311.pyc deleted file mode 100644 index 9b9dfbf62f77f1e4a7951b768167da20e014a3ff..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-311.pyc deleted file mode 100644 index 9bf893c7a058c3df3948991d3c3004bc8c4eb32f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-311.pyc deleted file mode 100644 index 3f5c250abb5218c8db07a9b0b5255ac018010f64..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-311.pyc deleted file mode 100644 index fffe3a18e7fe1134fc20860734cc2a95f675f37b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_controlnet.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_controlnet.cpython-311.pyc deleted file mode 100644 index 36c2aee196374e814a5e1c8e1bad2d1c0f14b994..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_controlnet.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-311.pyc deleted file mode 100644 index b93d9c3db74ea35fcee4f70c91bfb2ffa5a169fe..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_depth2img.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-311.pyc deleted file mode 100644 index a3481021a5c3245562d5f5f9e0de76e5754710b3..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_image_variation.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-311.pyc deleted file mode 100644 index 7b61fcaec685eeabfe61a1523ded443c21a8f754..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_img2img.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-311.pyc deleted file mode 100644 index f9f5574043e3ea4fa9be52fc444953f8cc1fd307..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-311.pyc deleted file mode 100644 index 56892c587001cba6a07376556703160d36a124ab..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-311.pyc deleted file mode 100644 index 062c9b943c89034852c22181739468ae7432b11a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_instruct_pix2pix.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_k_diffusion.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_k_diffusion.cpython-311.pyc deleted file mode 100644 index 841fe76a56f69e9811abca4638d8c2c46c306c3a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_k_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-311.pyc deleted file mode 100644 index ba8014c725b99ca92bf8557fa1535572858bae42..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_latent_upscale.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-311.pyc deleted file mode 100644 index 3293a7eb21796750c48371c8fdbf8847c8c7cb2f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_panorama.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-311.pyc deleted file mode 100644 index 32d8c2027abb23fe0a086d984b18bca814a6f87b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-311.pyc deleted file mode 100644 index ab6d6dd33fb6707378528cd566f55ca2a6e56eaf..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_sag.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-311.pyc deleted file mode 100644 index 03fa980db4355d8c7fc84fbc5a9997402783e38c..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_diffusion_upscale.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-311.pyc deleted file mode 100644 index 523190888d5f31ff79c06f06ac6787513e68d206..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-311.pyc deleted file mode 100644 index cdcf54f6c0466701606a0ed6534fe320feffaf36..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/pipeline_stable_unclip_img2img.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-311.pyc deleted file mode 100644 index 106048b151f84d34bf68a913c22fdde7308b3b05..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker_flax.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker_flax.cpython-311.pyc deleted file mode 100644 index fc9740058a54d04ec7cb1940d90459b9da0b36f8..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/safety_checker_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-311.pyc deleted file mode 100644 index 7559623349b56a7b4d7d3ae4a98020a70e0e8e62..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion/__pycache__/stable_unclip_image_normalizer.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/my_diffusers/pipelines/stable_diffusion/convert_from_ckpt.py deleted file mode 100644 index 81bbbdeea72c37992a728029e19fd30e93006009..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ /dev/null @@ -1,1289 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Conversion script for the Stable Diffusion checkpoints.""" - -import os -import re -import tempfile -from typing import Optional - -import requests -import torch -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) - -from diffusers import ( - AutoencoderKL, - ControlNetModel, - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LDMTextToImagePipeline, - LMSDiscreteScheduler, - PNDMScheduler, - PriorTransformer, - StableDiffusionControlNetPipeline, - StableDiffusionPipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - UnCLIPScheduler, - UNet2DConditionModel, -) -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer - -from ...utils import is_omegaconf_available, is_safetensors_available, logging -from ...utils.import_utils import BACKENDS_MAPPING - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if controlnet: - unet_params = original_config.model.params.control_stage_config.params - else: - unet_params = original_config.model.params.unet_config.params - - vae_params = original_config.model.params.first_stage_config.params.ddconfig - - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) - - head_dim = unet_params.num_heads if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim = [5, 10, 20, 20] - - class_embed_type = None - projection_class_embeddings_input_dim = None - - if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels - else: - raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") - - config = dict( - sample_size=image_size // vae_scale_factor, - in_channels=unet_params.in_channels, - down_block_types=tuple(down_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=unet_params.num_res_blocks, - cross_attention_dim=unet_params.context_dim, - attention_head_dim=head_dim, - use_linear_projection=use_linear_projection, - class_embed_type=class_embed_type, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - ) - - if not controlnet: - config["out_channels"] = unet_params.out_channels - config["up_block_types"] = tuple(up_block_types) - - return config - - -def create_vae_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim - - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = dict( - sample_size=image_size, - in_channels=vae_params.in_channels, - out_channels=vae_params.out_ch, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=vae_params.z_channels, - layers_per_block=vae_params.num_res_blocks, - ) - return config - - -def create_diffusers_schedular(original_config): - schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, - beta_schedule="scaled_linear", - ) - return schedular - - -def create_ldm_bert_config(original_config): - bert_params = original_config.model.parms.cond_stage_config.params - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - print(f"Checkpoint {path} has both EMA and non-EMA weights.") - print( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - print( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - if controlnet: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) - - return text_model - - -textenc_conversion_lst = [ - ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), - ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - - -def convert_paint_by_example_checkpoint(checkpoint): - config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def convert_open_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - if "cond_stage_model.model.text_projection" in checkpoint: - d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) - else: - d_model = 1024 - - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - - for key in keys: - if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer - continue - if key in textenc_conversion_map: - text_model_dict[textenc_conversion_map[key]] = checkpoint[key] - if key.startswith("cond_stage_model.model.transformer."): - new_key = key[len("cond_stage_model.model.transformer.") :] - if new_key.endswith(".in_proj_weight"): - new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] - elif new_key.endswith(".in_proj_bias"): - new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] - else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - - text_model_dict[new_key] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) - - return text_model - - -def stable_unclip_image_encoder(original_config): - """ - Returns the image processor and clip image encoder for the img2img unclip pipeline. - - We currently know of two types of stable unclip models which separately use the clip and the openclip image - encoders. - """ - - image_embedder_config = original_config.model.params.embedder_config - - sd_clip_image_embedder_class = image_embedder_config.target - sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] - - if sd_clip_image_embedder_class == "ClipImageEmbedder": - clip_model_name = image_embedder_config.params.model - - if clip_model_name == "ViT-L/14": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") - else: - raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") - - elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - else: - raise NotImplementedError( - f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" - ) - - return feature_extractor, image_encoder - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config.model.params.noise_aug_config - noise_aug_class = noise_aug_config.target - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - -def load_pipeline_from_original_stable_diffusion_ckpt( - checkpoint_path: str, - original_config_file: str = None, - image_size: int = 512, - prediction_type: str = None, - model_type: str = None, - extract_ema: bool = False, - scheduler_type: str = "pndm", - num_in_channels: Optional[int] = None, - upcast_attention: Optional[bool] = None, - device: str = None, - from_safetensors: bool = False, - stable_unclip: Optional[str] = None, - stable_unclip_prior: Optional[str] = None, - clip_stats_path: Optional[str] = None, - controlnet: Optional[bool] = None, -) -> StableDiffusionPipeline: - """ - Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` - config file. - - Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the - global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is - recommended that you override the default values and/or supply an `original_config_file` wherever possible. - - Args: - checkpoint_path (`str`): Path to `.ckpt` file. - original_config_file (`str`): - Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically - inferred by looking for a key that only exists in SD2.0 models. - image_size (`int`, *optional*, defaults to 512): - The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 - Base. Use 768 for Stable Diffusion v2. - prediction_type (`str`, *optional*): - The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable - Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. - num_in_channels (`int`, *optional*, defaults to None): - The number of input channels. If `None`, it will be automatically inferred. - scheduler_type (`str`, *optional*, defaults to 'pndm'): - Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", - "ddim"]`. - model_type (`str`, *optional*, defaults to `None`): - The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", - "FrozenCLIPEmbedder", "PaintByExample"]`. - extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for - checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to - `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for - inference. Non-EMA weights are usually better to continue fine-tuning. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. This is necessary when running stable - diffusion 2.1. - device (`str`, *optional*, defaults to `None`): - The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is - in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A - StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. - """ - if prediction_type == "v-prediction": - prediction_type = "v_prediction" - - if not is_omegaconf_available(): - raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) - - from omegaconf import OmegaConf - - if from_safetensors: - if not is_safetensors_available(): - raise ValueError(BACKENDS_MAPPING["safetensors"][1]) - - from safetensors import safe_open - - checkpoint = {} - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) - else: - checkpoint = torch.load(checkpoint_path, map_location=device) - - # Sometimes models don't have the global_step item - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - print("global_step key not found in model") - global_step = None - - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - with tempfile.TemporaryDirectory() as tmpdir: - if original_config_file is None: - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - - original_config_file = os.path.join(tmpdir, "inference.yaml") - if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: - if not os.path.isfile("v2-inference-v.yaml"): - # model_type = "v2" - r = requests.get( - " https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - ) - open(original_config_file, "wb").write(r.content) - - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - else: - if not os.path.isfile("v1-inference.yaml"): - # model_type = "v1" - r = requests.get( - " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - ) - open(original_config_file, "wb").write(r.content) - - original_config = OmegaConf.load(original_config_file) - - if num_in_channels is not None: - original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels - - if ( - "parameterization" in original_config["model"]["params"] - and original_config["model"]["params"]["parameterization"] == "v" - ): - if prediction_type is None: - # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` - # as it relies on a brittle global step parameter here - prediction_type = "epsilon" if global_step == 875000 else "v_prediction" - if image_size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - image_size = 512 if global_step == 875000 else 768 - else: - if prediction_type is None: - prediction_type = "epsilon" - if image_size is None: - image_size = 512 - - num_train_timesteps = original_config.model.params.timesteps - beta_start = original_config.model.params.linear_start - beta_end = original_config.model.params.linear_end - - scheduler = DDIMScheduler( - beta_end=beta_end, - beta_schedule="scaled_linear", - beta_start=beta_start, - num_train_timesteps=num_train_timesteps, - steps_offset=1, - clip_sample=False, - set_alpha_to_one=False, - prediction_type=prediction_type, - ) - # make sure scheduler works correctly with DDIM - scheduler.register_to_config(clip_sample=False) - - if scheduler_type == "pndm": - config = dict(scheduler.config) - config["skip_prk_steps"] = True - scheduler = PNDMScheduler.from_config(config) - elif scheduler_type == "lms": - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "heun": - scheduler = HeunDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler": - scheduler = EulerDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "euler-ancestral": - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) - elif scheduler_type == "dpm": - scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) - elif scheduler_type == "ddim": - scheduler = scheduler - else: - raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") - - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - unet_config["upcast_attention"] = upcast_attention - unet = UNet2DConditionModel(**unet_config) - - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema - ) - - unet.load_state_dict(converted_unet_checkpoint) - - # Convert the VAE model. - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - - # Convert the text model. - if model_type is None: - model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] - logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - - if controlnet is None: - controlnet = "control_stage_config" in original_config.model.params - - if controlnet and model_type != "FrozenCLIPEmbedder": - raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'") - - if model_type == "FrozenOpenCLIPEmbedder": - text_model = convert_open_clip_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") - - if stable_unclip is None: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - else: - image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( - original_config, clip_stats_path=clip_stats_path, device=device - ) - - if stable_unclip == "img2img": - feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) - - pipe = StableUnCLIPImg2ImgPipeline( - # image encoding components - feature_extractor=feature_extractor, - image_encoder=image_encoder, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - elif stable_unclip == "txt2img": - if stable_unclip_prior is None or stable_unclip_prior == "karlo": - karlo_model = "kakaobrain/karlo-v1-alpha" - prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") - - prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") - - prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") - prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) - else: - raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") - - pipe = StableUnCLIPPipeline( - # prior components - prior_tokenizer=prior_tokenizer, - prior_text_encoder=prior_text_model, - prior=prior, - prior_scheduler=prior_scheduler, - # image noising components - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - # regular denoising components - tokenizer=tokenizer, - text_encoder=text_model, - unet=unet, - scheduler=scheduler, - # vae - vae=vae, - ) - else: - raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") - elif model_type == "PaintByExample": - vision_model = convert_paint_by_example_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") - pipe = PaintByExamplePipeline( - vae=vae, - image_encoder=vision_model, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=feature_extractor, - ) - elif model_type == "FrozenCLIPEmbedder": - text_model = convert_ldm_clip_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") - - if controlnet: - # Convert the ControlNetModel model. - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - controlnet_model = ControlNetModel(**ctrlnet_config) - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True - ) - controlnet_model.load_state_dict(converted_ctrl_checkpoint) - - pipe = StableDiffusionControlNetPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet_model, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - else: - text_config = create_ldm_bert_config(original_config) - text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) - tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - - return pipe diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/my_diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py deleted file mode 100644 index e977071b9c6ce7b78625490c467d4a318070adee..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ /dev/null @@ -1,776 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from diffusers.utils import is_accelerate_available, is_accelerate_version - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler -from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): - # 1. get previous step value (=t-1) - prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps - - if prev_timestep <= 0: - return clean_latents - - # 2. compute alphas, betas - alpha_prod_t = scheduler.alphas_cumprod[timestep] - alpha_prod_t_prev = ( - scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod - ) - - variance = scheduler._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) - - # direction pointing to x_t - e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) - dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t - noise = std_dev_t * randn_tensor( - clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator - ) - prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise - - return prev_latents - - -def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): - # 1. get previous step value (=t-1) - prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps - - # 2. compute alphas, betas - alpha_prod_t = scheduler.alphas_cumprod[timestep] - alpha_prod_t_prev = ( - scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod - ) - - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) - - # 4. Clip "predicted x_0" - if scheduler.config.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = scheduler._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred - - noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( - variance ** (0.5) * eta - ) - return noise - - -class CycleDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image to image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: DDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs - def check_inputs( - self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - - return timesteps, num_inference_steps - t_start - - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): - image = image.to(device=device, dtype=dtype) - - batch_size = image.shape[0] - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) - - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) - - # add noise to latents using the timestep - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - # get latents - clean_latents = init_latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents - - return latents, clean_latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - source_prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - source_guidance_scale: Optional[float] = 1, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - source_guidance_scale (`float`, *optional*, defaults to 1): - Guidance scale for the source prompt. This is useful to control the amount of influence the source - prompt for encoding. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.1): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 1. Check inputs - self.check_inputs(prompt, strength, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - ) - source_prompt_embeds = self._encode_prompt( - source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None - ) - - # 4. Preprocess image - image = preprocess(image) - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - latents, clean_latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator - ) - source_latents = latents - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - generator = extra_step_kwargs.pop("generator", None) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - source_latent_model_input = torch.cat([source_latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) - - # predict the noise residual - concat_latent_model_input = torch.stack( - [ - source_latent_model_input[0], - latent_model_input[0], - source_latent_model_input[1], - latent_model_input[1], - ], - dim=0, - ) - concat_prompt_embeds = torch.stack( - [ - source_prompt_embeds[0], - prompt_embeds[0], - source_prompt_embeds[1], - prompt_embeds[1], - ], - dim=0, - ) - concat_noise_pred = self.unet( - concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds - ).sample - - # perform guidance - ( - source_noise_pred_uncond, - noise_pred_uncond, - source_noise_pred_text, - noise_pred_text, - ) = concat_noise_pred.chunk(4, dim=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( - source_noise_pred_text - source_noise_pred_uncond - ) - - # Sample source_latents from the posterior distribution. - prev_source_latents = posterior_sample( - self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs - ) - # Compute noise. - noise = compute_noise( - self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs - ) - source_latents = prev_source_latents - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs - ).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py deleted file mode 100644 index d23893b15777ee3e53017e4d5d77c35a4df3f641..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ /dev/null @@ -1,431 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from functools import partial -from typing import Dict, List, Optional, Union - -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict -from flax.jax_utils import unreplicate -from flax.training.common_utils import shard -from packaging import version -from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel - -from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel -from ...schedulers import ( - FlaxDDIMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, -) -from ...utils import deprecate, logging -from ..pipeline_flax_utils import FlaxDiffusionPipeline -from . import FlaxStableDiffusionPipelineOutput -from .safety_checker_flax import FlaxStableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -# Set to True to use python for loop instead of jax.fori_loop for easier debugging -DEBUG = False - - -class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`FlaxAutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`FlaxCLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or - [`FlaxDPMSolverMultistepScheduler`]. - safety_checker ([`FlaxStableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: FlaxAutoencoderKL, - text_encoder: FlaxCLIPTextModel, - tokenizer: CLIPTokenizer, - unet: FlaxUNet2DConditionModel, - scheduler: Union[ - FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler - ], - safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - dtype: jnp.dtype = jnp.float32, - ): - super().__init__() - self.dtype = dtype - - if safety_checker is None: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def prepare_inputs(self, prompt: Union[str, List[str]]): - if not isinstance(prompt, (str, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - return text_input.input_ids - - def _get_has_nsfw_concepts(self, features, params): - has_nsfw_concepts = self.safety_checker(features, params) - return has_nsfw_concepts - - def _run_safety_checker(self, images, safety_model_params, jit=False): - # safety_model_params should already be replicated when jit is True - pil_images = [Image.fromarray(image) for image in images] - features = self.feature_extractor(pil_images, return_tensors="np").pixel_values - - if jit: - features = shard(features) - has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) - has_nsfw_concepts = unshard(has_nsfw_concepts) - safety_model_params = unreplicate(safety_model_params) - else: - has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) - - images_was_copied = False - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept: - if not images_was_copied: - images_was_copied = True - images = images.copy() - - #images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image - - if any(has_nsfw_concepts): - warnings.warn( - "Potential NSFW content was detected in one or more images. A black image 3 will be returned" - " instead. Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts - - def _generate( - self, - prompt_ids: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int, - height: int, - width: int, - guidance_scale: float, - latents: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get prompt text embeddings - prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] - - # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` - # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` - batch_size = prompt_ids.shape[0] - - max_length = prompt_ids.shape[-1] - - if neg_prompt_ids is None: - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ).input_ids - else: - uncond_input = neg_prompt_ids - negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) - - latents_shape = ( - batch_size, - self.unet.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - def loop_body(step, args): - latents, scheduler_state = args - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = jnp.concatenate([latents] * 2) - - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - timestep = jnp.broadcast_to(t, latents_input.shape[0]) - - latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) - - # predict the noise residual - noise_pred = self.unet.apply( - {"params": params["unet"]}, - jnp.array(latents_input), - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, - ).sample - # perform guidance - noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, scheduler_state - - scheduler_state = self.scheduler.set_timesteps( - params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape - ) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * params["scheduler"].init_noise_sigma - - if DEBUG: - # run with python for loop - for i in range(num_inference_steps): - latents, scheduler_state = loop_body(i, (latents, scheduler_state)) - else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) - - # scale and decode the image latents with vae - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample - - image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) - return image - - def __call__( - self, - prompt_ids: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int = 50, - height: Optional[int] = None, - width: Optional[int] = None, - guidance_scale: Union[float, jnp.array] = 7.5, - latents: jnp.array = None, - neg_prompt_ids: jnp.array = None, - return_dict: bool = True, - jit: bool = False, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - latents (`jnp.array`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. tensor will ge generated - by sampling using the supplied random `generator`. - jit (`bool`, defaults to `False`): - Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument - exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of - a plain tuple. - - Returns: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - if isinstance(guidance_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. - guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - guidance_scale = guidance_scale[:, None] - - if jit: - images = _p_generate( - self, - prompt_ids, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, - ) - else: - images = self._generate( - prompt_ids, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, - ) - - if self.safety_checker is not None: - safety_params = params["safety_checker"] - images_uint8_casted = (images * 255).round().astype("uint8") - num_devices, batch_size = images.shape[:2] - - images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) - images = np.asarray(images) - - # block images - if any(has_nsfw_concept): - for i, is_nsfw in enumerate(has_nsfw_concept): - if is_nsfw: - images[i] = np.asarray(images_uint8_casted[i]) - - images = images.reshape(num_devices, batch_size, height, width, 3) - else: - images = np.asarray(images) - has_nsfw_concept = False - - if not return_dict: - return (images, has_nsfw_concept) - - return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. -# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). -@partial( - jax.pmap, - in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), - static_broadcasted_argnums=(0, 4, 5, 6), -) -def _p_generate( - pipe, - prompt_ids, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, -): - return pipe._generate( - prompt_ids, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, - ) - - -@partial(jax.pmap, static_broadcasted_argnums=(0,)) -def _p_get_has_nsfw_concepts(pipe, features, params): - return pipe._get_has_nsfw_concepts(features, params) - - -def unshard(x: jnp.ndarray): - # einops.rearrange(x, 'd b ... -> (d b) ...') - num_devices, batch_size = x.shape[:2] - rest = x.shape[2:] - return x.reshape(num_devices * batch_size, *rest) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py deleted file mode 100644 index 736e273244c7982489552a2990660ddce50d10fd..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from functools import partial -from typing import Dict, List, Optional, Union - -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict -from flax.jax_utils import unreplicate -from flax.training.common_utils import shard -from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel - -from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel -from ...schedulers import ( - FlaxDDIMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, -) -from ...utils import PIL_INTERPOLATION, logging -from ..pipeline_flax_utils import FlaxDiffusionPipeline -from . import FlaxStableDiffusionPipelineOutput -from .safety_checker_flax import FlaxStableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -# Set to True to use python for loop instead of jax.fori_loop for easier debugging -DEBUG = False - - -class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): - r""" - Pipeline for image-to-image generation using Stable Diffusion. - - This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`FlaxAutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`FlaxCLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or - [`FlaxDPMSolverMultistepScheduler`]. - safety_checker ([`FlaxStableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: FlaxAutoencoderKL, - text_encoder: FlaxCLIPTextModel, - tokenizer: CLIPTokenizer, - unet: FlaxUNet2DConditionModel, - scheduler: Union[ - FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler - ], - safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - dtype: jnp.dtype = jnp.float32, - ): - super().__init__() - self.dtype = dtype - - if safety_checker is None: - logger.warn( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): - if not isinstance(prompt, (str, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if not isinstance(image, (Image.Image, list)): - raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") - - if isinstance(image, Image.Image): - image = [image] - - processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) - - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - return text_input.input_ids, processed_images - - def _get_has_nsfw_concepts(self, features, params): - has_nsfw_concepts = self.safety_checker(features, params) - return has_nsfw_concepts - - def _run_safety_checker(self, images, safety_model_params, jit=False): - # safety_model_params should already be replicated when jit is True - pil_images = [Image.fromarray(image) for image in images] - features = self.feature_extractor(pil_images, return_tensors="np").pixel_values - - if jit: - features = shard(features) - has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) - has_nsfw_concepts = unshard(has_nsfw_concepts) - safety_model_params = unreplicate(safety_model_params) - else: - has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) - - images_was_copied = False - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept: - if not images_was_copied: - images_was_copied = True - images = images.copy() - - #images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image - - if any(has_nsfw_concepts): - warnings.warn( - "Potential NSFW content was detected in one or more images. A black image 1 will be returned" - " instead. Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts - - def get_timestep_start(self, num_inference_steps, strength): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - - return t_start - - def _generate( - self, - prompt_ids: jnp.array, - image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - start_timestep: int, - num_inference_steps: int, - height: int, - width: int, - guidance_scale: float, - noise: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get prompt text embeddings - prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] - - # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` - # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` - batch_size = prompt_ids.shape[0] - - max_length = prompt_ids.shape[-1] - - if neg_prompt_ids is None: - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ).input_ids - else: - uncond_input = neg_prompt_ids - negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) - - latents_shape = ( - batch_size, - self.unet.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if noise is None: - noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) - else: - if noise.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") - - # Create init_latents - init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist - init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) - init_latents = self.vae.config.scaling_factor * init_latents - - def loop_body(step, args): - latents, scheduler_state = args - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = jnp.concatenate([latents] * 2) - - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - timestep = jnp.broadcast_to(t, latents_input.shape[0]) - - latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) - - # predict the noise residual - noise_pred = self.unet.apply( - {"params": params["unet"]}, - jnp.array(latents_input), - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, - ).sample - # perform guidance - noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, scheduler_state - - scheduler_state = self.scheduler.set_timesteps( - params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape - ) - - latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) - - latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * params["scheduler"].init_noise_sigma - - if DEBUG: - # run with python for loop - for i in range(start_timestep, num_inference_steps): - latents, scheduler_state = loop_body(i, (latents, scheduler_state)) - else: - latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) - - # scale and decode the image latents with vae - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample - - image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) - return image - - def __call__( - self, - prompt_ids: jnp.array, - image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - strength: float = 0.8, - num_inference_steps: int = 50, - height: Optional[int] = None, - width: Optional[int] = None, - guidance_scale: Union[float, jnp.array] = 7.5, - noise: jnp.array = None, - neg_prompt_ids: jnp.array = None, - return_dict: bool = True, - jit: bool = False, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt_ids (`jnp.array`): - The prompt or prompts to guide the image generation. - image (`jnp.array`): - Array representing an image batch, that will be used as the starting point for the process. - params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights - prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - noise (`jnp.array`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. tensor will ge generated - by sampling using the supplied random `generator`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of - a plain tuple. - jit (`bool`, defaults to `False`): - Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument - exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. - - Returns: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - if isinstance(guidance_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. - guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - guidance_scale = guidance_scale[:, None] - - start_timestep = self.get_timestep_start(num_inference_steps, strength) - - if jit: - images = _p_generate( - self, - prompt_ids, - image, - params, - prng_seed, - start_timestep, - num_inference_steps, - height, - width, - guidance_scale, - noise, - neg_prompt_ids, - ) - else: - images = self._generate( - prompt_ids, - image, - params, - prng_seed, - start_timestep, - num_inference_steps, - height, - width, - guidance_scale, - noise, - neg_prompt_ids, - ) - - if self.safety_checker is not None: - safety_params = params["safety_checker"] - images_uint8_casted = (images * 255).round().astype("uint8") - num_devices, batch_size = images.shape[:2] - - images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) - images = np.asarray(images) - - # block images - if any(has_nsfw_concept): - for i, is_nsfw in enumerate(has_nsfw_concept): - if is_nsfw: - images[i] = np.asarray(images_uint8_casted[i]) - - images = images.reshape(num_devices, batch_size, height, width, 3) - else: - images = np.asarray(images) - has_nsfw_concept = False - - if not return_dict: - return (images, has_nsfw_concept) - - return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. -# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). -@partial( - jax.pmap, - in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), - static_broadcasted_argnums=(0, 5, 6, 7, 8), -) -def _p_generate( - pipe, - prompt_ids, - image, - params, - prng_seed, - start_timestep, - num_inference_steps, - height, - width, - guidance_scale, - noise, - neg_prompt_ids, -): - return pipe._generate( - prompt_ids, - image, - params, - prng_seed, - start_timestep, - num_inference_steps, - height, - width, - guidance_scale, - noise, - neg_prompt_ids, - ) - - -@partial(jax.pmap, static_broadcasted_argnums=(0,)) -def _p_get_has_nsfw_concepts(pipe, features, params): - return pipe._get_has_nsfw_concepts(features, params) - - -def unshard(x: jnp.ndarray): - # einops.rearrange(x, 'd b ... -> (d b) ...') - num_devices, batch_size = x.shape[:2] - rest = x.shape[2:] - return x.reshape(num_devices * batch_size, *rest) - - -def preprocess(image, dtype): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = jnp.array(image).astype(dtype) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - return 2.0 * image - 1.0 diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py deleted file mode 100644 index 7629f33c21e4f6a2fe792caa0dc8c77de179e143..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ /dev/null @@ -1,523 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from functools import partial -from typing import Dict, List, Optional, Union - -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict -from flax.jax_utils import unreplicate -from flax.training.common_utils import shard -from packaging import version -from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel - -from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel -from ...schedulers import ( - FlaxDDIMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, -) -from ...utils import PIL_INTERPOLATION, deprecate, logging -from ..pipeline_flax_utils import FlaxDiffusionPipeline -from . import FlaxStableDiffusionPipelineOutput -from .safety_checker_flax import FlaxStableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -# Set to True to use python for loop instead of jax.fori_loop for easier debugging -DEBUG = False - - -class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. - - This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`FlaxAutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`FlaxCLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or - [`FlaxDPMSolverMultistepScheduler`]. - safety_checker ([`FlaxStableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: FlaxAutoencoderKL, - text_encoder: FlaxCLIPTextModel, - tokenizer: CLIPTokenizer, - unet: FlaxUNet2DConditionModel, - scheduler: Union[ - FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler - ], - safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - dtype: jnp.dtype = jnp.float32, - ): - super().__init__() - self.dtype = dtype - - if safety_checker is None: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def prepare_inputs( - self, - prompt: Union[str, List[str]], - image: Union[Image.Image, List[Image.Image]], - mask: Union[Image.Image, List[Image.Image]], - ): - if not isinstance(prompt, (str, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if not isinstance(image, (Image.Image, list)): - raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") - - if isinstance(image, Image.Image): - image = [image] - - if not isinstance(mask, (Image.Image, list)): - raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") - - if isinstance(mask, Image.Image): - mask = [mask] - - processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) - processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) - # processed_masks[processed_masks < 0.5] = 0 - processed_masks = processed_masks.at[processed_masks < 0.5].set(0) - # processed_masks[processed_masks >= 0.5] = 1 - processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) - - processed_masked_images = processed_images * (processed_masks < 0.5) - - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - return text_input.input_ids, processed_masked_images, processed_masks - - def _get_has_nsfw_concepts(self, features, params): - has_nsfw_concepts = self.safety_checker(features, params) - return has_nsfw_concepts - - def _run_safety_checker(self, images, safety_model_params, jit=False): - # safety_model_params should already be replicated when jit is True - pil_images = [Image.fromarray(image) for image in images] - features = self.feature_extractor(pil_images, return_tensors="np").pixel_values - - if jit: - features = shard(features) - has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) - has_nsfw_concepts = unshard(has_nsfw_concepts) - safety_model_params = unreplicate(safety_model_params) - else: - has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) - - images_was_copied = False - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept: - if not images_was_copied: - images_was_copied = True - images = images.copy() - - #images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image - - if any(has_nsfw_concepts): - warnings.warn( - "Potential NSFW content was detected in one or more images. A black image 2 will be returned" - " instead. Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts - - def _generate( - self, - prompt_ids: jnp.array, - mask: jnp.array, - masked_image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int, - height: int, - width: int, - guidance_scale: float, - latents: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # get prompt text embeddings - prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] - - # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` - # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` - batch_size = prompt_ids.shape[0] - - max_length = prompt_ids.shape[-1] - - if neg_prompt_ids is None: - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ).input_ids - else: - uncond_input = neg_prompt_ids - negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) - - latents_shape = ( - batch_size, - self.vae.config.latent_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - prng_seed, mask_prng_seed = jax.random.split(prng_seed) - - masked_image_latent_dist = self.vae.apply( - {"params": params["vae"]}, masked_image, method=self.vae.encode - ).latent_dist - masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents - del mask_prng_seed - - mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") - - # 8. Check that sizes of mask, masked image and latents match - num_channels_latents = self.vae.config.latent_channels - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - def loop_body(step, args): - latents, mask, masked_image_latents, scheduler_state = args - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = jnp.concatenate([latents] * 2) - mask_input = jnp.concatenate([mask] * 2) - masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) - - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - timestep = jnp.broadcast_to(t, latents_input.shape[0]) - - latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) - # concat latents, mask, masked_image_latents in the channel dimension - latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) - - # predict the noise residual - noise_pred = self.unet.apply( - {"params": params["unet"]}, - jnp.array(latents_input), - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, - ).sample - # perform guidance - noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, mask, masked_image_latents, scheduler_state - - scheduler_state = self.scheduler.set_timesteps( - params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape - ) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * params["scheduler"].init_noise_sigma - - if DEBUG: - # run with python for loop - for i in range(num_inference_steps): - latents, mask, masked_image_latents, scheduler_state = loop_body( - i, (latents, mask, masked_image_latents, scheduler_state) - ) - else: - latents, _, _, _ = jax.lax.fori_loop( - 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) - ) - - # scale and decode the image latents with vae - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample - - image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) - return image - - def __call__( - self, - prompt_ids: jnp.array, - mask: jnp.array, - masked_image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int = 50, - height: Optional[int] = None, - width: Optional[int] = None, - guidance_scale: Union[float, jnp.array] = 7.5, - latents: jnp.array = None, - neg_prompt_ids: jnp.array = None, - return_dict: bool = True, - jit: bool = False, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - latents (`jnp.array`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. tensor will ge generated - by sampling using the supplied random `generator`. - jit (`bool`, defaults to `False`): - Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument - exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of - a plain tuple. - - Returns: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") - mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") - - if isinstance(guidance_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. - guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - guidance_scale = guidance_scale[:, None] - - if jit: - images = _p_generate( - self, - prompt_ids, - mask, - masked_image, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, - ) - else: - images = self._generate( - prompt_ids, - mask, - masked_image, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, - ) - - if self.safety_checker is not None: - safety_params = params["safety_checker"] - images_uint8_casted = (images * 255).round().astype("uint8") - num_devices, batch_size = images.shape[:2] - - images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) - images = np.asarray(images) - - # block images - if any(has_nsfw_concept): - for i, is_nsfw in enumerate(has_nsfw_concept): - if is_nsfw: - images[i] = np.asarray(images_uint8_casted[i]) - - images = images.reshape(num_devices, batch_size, height, width, 3) - else: - images = np.asarray(images) - has_nsfw_concept = False - - if not return_dict: - return (images, has_nsfw_concept) - - return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. -# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). -@partial( - jax.pmap, - in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), - static_broadcasted_argnums=(0, 6, 7, 8), -) -def _p_generate( - pipe, - prompt_ids, - mask, - masked_image, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, -): - return pipe._generate( - prompt_ids, - mask, - masked_image, - params, - prng_seed, - num_inference_steps, - height, - width, - guidance_scale, - latents, - neg_prompt_ids, - ) - - -@partial(jax.pmap, static_broadcasted_argnums=(0,)) -def _p_get_has_nsfw_concepts(pipe, features, params): - return pipe._get_has_nsfw_concepts(features, params) - - -def unshard(x: jnp.ndarray): - # einops.rearrange(x, 'd b ... -> (d b) ...') - num_devices, batch_size = x.shape[:2] - rest = x.shape[2:] - return x.reshape(num_devices * batch_size, *rest) - - -def preprocess_image(image, dtype): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = jnp.array(image).astype(dtype) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask, dtype): - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w, h)) - mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 - mask = jnp.expand_dims(mask, axis=(0, 1)) - - return mask diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py deleted file mode 100644 index 55b996e56bb319fd31296b6bb8864906c5eda282..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import deprecate, logging -from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput - - -logger = logging.get_logger(__name__) - - -class OnnxStableDiffusionPipeline(DiffusionPipeline): - vae_encoder: OnnxRuntimeModel - vae_decoder: OnnxRuntimeModel - text_encoder: OnnxRuntimeModel - tokenizer: CLIPTokenizer - unet: OnnxRuntimeModel - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] - safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor - - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae_encoder: OnnxRuntimeModel, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae_encoder=vae_encoder, - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[np.random.RandomState] = None, - latents: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - ): - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if generator is None: - generator = np.random - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt - ) - - # get the initial random noise unless the user supplied it - latents_dtype = prompt_embeds.dtype - latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - if latents is None: - latents = generator.randn(*latents_shape).astype(latents_dtype) - elif latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - latents = latents * np.float64(self.scheduler.init_noise_sigma) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - timestep_dtype = next( - (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="np" - ).pixel_values.astype(image.dtype) - - images, has_nsfw_concept = [], [] - for i in range(image.shape[0]): - image_i, has_nsfw_concept_i = self.safety_checker( - clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] - ) - images.append(image_i) - has_nsfw_concept.append(has_nsfw_concept_i[0]) - image = np.concatenate(images) - else: - has_nsfw_concept = None - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - -class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): - def __init__( - self, - vae_encoder: OnnxRuntimeModel, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - ): - deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." - deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) - super().__init__( - vae_encoder=vae_encoder, - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py deleted file mode 100644 index 9123e5f3296d948eeb0183c670886495563d608f..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ /dev/null @@ -1,465 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import PIL_INTERPOLATION, deprecate, logging -from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image to image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - vae_encoder: OnnxRuntimeModel - vae_decoder: OnnxRuntimeModel - text_encoder: OnnxRuntimeModel - tokenizer: CLIPTokenizer - unet: OnnxRuntimeModel - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] - safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor - - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae_encoder: OnnxRuntimeModel, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae_encoder=vae_encoder, - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def __call__( - self, - prompt: Union[str, List[str]], - image: Union[np.ndarray, PIL.Image.Image] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[np.random.RandomState] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - image (`np.ndarray` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`np.random.RandomState`, *optional*): - A np.random.RandomState to make generation deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if generator is None: - generator = np.random - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - image = preprocess(image).cpu().numpy() - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt - ) - - latents_dtype = prompt_embeds.dtype - image = image.astype(latents_dtype) - # encode the init image into latents and scale the latents - init_latents = self.vae_encoder(sample=image)[0] - init_latents = 0.18215 * init_latents - - if isinstance(prompt, str): - prompt = [prompt] - if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = len(prompt) // init_latents.shape[0] - init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) - elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." - ) - else: - init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps.numpy()[-init_timestep] - timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) - - # add noise to latents using the timesteps - noise = generator.randn(*init_latents.shape).astype(latents_dtype) - init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) - ) - init_latents = init_latents.numpy() - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - latents = init_latents - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].numpy() - - timestep_dtype = next( - (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ - 0 - ] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="np" - ).pixel_values.astype(image.dtype) - # safety_checker does not support batched inputs yet - images, has_nsfw_concept = [], [] - for i in range(image.shape[0]): - image_i, has_nsfw_concept_i = self.safety_checker( - clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] - ) - images.append(image_i) - has_nsfw_concept.append(has_nsfw_concept_i[0]) - image = np.concatenate(images) - else: - has_nsfw_concept = None - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py deleted file mode 100644 index 46b5ce5ad6e49b4c04d51267851e39afaac590c6..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import PIL_INTERPOLATION, deprecate, logging -from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -NUM_UNET_INPUT_CHANNELS = 9 -NUM_LATENT_CHANNELS = 4 - - -def prepare_mask_and_masked_image(image, mask, latents_shape): - image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) - image = image[None].transpose(0, 3, 1, 2) - image = image.astype(np.float32) / 127.5 - 1.0 - - image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) - masked_image = image * (image_mask < 127.5) - - mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - return mask, masked_image - - -class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - vae_encoder: OnnxRuntimeModel - vae_decoder: OnnxRuntimeModel - text_encoder: OnnxRuntimeModel - tokenizer: CLIPTokenizer - unet: OnnxRuntimeModel - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] - safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor - - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae_encoder: OnnxRuntimeModel, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae_encoder=vae_encoder, - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - image: PIL.Image.Image, - mask_image: PIL.Image.Image, - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, - latents: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will - be masked out with `mask_image` and repainted according to `prompt`. - mask_image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted - to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) - instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`np.random.RandomState`, *optional*): - A np.random.RandomState to make generation deterministic. - latents (`np.ndarray`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if generator is None: - generator = np.random - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt - ) - - num_channels_latents = NUM_LATENT_CHANNELS - latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) - latents_dtype = prompt_embeds.dtype - if latents is None: - latents = generator.randn(*latents_shape).astype(latents_dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - # prepare mask and masked_image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) - mask = mask.astype(latents.dtype) - masked_image = masked_image.astype(latents.dtype) - - masked_image_latents = self.vae_encoder(sample=masked_image)[0] - masked_image_latents = 0.18215 * masked_image_latents - - # duplicate mask and masked_image_latents for each generation per prompt - mask = mask.repeat(batch_size * num_images_per_prompt, 0) - masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) - - mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask - masked_image_latents = ( - np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents - ) - - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - - unet_input_channels = NUM_UNET_INPUT_CHANNELS - if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: - raise ValueError( - "Incorrect configuration settings! The config of `pipeline.unet` expects" - f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * np.float64(self.scheduler.init_noise_sigma) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - timestep_dtype = next( - (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - # concat latents, mask, masked_image_latnets in the channel dimension - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ - 0 - ] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="np" - ).pixel_values.astype(image.dtype) - # safety_checker does not support batched inputs yet - images, has_nsfw_concept = [], [] - for i in range(image.shape[0]): - image_i, has_nsfw_concept_i = self.safety_checker( - clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] - ) - images.append(image_i) - has_nsfw_concept.append(has_nsfw_concept_i[0]) - image = np.concatenate(images) - else: - has_nsfw_concept = None - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py deleted file mode 100644 index 84e5f6aaab01f231025fdebd44cfc2932bbf1b91..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ /dev/null @@ -1,460 +0,0 @@ -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from transformers import CLIPFeatureExtractor, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import deprecate, logging -from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def preprocess(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask, scale_factor=8): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - return mask - - -class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to - provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - vae_encoder: OnnxRuntimeModel - vae_decoder: OnnxRuntimeModel - text_encoder: OnnxRuntimeModel - tokenizer: CLIPTokenizer - unet: OnnxRuntimeModel - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] - safety_checker: OnnxRuntimeModel - feature_extractor: CLIPFeatureExtractor - - def __init__( - self, - vae_encoder: OnnxRuntimeModel, - vae_decoder: OnnxRuntimeModel, - text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, - unet: OnnxRuntimeModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: OnnxRuntimeModel, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae_encoder=vae_encoder, - vae_decoder=vae_decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def __call__( - self, - prompt: Union[str, List[str]], - image: Union[np.ndarray, PIL.Image.Image] = None, - mask_image: Union[np.ndarray, PIL.Image.Image] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[np.random.RandomState] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - image (`nd.ndarray` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`nd.ndarray` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should - contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`np.random.RandomState`, *optional*): - A np.random.RandomState to make generation deterministic. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if generator is None: - generator = np.random - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - - if isinstance(image, PIL.Image.Image): - image = preprocess(image) - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt - ) - - latents_dtype = prompt_embeds.dtype - image = image.astype(latents_dtype) - - # encode the init image into latents and scale the latents - init_latents = self.vae_encoder(sample=image)[0] - init_latents = 0.18215 * init_latents - - # Expand init_latents for batch_size and num_images_per_prompt - init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) - init_latents_orig = init_latents - - # preprocess mask - if not isinstance(mask_image, np.ndarray): - mask_image = preprocess_mask(mask_image, 8) - mask_image = mask_image.astype(latents_dtype) - mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps.numpy()[-init_timestep] - timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) - - # add noise to latents using the timesteps - noise = generator.randn(*init_latents.shape).astype(latents_dtype) - init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) - ) - init_latents = init_latents.numpy() - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - latents = init_latents - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].numpy() - timestep_dtype = next( - (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ - 0 - ] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ).prev_sample - - latents = latents.numpy() - - init_latents_proper = self.scheduler.add_noise( - torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) - ) - - init_latents_proper = init_latents_proper.numpy() - - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="np" - ).pixel_values.astype(image.dtype) - # There will throw an error if use safety_checker batchsize>1 - images, has_nsfw_concept = [], [] - for i in range(image.shape[0]): - image_i, has_nsfw_concept_i = self.safety_checker( - clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] - ) - images.append(image_i) - has_nsfw_concept.append(has_nsfw_concept_i[0]) - image = np.concatenate(images) - else: - has_nsfw_concept = None - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py deleted file mode 100644 index 5044797986178ce898ba3fced9f3320316f04ef5..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ /dev/null @@ -1,714 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import StableDiffusionPipeline - - >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) - >>> pipe = pipe.to("cuda") - - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt).images[0] - ``` -""" - - -class StableDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. - - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py deleted file mode 100644 index e5550f1d0d450386b3a932b8a74a3396eb8a127d..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ /dev/null @@ -1,1016 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import torch -from torch.nn import functional as F -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.cross_attention import CrossAttention -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import StableDiffusionAttendAndExcitePipeline - - >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( - ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 - ... ).to("cuda") - - - >>> prompt = "a cat and a frog" - - >>> # use get_indices function to find out indices of the tokens you want to alter - >>> pipe.get_indices(prompt) - {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'} - - >>> token_indices = [2, 5] - >>> seed = 6141 - >>> generator = torch.Generator("cuda").manual_seed(seed) - - >>> images = pipe( - ... prompt=prompt, - ... token_indices=token_indices, - ... guidance_scale=7.5, - ... generator=generator, - ... num_inference_steps=50, - ... max_iter_to_alter=25, - ... ).images - - >>> image = images[0] - >>> image.save(f"../images/{prompt}_{seed}.png") - ``` -""" - - -class AttentionStore: - @staticmethod - def get_empty_store(): - return {"down": [], "mid": [], "up": []} - - def __call__(self, attn, is_cross: bool, place_in_unet: str): - if self.cur_att_layer >= 0 and is_cross: - if attn.shape[1] == self.attn_res**2: - self.step_store[place_in_unet].append(attn) - - self.cur_att_layer += 1 - if self.cur_att_layer == self.num_att_layers: - self.cur_att_layer = 0 - self.between_steps() - - def between_steps(self): - self.attention_store = self.step_store - self.step_store = self.get_empty_store() - - def get_average_attention(self): - average_attention = self.attention_store - return average_attention - - def aggregate_attention(self, from_where: List[str]) -> torch.Tensor: - """Aggregates the attention across the different layers and heads at the specified resolution.""" - out = [] - attention_maps = self.get_average_attention() - for location in from_where: - for item in attention_maps[location]: - cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1]) - out.append(cross_maps) - out = torch.cat(out, dim=0) - out = out.sum(0) / out.shape[0] - return out - - def reset(self): - self.cur_att_layer = 0 - self.step_store = self.get_empty_store() - self.attention_store = {} - - def __init__(self, attn_res=16): - """ - Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion - process - """ - self.num_att_layers = -1 - self.cur_att_layer = 0 - self.step_store = self.get_empty_store() - self.attention_store = {} - self.curr_step_index = 0 - self.attn_res = attn_res - - -class AttendExciteCrossAttnProcessor: - def __init__(self, attnstore, place_in_unet): - super().__init__() - self.attnstore = attnstore - self.place_in_unet = place_in_unet - - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - query = attn.to_q(hidden_states) - - is_cross = encoder_hidden_states is not None - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - - # only need to store attention maps during the Attend and Excite process - if attention_probs.requires_grad: - self.attnstore(attention_probs, is_cross, self.place_in_unet) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - indices, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if (indices is None) or (indices is not None and not isinstance(indices, List)): - raise ValueError(f"`indices` has to be a list but is {type(indices)}") - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @staticmethod - def _compute_max_attention_per_index( - attention_maps: torch.Tensor, - indices: List[int], - ) -> List[torch.Tensor]: - """Computes the maximum attention value for each of the tokens we wish to alter.""" - attention_for_text = attention_maps[:, :, 1:-1] - attention_for_text *= 100 - attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) - - # Shift indices since we removed the first token - indices = [index - 1 for index in indices] - - # Extract the maximum values - max_indices_list = [] - for i in indices: - image = attention_for_text[:, :, i] - smoothing = GaussianSmoothing().to(attention_maps.device) - input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") - image = smoothing(input).squeeze(0).squeeze(0) - max_indices_list.append(image.max()) - return max_indices_list - - def _aggregate_and_get_max_attention_per_token( - self, - indices: List[int], - ): - """Aggregates the attention for each token and computes the max activation value for each token to alter.""" - attention_maps = self.attention_store.aggregate_attention( - from_where=("up", "down", "mid"), - ) - max_attention_per_index = self._compute_max_attention_per_index( - attention_maps=attention_maps, - indices=indices, - ) - return max_attention_per_index - - @staticmethod - def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor: - """Computes the attend-and-excite loss using the maximum attention value for each token.""" - losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index] - loss = max(losses) - return loss - - @staticmethod - def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: - """Update the latent according to the computed loss.""" - grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] - latents = latents - step_size * grad_cond - return latents - - def _perform_iterative_refinement_step( - self, - latents: torch.Tensor, - indices: List[int], - loss: torch.Tensor, - threshold: float, - text_embeddings: torch.Tensor, - step_size: float, - t: int, - max_refinement_steps: int = 20, - ): - """ - Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code - according to our loss objective until the given threshold is reached for all tokens. - """ - iteration = 0 - target_loss = max(0, 1.0 - threshold) - while loss > target_loss: - iteration += 1 - - latents = latents.clone().detach().requires_grad_(True) - self.unet(latents, t, encoder_hidden_states=text_embeddings).sample - self.unet.zero_grad() - - # Get max activation value for each subject token - max_attention_per_index = self._aggregate_and_get_max_attention_per_token( - indices=indices, - ) - - loss = self._compute_loss(max_attention_per_index) - - if loss != 0: - latents = self._update_latent(latents, loss, step_size) - - logger.info(f"\t Try {iteration}. loss: {loss}") - - if iteration >= max_refinement_steps: - logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ") - break - - # Run one more time but don't compute gradients and update the latents. - # We just need to compute the new loss - the grad update will occur below - latents = latents.clone().detach().requires_grad_(True) - _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample - self.unet.zero_grad() - - # Get max activation value for each subject token - max_attention_per_index = self._aggregate_and_get_max_attention_per_token( - indices=indices, - ) - loss = self._compute_loss(max_attention_per_index) - logger.info(f"\t Finished with loss of: {loss}") - return loss, latents, max_attention_per_index - - def register_attention_control(self): - attn_procs = {} - cross_att_count = 0 - for name in self.unet.attn_processors.keys(): - if name.startswith("mid_block"): - place_in_unet = "mid" - elif name.startswith("up_blocks"): - place_in_unet = "up" - elif name.startswith("down_blocks"): - place_in_unet = "down" - else: - continue - - cross_att_count += 1 - attn_procs[name] = AttendExciteCrossAttnProcessor( - attnstore=self.attention_store, place_in_unet=place_in_unet - ) - - self.unet.set_attn_processor(attn_procs) - self.attention_store.num_att_layers = cross_att_count - - def get_indices(self, prompt: str) -> Dict[str, int]: - """Utility function to list the indices of the tokens you wish to alte""" - ids = self.tokenizer(prompt).input_ids - indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} - return indices - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]], - token_indices: List[int], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - max_iter_to_alter: int = 25, - thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, - scale_factor: int = 20, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - token_indices (`List[int]`): - The token indices to alter with attend-and-excite. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - max_iter_to_alter (`int`, *optional*, defaults to `25`): - Number of denoising steps to apply attend-and-excite. The first denoising steps are - where the attend-and-excite is applied. I.e. if `max_iter_to_alter` is 25 and there are a total of `30` - denoising steps, the first 25 denoising steps will apply attend-and-excite and the last 5 will not - apply attend-and-excite. - thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`): - Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. - scale_factor (`int`, *optional*, default to 20): - Scale factor that controls the step size of each Attend and Excite update. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. :type attention_store: object - """ - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - token_indices, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - self.attention_store = AttentionStore() - self.register_attention_control() - - # default config for step size from original repo - scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps)) - step_size = scale_factor * np.sqrt(scale_range) - - text_embeddings = ( - prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds - ) - - if isinstance(token_indices[0], int): - token_indices = [token_indices] - indices = [] - for ind in token_indices: - indices = indices + [ind] * num_images_per_prompt - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Attend and excite process - with torch.enable_grad(): - latents = latents.clone().detach().requires_grad_(True) - updated_latents = [] - for latent, index, text_embedding in zip(latents, indices, text_embeddings): - # Forward pass of denoising with text conditioning - latent = latent.unsqueeze(0) - text_embedding = text_embedding.unsqueeze(0) - - self.unet( - latent, - t, - encoder_hidden_states=text_embedding, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - self.unet.zero_grad() - - # Get max activation value for each subject token - max_attention_per_index = self._aggregate_and_get_max_attention_per_token( - indices=index, - ) - - loss = self._compute_loss(max_attention_per_index=max_attention_per_index) - - # If this is an iterative refinement step, verify we have reached the desired threshold for all - if i in thresholds.keys() and loss > 1.0 - thresholds[i]: - loss, latent, max_attention_per_index = self._perform_iterative_refinement_step( - latents=latent, - indices=index, - loss=loss, - threshold=thresholds[i], - text_embeddings=text_embedding, - step_size=step_size[i], - t=t, - ) - - # Perform gradient update - if i < max_iter_to_alter: - if loss != 0: - latent = self._update_latent( - latents=latent, - loss=loss, - step_size=step_size[i], - ) - logger.info(f"Iteration {i} | Loss: {loss:0.4f}") - - updated_latents.append(latent) - - latents = torch.cat(updated_latents, dim=0) - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - -class GaussianSmoothing(torch.nn.Module): - """ - Arguments: - Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input - using a depthwise convolution. - channels (int, sequence): Number of channels of the input tensors. Output will - have this number of channels as well. - kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the - gaussian kernel. dim (int, optional): The number of dimensions of the data. - Default value is 2 (spatial). - """ - - # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2 - def __init__( - self, - channels: int = 1, - kernel_size: int = 3, - sigma: float = 0.5, - dim: int = 2, - ): - super().__init__() - - if isinstance(kernel_size, int): - kernel_size = [kernel_size] * dim - if isinstance(sigma, float): - sigma = [sigma] * dim - - # The gaussian kernel is the product of the - # gaussian function of each dimension. - kernel = 1 - meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) - for size, std, mgrid in zip(kernel_size, sigma, meshgrids): - mean = (size - 1) / 2 - kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) - - # Make sure sum of values in gaussian kernel equals 1. - kernel = kernel / torch.sum(kernel) - - # Reshape to depthwise convolutional weight - kernel = kernel.view(1, 1, *kernel.size()) - kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) - - self.register_buffer("weight", kernel) - self.groups = channels - - if dim == 1: - self.conv = F.conv1d - elif dim == 2: - self.conv = F.conv2d - elif dim == 3: - self.conv = F.conv3d - else: - raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) - - def forward(self, input): - """ - Arguments: - Apply gaussian filter to input. - input (torch.Tensor): Input to apply gaussian filter on. - Returns: - filtered (torch.Tensor): Filtered output. - """ - return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py deleted file mode 100644 index 00f35ee4d02180887d4e76fd1dba7210df6a69ca..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ /dev/null @@ -1,820 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import PIL.Image -import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" - ... ) - >>> image = np.array(image) - - >>> # get canny image - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - - >>> # load control net and stable diffusion v1-5 - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) - >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 - ... ) - - >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) - >>> # remove following line if xformers is not installed - >>> pipe.enable_xformers_memory_efficient_attention() - - >>> pipe.enable_model_cpu_offload() - - >>> # generate image - >>> generator = torch.manual_seed(0) - >>> image = pipe( - ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image - ... ).images[0] - ``` -""" - - -class StableDiffusionControlNetPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - controlnet ([`ControlNetModel`]): - Provides additional conditioning to the unet during the denoising process - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: ControlNetModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - # the safety checker can offload the vae again - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # control net hook has be manually offloaded as it alternates with unet - cpu_offload_with_hook(self.controlnet, device) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - - if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: - raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" - ) - - if image_is_pil: - image_batch_size = 1 - elif image_is_tensor: - image_batch_size = image.shape[0] - elif image_is_pil_list: - image_batch_size = len(image) - elif image_is_tensor_list: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image(self, image, width, height, batch_size, num_images_per_prompt, device, dtype): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - image = [ - np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image - ] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def _default_height_width(self, height, width, image): - if isinstance(image, list): - image = image[0] - - if height is None: - if isinstance(image, PIL.Image.Image): - height = image.height - elif isinstance(image, torch.Tensor): - height = image.shape[3] - - height = (height // 8) * 8 # round down to nearest multiple of 8 - - if width is None: - if isinstance(image, PIL.Image.Image): - width = image.width - elif isinstance(image, torch.Tensor): - width = image.shape[2] - - width = (width // 8) * 8 # round down to nearest multiple of 8 - - return height, width - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can - also be accepted as an image. The control image is automatically resized to fit the output image. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, image) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare image - image = self.prepare_image( - image, - width, - height, - batch_size * num_images_per_prompt, - num_images_per_prompt, - device, - self.controlnet.dtype, - ) - - if do_classifier_free_guidance: - image = torch.cat([image] * 2) - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - down_block_res_samples, mid_block_res_sample = self.controlnet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - controlnet_cond=image, - return_dict=False, - ) - - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py deleted file mode 100644 index 6c02e06a6523c3c6bd6d3a4898844297aec2244b..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ /dev/null @@ -1,689 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from packaging import version -from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image to image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - depth_estimator: DPTForDepthEstimation, - feature_extractor: DPTFeatureExtractor, - ): - super().__init__() - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - depth_estimator=depth_estimator, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.depth_estimator]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs - def check_inputs( - self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - - return timesteps, num_inference_steps - t_start - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) - - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents - - return latents - - def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): - if isinstance(image, PIL.Image.Image): - image = [image] - else: - image = [img for img in image] - - if isinstance(image[0], PIL.Image.Image): - width, height = image[0].size - else: - width, height = image[0].shape[-2:] - - if depth_map is None: - pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values - pixel_values = pixel_values.to(device=device) - # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. - # So we use `torch.autocast` here for half precision inference. - context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() - with context_manger: - depth_map = self.depth_estimator(pixel_values).predicted_depth - else: - depth_map = depth_map.to(device=device, dtype=dtype) - - depth_map = torch.nn.functional.interpolate( - depth_map.unsqueeze(1), - size=(height // self.vae_scale_factor, width // self.vae_scale_factor), - mode="bicubic", - align_corners=False, - ) - - depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) - depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) - depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 - depth_map = depth_map.to(dtype) - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if depth_map.shape[0] < batch_size: - depth_map = depth_map.repeat(batch_size, 1, 1, 1) - - depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map - return depth_map - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - depth_map: Optional[torch.FloatTensor] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> import torch - >>> import requests - >>> from PIL import Image - - >>> from diffusers import StableDiffusionDepth2ImgPipeline - - >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( - ... "stabilityai/stable-diffusion-2-depth", - ... torch_dtype=torch.float16, - ... ) - >>> pipe.to("cuda") - - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> init_image = Image.open(requests.get(url, stream=True).raw) - >>> prompt = "two tigers" - >>> n_propmt = "bad, deformed, ugly, bad anotomy" - >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0] - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 1. Check inputs - self.check_inputs( - prompt, - strength, - callback_steps, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - if image is None: - raise ValueError("`image` input cannot be undefined.") - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare depth mask - depth_mask = self.prepare_depth_map( - image, - depth_map, - batch_size * num_images_per_prompt, - do_classifier_free_guidance, - prompt_embeds.dtype, - device, - ) - - # 5. Preprocess image - image = preprocess(image) - - # 6. Set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 7. Prepare latent variables - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator - ) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 10. Post-processing - image = self.decode_latents(latents) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py deleted file mode 100644 index a7165457c67cccfecbc36d0e3fc3e7b4a6e228d5..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import PIL -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class StableDiffusionImageVariationPipeline(DiffusionPipeline): - r""" - Pipeline to generate variations from an input image using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - image_encoder ([`CLIPVisionModelWithProjection`]): - Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - # TODO: feature_extractor is required to encode images (if they are in PIL format), - # we should give a descriptive message if the pipeline doesn't have one. - _optional_components = ["safety_checker"] - - def __init__( - self, - vae: AutoencoderKL, - image_encoder: CLIPVisionModelWithProjection, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warn( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - image_encoder=image_encoder, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(images=image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeddings = self.image_encoder(image).image_embeds - image_embeddings = image_embeddings.unsqueeze(1) - - # duplicate image embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = image_embeddings.shape - image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) - image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - negative_prompt_embeds = torch.zeros_like(image_embeddings) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) - - return image_embeddings - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs(self, image, height, width, callback_steps): - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" - f" {type(image)}" - ) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def __call__( - self, - image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): - The image or images to guide the image generation. If you provide a tensor, it needs to comply with the - configuration of - [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) - `CLIPFeatureExtractor` - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(image, height, width, callback_steps) - - # 2. Define call parameters - if isinstance(image, PIL.Image.Image): - batch_size = 1 - elif isinstance(image, list): - batch_size = len(image) - else: - batch_size = image.shape[0] - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input image - image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - image_embeddings.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py deleted file mode 100644 index 172ab15a757e725164a5257f105baa0d2deef115..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ /dev/null @@ -1,734 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import requests - >>> import torch - >>> from PIL import Image - >>> from io import BytesIO - - >>> from diffusers import StableDiffusionImg2ImgPipeline - - >>> device = "cuda" - >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" - >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) - >>> pipe = pipe.to(device) - - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - - >>> response = requests.get(url) - >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") - >>> init_image = init_image.resize((768, 512)) - - >>> prompt = "A fantasy landscape, trending on artstation" - - >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images - >>> images[0].save("fantasy_landscape.png") - ``` -""" - - -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class StableDiffusionImg2ImgPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image to image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - - return timesteps, num_inference_steps - t_start - - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) - - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - shape = init_latents.shape - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - # get latents - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents - - return latents - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` - will be used as a starting point, adding more noise to it the larger the `strength`. The number of - denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will - be maximum and the denoising process will run for the full number of iterations specified in - `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. This parameter will be modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Preprocess image - image = preprocess(image) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py deleted file mode 100644 index b645ba667f779e79d6d84427d07196042ee41ad2..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ /dev/null @@ -1,890 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def prepare_mask_and_masked_image(image, mask): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - - - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") - - # Batch single image - if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - # preprocess image - if isinstance(image, (PIL.Image.Image, np.ndarray)): - image = [image] - - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (PIL.Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - return mask, masked_image - - -class StableDiffusionInpaintPipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration" - " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" - " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" - " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" - " Hub, it would be very nice if you could open a Pull request for the" - " `scheduler/scheduler_config.json` file" - ) - deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["skip_prk_steps"] = True - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device=device, dtype=dtype) - - masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) - - mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask - masked_image_latents = ( - torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents - ) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - return mask, masked_image_latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will - be masked out with `mask_image` and repainted according to `prompt`. - mask_image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted - to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) - instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> import PIL - >>> import requests - >>> import torch - >>> from io import BytesIO - - >>> from diffusers import StableDiffusionInpaintPipeline - - - >>> def download_image(url): - ... response = requests.get(url) - ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - - >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" - - >>> init_image = download_image(img_url).resize((512, 512)) - >>> mask_image = download_image(mask_url).resize((512, 512)) - - >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( - ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" - >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs - self.check_inputs( - prompt, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - ) - - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask_image is None: - raise ValueError("`mask_image` input cannot be undefined.") - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Preprocess mask and image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - - # 8. Check that sizes of mask, masked image and latents match - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." - ) - - # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 10. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 11. Post-processing - image = self.decode_latents(latents) - - # 12. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 13. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py deleted file mode 100644 index 1e84efa4163cab4b456c2c29ed24199724559e4b..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ /dev/null @@ -1,713 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask, scale_factor=8): - if not isinstance(mask, torch.FloatTensor): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - else: - valid_mask_channel_sizes = [1, 3] - # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) - if mask.shape[3] in valid_mask_channel_sizes: - mask = mask.permute(0, 3, 1, 2) - elif mask.shape[1] not in valid_mask_channel_sizes: - raise ValueError( - f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," - f" but received mask of shape {tuple(mask.shape)}" - ) - # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape - mask = mask.mean(dim=1, keepdim=True) - h, w = mask.shape[-2:] - h, w = map(lambda x: x - x % 32, (h, w)) # resize to integer multiple of 32 - mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) - return mask - - -class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): - r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["feature_extractor"] - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs - def check_inputs( - self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None - ): - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - - return timesteps, num_inference_steps - t_start - - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator): - image = image.to(device=self.device, dtype=dtype) - init_latent_dist = self.vae.encode(image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = self.vae.config.scaling_factor * init_latents - - # Expand init_latents for batch_size and num_images_per_prompt - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - init_latents_orig = init_latents - - # add noise to latents using the timesteps - noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype) - init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents - return latents, init_latents_orig, noise - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, - strength: float = 0.8, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - add_predicted_noise: Optional[bool] = False, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. This is the image whose masked region will be inpainted. - mask_image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a - PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the - expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. - strength (`float`, *optional*, defaults to 0.8): - Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` - is 1, the denoising process will be run on the masked area for the full number of iterations specified - in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to - that region the larger the `strength`. If `strength` is 0, no inpainting will occur. - num_inference_steps (`int`, *optional*, defaults to 50): - The reference number of denoising steps. More denoising steps usually lead to a higher quality image at - the expense of slower inference. This parameter will be modulated by `strength`, as explained above. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - add_predicted_noise (`bool`, *optional*, defaults to True): - Use predicted noise instead of random noise when constructing noisy versions of the original image in - the reverse diffusion process - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 1. Check inputs - self.check_inputs(prompt, strength, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Preprocess image and mask - if not isinstance(image, torch.FloatTensor): - image = preprocess_image(image) - - mask_image = preprocess_mask(mask_image, self.vae_scale_factor) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - - # 6. Prepare latent variables - # encode the init image into latents and scale the latents - latents, init_latents_orig, noise = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator - ) - - # 7. Prepare mask latent - mask = mask_image.to(device=self.device, dtype=latents.dtype) - mask = torch.cat([mask] * batch_size * num_images_per_prompt) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # masking - if add_predicted_noise: - init_latents_proper = self.scheduler.add_noise( - init_latents_orig, noise_pred_uncond, torch.tensor([t]) - ) - else: - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # use original latents corresponding to unmasked portions of the image - latents = (init_latents_orig * mask) + (latents * (1 - mask)) - - # 10. Post-processing - image = self.decode_latents(latents) - - # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 12. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py deleted file mode 100644 index 9efa91e161b9b7ab53cbf0b4fbfdc3f9bad590fb..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ /dev/null @@ -1,718 +0,0 @@ -# Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - deprecate, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): - r""" - Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - num_inference_steps: int = 100, - guidance_scale: float = 7.5, - image_guidance_scale: float = 1.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be repainted according to `prompt`. - num_inference_steps (`int`, *optional*, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. This pipeline requires a value of at least `1`. - image_guidance_scale (`float`, *optional*, defaults to 1.5): - Image guidance scale is to push the generated image towards the inital image `image`. Image guidance - scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to - generate images that are closely linked to the source image `image`, usually at the expense of lower - image quality. This pipeline requires a value of at least `1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> import PIL - >>> import requests - >>> import torch - >>> from io import BytesIO - - >>> from diffusers import StableDiffusionInstructPix2PixPipeline - - - >>> def download_image(url): - ... response = requests.get(url) - ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - - >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" - - >>> image = download_image(img_url).resize((512, 512)) - - >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( - ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> prompt = "make the mountains snowy" - >>> image = pipe(prompt=prompt, image=image).images[0] - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Check inputs - self.check_inputs(prompt, callback_steps) - - if image is None: - raise ValueError("`image` input cannot be undefined.") - - # 1. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 - # check if scheduler is in sigmas space - scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") - - # 2. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 3. Preprocess image - image = preprocess(image) - height, width = image.shape[-2:] - - # 4. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare Image latents - image_latents = self.prepare_image_latents( - image, - batch_size, - num_images_per_prompt, - prompt_embeds.dtype, - device, - do_classifier_free_guidance, - generator, - ) - - # 6. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Check that shapes of latents and image match the UNet channels - num_channels_image = image_latents.shape[1] - if num_channels_latents + num_channels_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" - " `pipeline.unet` or your `image` input." - ) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance. - # The latents are expanded 3 times because for pix2pix the guidance\ - # is applied for both the text and the input image. - latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents - - # concat latents, image_latents in the channel dimension - scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) - - # predict the noise residual - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # Hack: - # For karras style schedulers the model does classifer free guidance using the - # predicted_original_sample instead of the noise_pred. So we need to compute the - # predicted_original_sample here if we are using a karras style scheduler. - if scheduler_is_in_sigma_space: - step_index = (self.scheduler.timesteps == t).nonzero().item() - sigma = self.scheduler.sigmas[step_index] - noise_pred = latent_model_input - sigma * noise_pred - - # perform guidance - if do_classifier_free_guidance: - noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_image) - + image_guidance_scale * (noise_pred_image - noise_pred_uncond) - ) - - # Hack: - # For karras style schedulers the model does classifer free guidance using the - # predicted_original_sample instead of the noise_pred. But the scheduler.step function - # expects the noise_pred and computes the predicted_original_sample internally. So we - # need to overwrite the noise_pred here such that the value of the computed - # predicted_original_sample is correct. - if scheduler_is_in_sigma_space: - noise_pred = (noise_pred - latents) / (-sigma) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 10. Post-processing - image = self.decode_latents(latents) - - # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 12. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_ prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] - prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def check_inputs(self, prompt, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def prepare_image_latents( - self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None - ): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() - - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand image_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - if do_classifier_free_guidance: - uncond_image_latents = torch.zeros_like(image_latents) - image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) - - return image_latents diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py deleted file mode 100644 index f3db54caa34224e6710002e823e207827e8409d4..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ /dev/null @@ -1,547 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -from typing import Callable, List, Optional, Union - -import torch -from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser - -from ...pipelines import DiffusionPipeline -from ...schedulers import LMSDiscreteScheduler -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor -from . import StableDiffusionPipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class ModelWrapper: - def __init__(self, model, alphas_cumprod): - self.model = model - self.alphas_cumprod = alphas_cumprod - - def apply_model(self, *args, **kwargs): - if len(args) == 3: - encoder_hidden_states = args[-1] - args = args[:2] - if kwargs.get("cond", None) is not None: - encoder_hidden_states = kwargs.pop("cond") - return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample - - -class StableDiffusionKDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - - - This is an experimental pipeline and is likely to change in the future. - - - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae, - text_encoder, - tokenizer, - unet, - scheduler, - safety_checker, - feature_extractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - logger.info( - f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use" - " this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines" - " as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for" - " production settings." - ) - - # get correct sigmas from LMS - scheduler = LMSDiscreteScheduler.from_config(scheduler.config) - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - model = ModelWrapper(unet, scheduler.alphas_cumprod) - if scheduler.prediction_type == "v_prediction": - self.k_diffusion_model = CompVisVDenoiser(model) - else: - self.k_diffusion_model = CompVisDenoiser(model) - - def set_scheduler(self, scheduler_type: str): - library = importlib.import_module("k_diffusion") - sampling = getattr(library, "sampling") - self.sampler = getattr(sampling, scheduler_type) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - return latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = True - if guidance_scale <= 1.0: - raise ValueError("has to use guidance_scale") - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) - sigmas = self.scheduler.sigmas - sigmas = sigmas.to(prompt_embeds.dtype) - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - latents = latents * sigmas[0] - self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) - self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) - - # 6. Define model function - def model_fn(x, t): - latent_model_input = torch.cat([x] * 2) - t = torch.cat([t] * 2) - - noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) - - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - return noise_pred - - # 7. Run k-diffusion solver - latents = self.sampler(model_fn, latents, sigmas) - - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py deleted file mode 100644 index 624d0e62582804434ad8b04a52728e63b0fce348..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import EulerDiscreteScheduler -from ...utils import is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - - image = [np.array(i.resize((w, h)))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): - r""" - Pipeline to upscale the resolution of Stable Diffusion output images by a factor of 2. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`EulerDiscreteScheduler`]. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: EulerDiscreteScheduler, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - ) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - device: (`torch.device`): - torch device - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_length=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - text_encoder_out = self.text_encoder( - text_input_ids.to(device), - output_hidden_states=True, - ) - text_embeddings = text_encoder_out.hidden_states[-1] - text_pooler_out = text_encoder_out.pooler_output - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_length=True, - return_tensors="pt", - ) - - uncond_encoder_out = self.text_encoder( - uncond_input.input_ids.to(device), - output_hidden_states=True, - ) - - uncond_embeddings = uncond_encoder_out.hidden_states[-1] - uncond_pooler_out = uncond_encoder_out.pooler_output - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out]) - - return text_embeddings, text_pooler_out - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def check_inputs(self, prompt, image, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" - ) - - # verify batch size of prompt and image are same if image is a list or tensor - if isinstance(image, list) or isinstance(image, torch.Tensor): - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) - if isinstance(image, list): - image_batch_size = len(image) - else: - image_batch_size = image.shape[0] if image.ndim == 4 else 1 - if batch_size != image_batch_size: - raise ValueError( - f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." - " Please make sure that passed `prompt` matches the batch size of `image`." - ) - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height, width) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], - num_inference_steps: int = 75, - guidance_scale: float = 9.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image upscaling. - image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): - `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be - either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It - will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an - image representation and encoded using this pipeline's `vae` encoder. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - ```py - >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline - >>> import torch - - - >>> pipeline = StableDiffusionPipeline.from_pretrained( - ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 - ... ) - >>> pipeline.to("cuda") - - >>> model_id = "stabilityai/sd-x2-latent-upscaler" - >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) - >>> upscaler.to("cuda") - - >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" - >>> generator = torch.manual_seed(33) - - >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images - - >>> with torch.no_grad(): - ... image = pipeline.decode_latents(low_res_latents) - >>> image = pipeline.numpy_to_pil(image)[0] - - >>> image.save("../images/a1.png") - - >>> upscaled_image = upscaler( - ... prompt=prompt, - ... image=low_res_latents, - ... num_inference_steps=20, - ... guidance_scale=0, - ... generator=generator, - ... ).images[0] - - >>> upscaled_image.save("../images/a2.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - - # 1. Check inputs - self.check_inputs(prompt, image, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if guidance_scale == 0: - prompt = [""] * batch_size - - # 3. Encode input prompt - text_embeddings, text_pooler_out = self._encode_prompt( - prompt, device, do_classifier_free_guidance, negative_prompt - ) - - # 4. Preprocess image - image = preprocess(image) - image = image.to(dtype=text_embeddings.dtype, device=device) - if image.shape[1] == 3: - # encode image if not in latent-space yet - image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - batch_multiplier = 2 if do_classifier_free_guidance else 1 - image = image[None, :] if image.ndim == 3 else image - image = torch.cat([image] * batch_multiplier) - - # 5. Add noise to image (set to be 0): - # (see below notes from the author): - # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." - noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) - noise_level = torch.cat([noise_level] * image.shape[0]) - inv_noise_level = (noise_level**2 + 1) ** (-0.5) - - image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] - image_cond = image_cond.to(text_embeddings.dtype) - - noise_level_embed = torch.cat( - [ - torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), - torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), - ], - dim=1, - ) - - timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) - - # 6. Prepare latent variables - height, width = image.shape[2:] - num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( - batch_size, - num_channels_latents, - height * 2, # 2x upscale - width * 2, - text_embeddings.dtype, - device, - generator, - latents, - ) - - # 7. Check that sizes of image and latents match - num_channels_image = image.shape[1] - if num_channels_latents + num_channels_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" - " `pipeline.unet` or your `image` input." - ) - - # 9. Denoising loop - num_warmup_steps = 0 - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - sigma = self.scheduler.sigmas[i] - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) - # preconditioning parameter based on Karras et al. (2022) (table 1) - timestep = torch.log(sigma) * 0.25 - - noise_pred = self.unet( - scaled_model_input, - timestep, - encoder_hidden_states=text_embeddings, - timestep_cond=timestep_condition, - ).sample - - # in original repo, the output contains a variance channel that's not used - noise_pred = noise_pred[:, :-1] - - # apply preconditioning, based on table 1 in Karras et al. (2022) - inv_sigma = 1 / (sigma**2 + 1) - noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 10. Post-processing - image = self.decode_latents(latents) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py deleted file mode 100644 index b1f29fbef12bcc66904bc3376ecd75860b7f50e5..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ /dev/null @@ -1,664 +0,0 @@ -# Copyright 2023 MultiDiffusion Authors and The HuggingFace Team. All rights reserved." -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, PNDMScheduler -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler - - >>> model_ckpt = "stabilityai/stable-diffusion-2-base" - >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") - >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained( - ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16 - ... ) - - >>> pipe = pipe.to("cuda") - - >>> prompt = "a photo of the dolomites" - >>> image = pipe(prompt).images[0] - ``` -""" - - -class StableDiffusionPanoramaPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image - Generation". - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). - - To generate panorama-like images, be sure to pass the `width` parameter accordingly when using the pipeline. Our - recommendation for the `width` value is 2048. This is the default value of the `width` parameter for this pipeline. - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. The original work - on Multi Diffsion used the [`DDIMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: DDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if isinstance(scheduler, PNDMScheduler): - logger.error("PNDMScheduler for this pipeline is currently not supported.") - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): - # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) - panorama_height /= 8 - panorama_width /= 8 - num_blocks_height = (panorama_height - window_size) // stride + 1 - num_blocks_width = (panorama_width - window_size) // stride + 1 - total_num_blocks = int(num_blocks_height * num_blocks_width) - views = [] - for i in range(total_num_blocks): - h_start = int((i // num_blocks_width) * stride) - h_end = h_start + window_size - w_start = int((i % num_blocks_width) * stride) - w_end = w_start + window_size - views.append((h_start, h_end, w_start, w_end)) - return views - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = 512, - width: Optional[int] = 2048, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to 512: - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 2048): - The width in pixels of the generated image. The width is kept to a high number because the - pipeline is supposed to be used for generating panorama-like images. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Define panorama grid and initialize views for synthesis. - views = self.get_views(height, width) - count = torch.zeros_like(latents) - value = torch.zeros_like(latents) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - # Each denoising step also includes refinement of the latents with respect to the - # views. - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - count.zero_() - value.zero_() - - # generate views - # Here, we iterate through different spatial crops of the latents and denoise them. These - # denoised (latent) crops are then averaged to produce the final latent - # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the - # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 - for h_start, h_end, w_start, w_end in views: - # get the latents corresponding to the current view coordinates - latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents_view_denoised = self.scheduler.step( - noise_pred, t, latents_for_view, **extra_step_kwargs - ).prev_sample - value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised - count[:, :, h_start:h_end, w_start:w_end] += 1 - - # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 - latents = torch.where(count > 0, value / count, value) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py deleted file mode 100644 index c78d327f69a9729b732e02f600d5b0954d7b7ac0..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ /dev/null @@ -1,1228 +0,0 @@ -# Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union - -import numpy as np -import PIL -import torch -import torch.nn.functional as F -from transformers import ( - BlipForConditionalGeneration, - BlipProcessor, - CLIPFeatureExtractor, - CLIPTextModel, - CLIPTokenizer, -) - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.cross_attention import CrossAttention -from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler -from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler -from ...utils import ( - PIL_INTERPOLATION, - BaseOutput, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, -) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class Pix2PixInversionPipelineOutput(BaseOutput): - """ - Output class for Stable Diffusion pipelines. - - Args: - latents (`torch.FloatTensor`) - inverted latents tensor - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - """ - - latents: torch.FloatTensor - images: Union[List[PIL.Image.Image], np.ndarray] - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import requests - >>> import torch - - >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline - - - >>> def download(embedding_url, local_filepath): - ... r = requests.get(embedding_url) - ... with open(local_filepath, "wb") as f: - ... f.write(r.content) - - - >>> model_ckpt = "CompVis/stable-diffusion-v1-4" - >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) - >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.to("cuda") - - >>> prompt = "a high resolution painting of a cat in the style of van gough" - >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" - >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" - - >>> for url in [source_emb_url, target_emb_url]: - ... download(url, url.split("/")[-1]) - - >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) - >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) - >>> images = pipeline( - ... prompt, - ... source_embeds=src_embeds, - ... target_embeds=target_embeds, - ... num_inference_steps=50, - ... cross_attention_guidance_amount=0.15, - ... ).images - - >>> images[0].save("edited_image_dog.png") - ``` -""" - -EXAMPLE_INVERT_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from transformers import BlipForConditionalGeneration, BlipProcessor - >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline - - >>> import requests - >>> from PIL import Image - - >>> captioner_id = "Salesforce/blip-image-captioning-base" - >>> processor = BlipProcessor.from_pretrained(captioner_id) - >>> model = BlipForConditionalGeneration.from_pretrained( - ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True - ... ) - - >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" - >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( - ... sd_model_ckpt, - ... caption_generator=model, - ... caption_processor=processor, - ... torch_dtype=torch.float16, - ... safety_checker=None, - ... ) - - >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) - >>> pipeline.enable_model_cpu_offload() - - >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" - - >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) - >>> # generate caption - >>> caption = pipeline.generate_caption(raw_image) - - >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" - >>> inv_latents = pipeline.invert(caption, image=raw_image).latents - >>> # we need to generate source and target embeds - - >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] - - >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] - - >>> source_embeds = pipeline.get_embeds(source_prompts) - >>> target_embeds = pipeline.get_embeds(target_prompts) - >>> # the latents can then be used to edit a real image - - >>> image = pipeline( - ... caption, - ... source_embeds=source_embeds, - ... target_embeds=target_embeds, - ... num_inference_steps=50, - ... cross_attention_guidance_amount=0.15, - ... generator=generator, - ... latents=inv_latents, - ... negative_prompt=caption, - ... ).images[0] - >>> image.save("edited_image.png") - ``` -""" - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -def prepare_unet(unet: UNet2DConditionModel): - """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" - pix2pix_zero_attn_procs = {} - for name in unet.attn_processors.keys(): - module_name = name.replace(".processor", "") - module = unet.get_submodule(module_name) - if "attn2" in name: - pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True) - module.requires_grad_(True) - else: - pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False) - module.requires_grad_(False) - - unet.set_attn_processor(pix2pix_zero_attn_procs) - return unet - - -class Pix2PixZeroL2Loss: - def __init__(self): - self.loss = 0.0 - - def compute_loss(self, predictions, targets): - self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) - - -class Pix2PixZeroCrossAttnProcessor: - """An attention processor class to store the attention weights. - In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" - - def __init__(self, is_pix2pix_zero=False): - self.is_pix2pix_zero = is_pix2pix_zero - if self.is_pix2pix_zero: - self.reference_cross_attn_map = {} - - def __call__( - self, - attn: CrossAttention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - timestep=None, - loss=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - if self.is_pix2pix_zero and timestep is not None: - # new bookkeeping to save the attention weights. - if loss is None: - self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() - # compute loss - elif loss is not None: - prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) - loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): - r""" - Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - requires_safety_checker (bool): - Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the - pipeline publicly. - """ - _optional_components = [ - "safety_checker", - "feature_extractor", - "caption_generator", - "caption_processor", - "inverse_scheduler", - ] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], - feature_extractor: CLIPFeatureExtractor, - safety_checker: StableDiffusionSafetyChecker, - inverse_scheduler: DDIMInverseScheduler, - caption_generator: BlipForConditionalGeneration, - caption_processor: BlipProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - caption_processor=caption_processor, - caption_generator=caption_generator, - inverse_scheduler=inverse_scheduler, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - hook = None - for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - source_embeds, - target_embeds, - callback_steps, - prompt_embeds=None, - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if source_embeds is None and target_embeds is None: - raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def generate_caption(self, images): - """Generates caption for a given image.""" - text = "a photography of" - - prev_device = self.caption_generator.device - - device = self._execution_device - inputs = self.caption_processor(images, text, return_tensors="pt").to( - device=device, dtype=self.caption_generator.dtype - ) - self.caption_generator.to(device) - outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) - - # offload caption generator - self.caption_generator.to(prev_device) - - caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] - return caption - - def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): - """Constructs the edit direction to steer the image generation process semantically.""" - return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) - - @torch.no_grad() - def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor: - num_prompts = len(prompt) - embeds = [] - for i in range(0, num_prompts, batch_size): - prompt_slice = prompt[i : i + batch_size] - - input_ids = self.tokenizer( - prompt_slice, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ).input_ids - - input_ids = input_ids.to(self.text_encoder.device) - embeds.append(self.text_encoder(input_ids)[0]) - - return torch.cat(embeds, dim=0).mean(0)[None] - - def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): - if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" - ) - - image = image.to(device=device, dtype=dtype) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) - else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) - - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) - - latents = init_latents - - return latents - - def auto_corr_loss(self, hidden_states, generator=None): - batch_size, channel, height, width = hidden_states.shape - if batch_size > 1: - raise ValueError("Only batch_size 1 is supported for now") - - hidden_states = hidden_states.squeeze(0) - # hidden_states must be shape [C,H,W] now - reg_loss = 0.0 - for i in range(hidden_states.shape[0]): - noise = hidden_states[i][None, None, :, :] - while True: - roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 - - if noise.shape[2] <= 8: - break - noise = F.avg_pool2d(noise, kernel_size=2) - return reg_loss - - def kl_divergence(self, hidden_states): - mean = hidden_states.mean() - var = hidden_states.var() - return var + mean**2 - 1 - torch.log(var + 1e-7) - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - source_embeds: torch.Tensor = None, - target_embeds: torch.Tensor = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - cross_attention_guidance_amount: float = 0.1, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - source_embeds (`torch.Tensor`): - Source concept embeddings. Generation of the embeddings as per the [original - paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. - target_embeds (`torch.Tensor`): - Target concept embeddings. Generation of the embeddings as per the [original - paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - cross_attention_guidance_amount (`float`, defaults to 0.1): - Amount of guidance needed from the reference cross-attention maps. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Define the spatial resolutions. - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - image, - source_embeds, - target_embeds, - callback_steps, - prompt_embeds, - ) - - # 3. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - if cross_attention_kwargs is None: - cross_attention_kwargs = {} - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Generate the inverted noise from the input image or any other image - # generated from the input prompt. - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - latents_init = latents.clone() - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Rejig the UNet so that we can obtain the cross-attenion maps and - # use them for guiding the subsequent image generation. - self.unet = prepare_unet(self.unet) - - # 7. Denoising loop where we obtain the cross-attention maps. - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs={"timestep": t}, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Compute the edit directions. - edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) - - # 9. Edit the prompt embeddings as per the edit directions discovered. - prompt_embeds_edit = prompt_embeds.clone() - prompt_embeds_edit[1:2] += edit_direction - - # 10. Second denoising loop to generate the edited image. - latents = latents_init - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # we want to learn the latent such that it steers the generation - # process towards the edited direction, so make the make initial - # noise learnable - x_in = latent_model_input.detach().clone() - x_in.requires_grad = True - - # optimizer - opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) - - with torch.enable_grad(): - # initialize loss - loss = Pix2PixZeroL2Loss() - - # predict the noise residual - noise_pred = self.unet( - x_in, - t, - encoder_hidden_states=prompt_embeds_edit.detach(), - cross_attention_kwargs={"timestep": t, "loss": loss}, - ).sample - - loss.loss.backward(retain_graph=False) - opt.step() - - # recompute the noise - noise_pred = self.unet( - x_in.detach(), - t, - encoder_hidden_states=prompt_embeds_edit, - cross_attention_kwargs={"timestep": None}, - ).sample - - latents = x_in.detach().chunk(2)[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - # 11. Post-process the latents. - edited_image = self.decode_latents(latents) - - # 12. Run the safety checker. - edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) - - # 13. Convert to PIL. - if output_type == "pil": - edited_image = self.numpy_to_pil(edited_image) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (edited_image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept) - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) - def invert( - self, - prompt: Optional[str] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - num_inference_steps: int = 50, - guidance_scale: float = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - cross_attention_guidance_amount: float = 0.1, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - lambda_auto_corr: float = 20.0, - lambda_kl: float = 20.0, - num_reg_steps: int = 5, - num_auto_corr_rolls: int = 5, - ): - r""" - Function used to generate inverted latents given a prompt and image. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`PIL.Image.Image`, *optional*): - `Image`, or tensor representing an image batch which will be used for conditioning. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - cross_attention_guidance_amount (`float`, defaults to 0.1): - Amount of guidance needed from the reference cross-attention maps. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - lambda_auto_corr (`float`, *optional*, defaults to 20.0): - Lambda parameter to control auto correction - lambda_kl (`float`, *optional*, defaults to 20.0): - Lambda parameter to control Kullback–Leibler divergence output - num_reg_steps (`int`, *optional*, defaults to 5): - Number of regularization loss steps - num_auto_corr_rolls (`int`, *optional*, defaults to 5): - Number of auto correction roll steps - - Examples: - - Returns: - [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or - `tuple`: - [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if - `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted - latents tensor and then second is the corresponding decoded image. - """ - # 1. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - if cross_attention_kwargs is None: - cross_attention_kwargs = {} - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Preprocess image - image = preprocess(image) - - # 4. Prepare latent variables - latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) - - # 5. Encode input prompt - num_images_per_prompt = 1 - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - ) - - # 4. Prepare timesteps - self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.inverse_scheduler.timesteps - - # 6. Rejig the UNet so that we can obtain the cross-attenion maps and - # use them for guiding the subsequent image generation. - self.unet = prepare_unet(self.unet) - - # 7. Denoising loop where we obtain the cross-attention maps. - num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order - with self.progress_bar(total=num_inference_steps - 2) as progress_bar: - for i, t in enumerate(timesteps[1:-1]): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs={"timestep": t}, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # regularization of the noise prediction - with torch.enable_grad(): - for _ in range(num_reg_steps): - if lambda_auto_corr > 0: - for _ in range(num_auto_corr_rolls): - var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - l_ac = self.auto_corr_loss(var, generator=generator) - l_ac.backward() - - grad = var.grad.detach() / num_auto_corr_rolls - noise_pred = noise_pred - lambda_auto_corr * grad - - if lambda_kl > 0: - var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) - - l_kld = self.kl_divergence(var) - l_kld.backward() - - grad = var.grad.detach() - noise_pred = noise_pred - lambda_kl * grad - - noise_pred = noise_pred.detach() - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 - ): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - inverted_latents = latents.detach().clone() - - # 8. Post-processing - image = self.decode_latents(latents.detach()) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - # 9. Convert to PIL. - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (inverted_latents, image) - - return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py deleted file mode 100644 index 0a34499e090ca9d434614b16ed30c97f2875c445..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ /dev/null @@ -1,769 +0,0 @@ -# Copyright 2023 Susung Hong and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -import torch.nn.functional as F -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import StableDiffusionSAGPipeline - - >>> pipe = StableDiffusionSAGPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt, sag_scale=0.75).images[0] - ``` -""" - - -# processes and stores attention probabilities -class CrossAttnStoreProcessor: - def __init__(self): - self.attention_probs = None - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - ): - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: - encoder_hidden_states = attn.norm_cross(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - self.attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(self.attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input -class StableDiffusionSAGPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): - from accelerate import cpu_offload - else: - raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") - - device = torch.device(f"cuda:{gpu_id}") - - if self.device.type != "cpu": - self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - sag_scale: float = 0.75, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - sag_scale (`float`, *optional*, defaults to 0.75): - SAG scale as defined in [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance] - (https://arxiv.org/abs/2210.00939). `sag_scale` is defined as `s_s` of equation (24) of SAG paper: - https://arxiv.org/pdf/2210.00939.pdf. Typically chosen between [0, 1.0] for better quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # and `sag_scale` is` `s` of equation (15) - # of the self-attentnion guidance paper: https://arxiv.org/pdf/2210.00939.pdf - # `sag_scale = 0` means no self-attention guidance - do_self_attention_guidance = sag_scale > 0.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop - store_processor = CrossAttnStoreProcessor() - self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # perform self-attention guidance with the stored self-attentnion map - if do_self_attention_guidance: - # classifier-free guidance produces two chunks of attention map - # and we only use unconditional one according to equation (24) - # in https://arxiv.org/pdf/2210.00939.pdf - if do_classifier_free_guidance: - # DDIM-like prediction of x0 - pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) - # get the stored attention maps - uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) - # self-attention-based degrading of latents - degraded_latents = self.sag_masking( - pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) - ) - uncond_emb, _ = prompt_embeds.chunk(2) - # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample - noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) - else: - # DDIM-like prediction of x0 - pred_x0 = self.pred_x0(latents, noise_pred, t) - # get the stored attention maps - cond_attn = store_processor.attention_probs - # self-attention-based degrading of latents - degraded_latents = self.sag_masking( - pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) - ) - # forward and give guidance - degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample - noise_pred += sag_scale * (noise_pred - degraded_pred) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - def sag_masking(self, original_latents, attn_map, t, eps): - # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf - bh, hw1, hw2 = attn_map.shape - b, latent_channel, latent_h, latent_w = original_latents.shape - h = self.unet.attention_head_dim - if isinstance(h, list): - h = h[-1] - map_size = math.isqrt(hw1) - - # Produce attention mask - attn_map = attn_map.reshape(b, h, hw1, hw2) - attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 - attn_mask = ( - attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) - ) - attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) - - # Blur according to the self-attention mask - degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) - degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) - - # Noise it again to match the noise level - degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) - - return degraded_latents - - # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step - # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) - def pred_x0(self, sample, model_output, timestep): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - - beta_prod_t = 1 - alpha_prod_t - if self.scheduler.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.scheduler.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.scheduler.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - # predict V - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - else: - raise ValueError( - f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," - " or `v_prediction`" - ) - - return pred_original_sample - - def pred_epsilon(self, sample, model_output, timestep): - alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - - beta_prod_t = 1 - alpha_prod_t - if self.scheduler.config.prediction_type == "epsilon": - pred_eps = model_output - elif self.scheduler.config.prediction_type == "sample": - pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) - elif self.scheduler.config.prediction_type == "v_prediction": - pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output - else: - raise ValueError( - f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," - " or `v_prediction`" - ) - - return pred_eps - - -# Gaussian blur -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = F.pad(img, padding, mode="reflect") - img = F.conv2d(img, kernel2d, groups=img.shape[-3]) - - return img diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py deleted file mode 100644 index 36eff33c59b037c28661648952b7d78a8ef7e063..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ /dev/null @@ -1,593 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -from transformers import CLIPTextModel, CLIPTokenizer - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers -from ...utils import deprecate, is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - - image = [np.array(i.resize((w, h)))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class StableDiffusionUpscalePipeline(DiffusionPipeline): - r""" - Pipeline for text-guided image super-resolution using Stable Diffusion 2. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - low_res_scheduler ([`SchedulerMixin`]): - A scheduler used to add initial noise to the low res conditioning image. It must be an instance of - [`DDPMScheduler`]. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - low_res_scheduler: DDPMScheduler, - scheduler: KarrasDiffusionSchedulers, - max_noise_level: int = 350, - ): - super().__init__() - - # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate - is_vae_scaling_factor_set_to_0_08333 = ( - hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 - ) - if not is_vae_scaling_factor_set_to_0_08333: - deprecation_message = ( - "The configuration file of the vae does not contain `scaling_factor` or it is set to" - f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" - " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to 0.08333" - " Please make sure to update the config accordingly, as not doing so might lead to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be" - " very nice if you could open a Pull Request for the `vae/config.json` file" - ) - deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) - vae.register_to_config(scaling_factor=0.08333) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - low_res_scheduler=low_res_scheduler, - scheduler=scheduler, - ) - self.register_to_config(max_noise_level=max_noise_level) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.text_encoder]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def check_inputs(self, prompt, image, noise_level, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" - ) - - # verify batch size of prompt and image are same if image is a list or tensor - if isinstance(image, list) or isinstance(image, torch.Tensor): - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) - if isinstance(image, list): - image_batch_size = len(image) - else: - image_batch_size = image.shape[0] - if batch_size != image_batch_size: - raise ValueError( - f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." - " Please make sure that passed `prompt` matches the batch size of `image`." - ) - - # check noise level - if noise_level > self.config.max_noise_level: - raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height, width) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - num_inference_steps: int = 75, - guidance_scale: float = 9.0, - noise_level: int = 20, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): - `Image`, or tensor representing an image batch which will be upscaled. * - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` - is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - ```py - >>> import requests - >>> from PIL import Image - >>> from io import BytesIO - >>> from diffusers import StableDiffusionUpscalePipeline - >>> import torch - - >>> # load model and scheduler - >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" - >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( - ... model_id, revision="fp16", torch_dtype=torch.float16 - ... ) - >>> pipeline = pipeline.to("cuda") - - >>> # let's download an image - >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" - >>> response = requests.get(url) - >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") - >>> low_res_img = low_res_img.resize((128, 128)) - >>> prompt = "a white cat" - - >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] - >>> upscaled_image.save("upsampled_cat.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - - # 1. Check inputs - self.check_inputs(prompt, image, noise_level, callback_steps) - - if image is None: - raise ValueError("`image` input cannot be undefined.") - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Preprocess image - image = preprocess(image) - image = image.to(dtype=prompt_embeds.dtype, device=device) - - # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Add noise to image - noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) - noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) - image = self.low_res_scheduler.add_noise(image, noise, noise_level) - - batch_multiplier = 2 if do_classifier_free_guidance else 1 - image = torch.cat([image] * batch_multiplier * num_images_per_prompt) - noise_level = torch.cat([noise_level] * image.shape[0]) - - # 6. Prepare latent variables - height, width = image.shape[2:] - num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Check that sizes of image and latents match - num_channels_image = image.shape[1] - if num_channels_latents + num_channels_image != self.unet.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" - " `pipeline.unet` or your `image` input." - ) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, image], dim=1) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 10. Post-processing - # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=torch.float32) - image = self.decode_latents(latents.float()) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py deleted file mode 100644 index 94780c9eb2603b77a2307554618e16767db92a98..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ /dev/null @@ -1,894 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from transformers.models.clip.modeling_clip import CLIPTextModelOutput - -from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel -from ...models.embeddings import get_timestep_embedding -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import StableUnCLIPPipeline - - >>> pipe = StableUnCLIPPipeline.from_pretrained( - ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 - ... ) # TODO update model path - >>> pipe = pipe.to("cuda") - - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> images = pipe(prompt).images - >>> images[0].save("astronaut_horse.png") - ``` -""" - - -class StableUnCLIPPipeline(DiffusionPipeline): - """ - Pipeline for text-to-image generation using stable unCLIP. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - prior_tokenizer ([`CLIPTokenizer`]): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - prior_text_encoder ([`CLIPTextModelWithProjection`]): - Frozen text-encoder. - prior ([`PriorTransformer`]): - The canonincal unCLIP prior to approximate the image embedding from the text embedding. - prior_scheduler ([`KarrasDiffusionSchedulers`]): - Scheduler used in the prior denoising process. - image_normalizer ([`StableUnCLIPImageNormalizer`]): - Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image - embeddings after the noise has been applied. - image_noising_scheduler ([`KarrasDiffusionSchedulers`]): - Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined - by `noise_level` in `StableUnCLIPPipeline.__call__`. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`KarrasDiffusionSchedulers`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - """ - - # prior components - prior_tokenizer: CLIPTokenizer - prior_text_encoder: CLIPTextModelWithProjection - prior: PriorTransformer - prior_scheduler: KarrasDiffusionSchedulers - - # image noising components - image_normalizer: StableUnCLIPImageNormalizer - image_noising_scheduler: KarrasDiffusionSchedulers - - # regular denoising components - tokenizer: CLIPTokenizer - text_encoder: CLIPTextModel - unet: UNet2DConditionModel - scheduler: KarrasDiffusionSchedulers - - vae: AutoencoderKL - - def __init__( - self, - # prior components - prior_tokenizer: CLIPTokenizer, - prior_text_encoder: CLIPTextModelWithProjection, - prior: PriorTransformer, - prior_scheduler: KarrasDiffusionSchedulers, - # image noising components - image_normalizer: StableUnCLIPImageNormalizer, - image_noising_scheduler: KarrasDiffusionSchedulers, - # regular denoising components - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModelWithProjection, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - # vae - vae: AutoencoderKL, - ): - super().__init__() - - self.register_modules( - prior_tokenizer=prior_tokenizer, - prior_text_encoder=prior_text_encoder, - prior=prior, - prior_scheduler=prior_scheduler, - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - tokenizer=tokenizer, - text_encoder=text_encoder, - unet=unet, - scheduler=scheduler, - vae=vae, - ) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's - models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only - when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - # TODO: self.prior.post_process_latents and self.image_noiser.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list - models = [ - self.prior_text_encoder, - self.text_encoder, - self.unet, - self.vae, - ] - for cpu_offloaded_model in models: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder - def _encode_prior_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, - text_attention_mask: Optional[torch.Tensor] = None, - ): - if text_model_output is None: - batch_size = len(prompt) if isinstance(prompt, list) else 1 - # get prompt text embeddings - text_inputs = self.prior_tokenizer( - prompt, - padding="max_length", - max_length=self.prior_tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - text_mask = text_inputs.attention_mask.bool().to(device) - - untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.prior_tokenizer.batch_decode( - untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] - - prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) - - prompt_embeds = prior_text_encoder_output.text_embeds - prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state - - else: - batch_size = text_model_output[0].shape[0] - prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1] - text_mask = text_attention_mask - - prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) - prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - - uncond_input = self.prior_tokenizer( - uncond_tokens, - padding="max_length", - max_length=self.prior_tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - uncond_text_mask = uncond_input.attention_mask.bool().to(device) - negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( - uncond_input.input_ids.to(device) - ) - - negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds - uncond_prior_text_encoder_hidden_states = ( - negative_prompt_embeds_prior_text_encoder_output.last_hidden_state - ) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) - - seq_len = uncond_prior_text_encoder_hidden_states.shape[1] - uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat( - 1, num_images_per_prompt, 1 - ) - uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - # done duplicates - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - prior_text_encoder_hidden_states = torch.cat( - [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states] - ) - - text_mask = torch.cat([uncond_text_mask, text_mask]) - - return prompt_embeds, prior_text_encoder_hidden_states, text_mask - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler - def prepare_prior_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the prior_scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - noise_level, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." - ) - - if prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - - if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." - ) - - if prompt is not None and negative_prompt is not None: - if type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: - raise ValueError( - f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." - ) - - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents - def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - latents = latents * scheduler.init_noise_sigma - return latents - - def noise_image_embeddings( - self, - image_embeds: torch.Tensor, - noise_level: int, - noise: Optional[torch.FloatTensor] = None, - generator: Optional[torch.Generator] = None, - ): - """ - Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher - `noise_level` increases the variance in the final un-noised images. - - The noise is applied in two ways - 1. A noise schedule is applied directly to the embeddings - 2. A vector of sinusoidal time embeddings are appended to the output. - - In both cases, the amount of noise is controlled by the same `noise_level`. - - The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. - """ - if noise is None: - noise = randn_tensor( - image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype - ) - - noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) - - image_embeds = self.image_normalizer.scale(image_embeds) - - image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) - - image_embeds = self.image_normalizer.unscale(image_embeds) - - noise_level = get_timestep_embedding( - timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 - ) - - # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, - # but we might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - noise_level = noise_level.to(image_embeds.dtype) - - image_embeds = torch.cat((image_embeds, noise_level), 1) - - return image_embeds - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - # regular denoising process args - prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 20, - guidance_scale: float = 10.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - noise_level: int = 0, - # prior args - prior_num_inference_steps: int = 25, - prior_guidance_scale: float = 4.0, - prior_latents: Optional[torch.FloatTensor] = None, - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 20): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 10.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - noise_level (`int`, *optional*, defaults to `0`): - The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in - the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. - prior_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps in the prior denoising process. More denoising steps usually lead to a - higher quality image at the expense of slower inference. - prior_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale for the prior denoising process as defined in [Classifier-Free Diffusion - Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of - [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. - prior_latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - embedding generation in the prior denoising process. Can be used to tweak the same generation with - different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied - random `generator`. - - Examples: - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt=prompt, - height=height, - width=width, - callback_steps=callback_steps, - noise_level=noise_level, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - batch_size = batch_size * num_images_per_prompt - - device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 - - # 3. Encode input prompt - prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=prior_do_classifier_free_guidance, - ) - - # 4. Prepare prior timesteps - self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) - prior_timesteps_tensor = self.prior_scheduler.timesteps - - # 5. Prepare prior latent variables - embedding_dim = self.prior.config.embedding_dim - prior_latents = self.prepare_latents( - (batch_size, embedding_dim), - prior_prompt_embeds.dtype, - device, - generator, - prior_latents, - self.prior_scheduler, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) - - # 7. Prior denoising loop - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents - latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) - - predicted_image_embedding = self.prior( - latent_model_input, - timestep=t, - proj_embedding=prior_prompt_embeds, - encoder_hidden_states=prior_text_encoder_hidden_states, - attention_mask=prior_text_mask, - ).predicted_image_embedding - - if prior_do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - - prior_latents = self.prior_scheduler.step( - predicted_image_embedding, - timestep=t, - sample=prior_latents, - **prior_extra_step_kwargs, - ).prev_sample - - if callback is not None and i % callback_steps == 0: - callback(i, t, prior_latents) - - prior_latents = self.prior.post_process_latents(prior_latents) - - image_embeds = prior_latents - - # done prior - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 8. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 9. Prepare image embeddings - image_embeds = self.noise_image_embeddings( - image_embeds=image_embeds, - noise_level=noise_level, - generator=generator, - ) - - if do_classifier_free_guidance: - negative_prompt_embeds = torch.zeros_like(image_embeds) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) - - # 10. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 11. Prepare latent variables - num_channels_latents = self.unet.in_channels - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - latents = self.prepare_latents( - shape=shape, - dtype=prompt_embeds.dtype, - device=device, - generator=generator, - latents=latents, - scheduler=self.scheduler, - ) - - # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 13. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - class_labels=image_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 14. Post-processing - image = self.decode_latents(latents) - - # 15. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/my_diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py deleted file mode 100644 index e98dfd6f0d3a0b08c3c0172217ede89681aff3c0..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ /dev/null @@ -1,787 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import PIL -import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection - -from diffusers.utils.import_utils import is_accelerate_available - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...models.embeddings import get_timestep_embedding -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, randn_tensor, replace_example_docstring -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import requests - >>> import torch - >>> from PIL import Image - >>> from io import BytesIO - - >>> from diffusers import StableUnCLIPImg2ImgPipeline - - >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( - ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 - ... ) # TODO update model path - >>> pipe = pipe.to("cuda") - - >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - - >>> response = requests.get(url) - >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") - >>> init_image = init_image.resize((768, 512)) - - >>> prompt = "A fantasy landscape, trending on artstation" - - >>> images = pipe(prompt, init_image).images - >>> images[0].save("fantasy_landscape.png") - ``` -""" - - -class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): - """ - Pipeline for text-guided image to image generation using stable unCLIP. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - feature_extractor ([`CLIPFeatureExtractor`]): - Feature extractor for image pre-processing before being encoded. - image_encoder ([`CLIPVisionModelWithProjection`]): - CLIP vision model for encoding images. - image_normalizer ([`StableUnCLIPImageNormalizer`]): - Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image - embeddings after the noise has been applied. - image_noising_scheduler ([`KarrasDiffusionSchedulers`]): - Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined - by `noise_level` in `StableUnCLIPPipeline.__call__`. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`KarrasDiffusionSchedulers`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - """ - - # image encoding components - feature_extractor: CLIPFeatureExtractor - image_encoder: CLIPVisionModelWithProjection - - # image noising components - image_normalizer: StableUnCLIPImageNormalizer - image_noising_scheduler: KarrasDiffusionSchedulers - - # regular denoising components - tokenizer: CLIPTokenizer - text_encoder: CLIPTextModel - unet: UNet2DConditionModel - scheduler: KarrasDiffusionSchedulers - - vae: AutoencoderKL - - def __init__( - self, - # image encoding components - feature_extractor: CLIPFeatureExtractor, - image_encoder: CLIPVisionModelWithProjection, - # image noising components - image_normalizer: StableUnCLIPImageNormalizer, - image_noising_scheduler: KarrasDiffusionSchedulers, - # regular denoising components - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - # vae - vae: AutoencoderKL, - ): - super().__init__() - - self.register_modules( - feature_extractor=feature_extractor, - image_encoder=image_encoder, - image_normalizer=image_normalizer, - image_noising_scheduler=image_noising_scheduler, - tokenizer=tokenizer, - text_encoder=text_encoder, - unet=unet, - scheduler=scheduler, - vae=vae, - ) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's - models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only - when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list - models = [ - self.image_encoder, - self.text_encoder, - self.unet, - self.vae, - ] - for cpu_offloaded_model in models: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def _encode_image( - self, - image, - device, - batch_size, - num_images_per_prompt, - do_classifier_free_guidance, - noise_level, - generator, - image_embeds, - ): - dtype = next(self.image_encoder.parameters()).dtype - - if isinstance(image, PIL.Image.Image): - # the image embedding should repeated so it matches the total batch size of the prompt - repeat_by = batch_size - else: - # assume the image input is already properly batched and just needs to be repeated so - # it matches the num_images_per_prompt. - # - # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched - # `image_embeds`. If those happen to be common use cases, let's think harder about - # what the expected dimensions of inputs should be and how we handle the encoding. - repeat_by = num_images_per_prompt - - if not image_embeds: - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(images=image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - - image_embeds = self.noise_image_embeddings( - image_embeds=image_embeds, - noise_level=noise_level, - generator=generator, - ) - - # duplicate image embeddings for each generation per prompt, using mps friendly method - image_embeds = image_embeds.unsqueeze(1) - bs_embed, seq_len, _ = image_embeds.shape - image_embeds = image_embeds.repeat(1, repeat_by, 1) - image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) - image_embeds = image_embeds.squeeze(1) - - if do_classifier_free_guidance: - negative_prompt_embeds = torch.zeros_like(image_embeds) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) - - return image_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - height, - width, - callback_steps, - noise_level, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - image_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." - ) - - if prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - - if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." - ) - - if prompt is not None and negative_prompt is not None: - if type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: - raise ValueError( - f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." - ) - - if image is not None and image_embeds is not None: - raise ValueError( - "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." - ) - - if image is None and image_embeds is None: - raise ValueError( - "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." - ) - - if image is not None: - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" - f" {type(image)}" - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings - def noise_image_embeddings( - self, - image_embeds: torch.Tensor, - noise_level: int, - noise: Optional[torch.FloatTensor] = None, - generator: Optional[torch.Generator] = None, - ): - """ - Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher - `noise_level` increases the variance in the final un-noised images. - - The noise is applied in two ways - 1. A noise schedule is applied directly to the embeddings - 2. A vector of sinusoidal time embeddings are appended to the output. - - In both cases, the amount of noise is controlled by the same `noise_level`. - - The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. - """ - if noise is None: - noise = randn_tensor( - image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype - ) - - noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) - - image_embeds = self.image_normalizer.scale(image_embeds) - - image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) - - image_embeds = self.image_normalizer.unscale(image_embeds) - - noise_level = get_timestep_embedding( - timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 - ) - - # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, - # but we might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - noise_level = noise_level.to(image_embeds.dtype) - - image_embeds = torch.cat((image_embeds, noise_level), 1) - - return image_embeds - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 20, - guidance_scale: float = 10, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - noise_level: int = 0, - image_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): - `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which - the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the - latents in the denoising process such as in the standard stable diffusion text guided image variation - process. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 20): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 10.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - noise_level (`int`, *optional*, defaults to `0`): - The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in - the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. - image_embeds (`torch.FloatTensor`, *optional*): - Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in - the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as - `latents`. - - Examples: - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt=prompt, - image=image, - height=height, - width=width, - callback_steps=callback_steps, - noise_level=noise_level, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - image_embeds=image_embeds, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - batch_size = batch_size * num_images_per_prompt - - device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Encoder input image - noise_level = torch.tensor([noise_level], device=device) - image_embeds = self._encode_image( - image=image, - device=device, - batch_size=batch_size, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - noise_level=noise_level, - generator=generator, - image_embeds=image_embeds, - ) - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size=batch_size, - num_channels_latents=num_channels_latents, - height=height, - width=width, - dtype=prompt_embeds.dtype, - device=device, - generator=generator, - latents=latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - class_labels=image_embeds, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/stable_diffusion/safety_checker.py b/my_diffusers/pipelines/stable_diffusion/safety_checker.py deleted file mode 100644 index 04ceb5a3517bc43c9ed81b78a8c1db161d2b2d8c..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/safety_checker.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -import torch.nn as nn -from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel - -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -def cosine_distance(image_embeds, text_embeds): - normalized_image_embeds = nn.functional.normalize(image_embeds) - normalized_text_embeds = nn.functional.normalize(text_embeds) - return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) - - -class StableDiffusionSafetyChecker(PreTrainedModel): - config_class = CLIPConfig - - _no_split_modules = ["CLIPEncoderLayer"] - - def __init__(self, config: CLIPConfig): - super().__init__(config) - - self.vision_model = CLIPVisionModel(config.vision_config) - self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) - - self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) - self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) - - self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) - self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) - - @torch.no_grad() - def forward(self, clip_input, images): - pooled_output = self.vision_model(clip_input)[1] # pooled_output - image_embeds = self.visual_projection(pooled_output) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() - - result = [] - batch_size = image_embeds.shape[0] - for i in range(batch_size): - result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} - - # increase this value to create a stronger `nfsw` filter - # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 - - for concept_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concept_idx] - concept_threshold = self.special_care_embeds_weights[concept_idx].item() - result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concept_idx] > 0: - result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) - adjustment = 0.01 - - for concept_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concept_idx] - concept_threshold = self.concept_embeds_weights[concept_idx].item() - result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concept_idx] > 0: - result_img["bad_concepts"].append(concept_idx) - - result.append(result_img) - - has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] - - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept and False: - images[idx] = np.zeros(images[idx].shape) # black image - - if any(has_nsfw_concepts): - logger.warning( - "Potential NSFW content was detected in one or more images. A black image 4 will be returned instead." - " Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts - - @torch.no_grad() - def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): - pooled_output = self.vision_model(clip_input)[1] # pooled_output - image_embeds = self.visual_projection(pooled_output) - - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) - cos_dist = cosine_distance(image_embeds, self.concept_embeds) - - # increase this value to create a stronger `nsfw` filter - # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 - - special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment - # special_scores = special_scores.round(decimals=3) - special_care = torch.any(special_scores > 0, dim=1) - special_adjustment = special_care * 0.01 - special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) - - concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment - # concept_scores = concept_scores.round(decimals=3) - has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) - - images[has_nsfw_concepts] = 0.0 # black image - - return images, has_nsfw_concepts diff --git a/my_diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/my_diffusers/pipelines/stable_diffusion/safety_checker_flax.py deleted file mode 100644 index 3a8c3167954016b3b89f16caf8348661cd3a27ef..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple - -import jax -import jax.numpy as jnp -from flax import linen as nn -from flax.core.frozen_dict import FrozenDict -from transformers import CLIPConfig, FlaxPreTrainedModel -from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule - - -def jax_cosine_distance(emb_1, emb_2, eps=1e-12): - norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T - norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T - return jnp.matmul(norm_emb_1, norm_emb_2.T) - - -class FlaxStableDiffusionSafetyCheckerModule(nn.Module): - config: CLIPConfig - dtype: jnp.dtype = jnp.float32 - - def setup(self): - self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) - self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) - - self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) - self.special_care_embeds = self.param( - "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) - ) - - self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) - self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) - - def __call__(self, clip_input): - pooled_output = self.vision_model(clip_input)[1] - image_embeds = self.visual_projection(pooled_output) - - special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) - cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) - - # increase this value to create a stronger `nfsw` filter - # at the cost of increasing the possibility of filtering benign image inputs - adjustment = 0.0 - - special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment - special_scores = jnp.round(special_scores, 3) - is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) - # Use a lower threshold if an image has any special care concept - special_adjustment = is_special_care * 0.01 - - concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment - concept_scores = jnp.round(concept_scores, 3) - has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) - - return has_nsfw_concepts - - -class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): - config_class = CLIPConfig - main_input_name = "clip_input" - module_class = FlaxStableDiffusionSafetyCheckerModule - - def __init__( - self, - config: CLIPConfig, - input_shape: Optional[Tuple] = None, - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - if input_shape is None: - input_shape = (1, 224, 224, 3) - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - - def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: - # init input tensor - clip_input = jax.random.normal(rng, input_shape) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - random_params = self.module.init(rngs, clip_input)["params"] - - return random_params - - def __call__( - self, - clip_input, - params: dict = None, - ): - clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) - - return self.module.apply( - {"params": params or self.params}, - jnp.array(clip_input, dtype=jnp.float32), - rngs={}, - ) diff --git a/my_diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/my_diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py deleted file mode 100644 index 9c7f190d0505ba8a1427f48438f7a04ffd611eed..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.modeling_utils import ModelMixin - - -class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): - """ - This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. - - It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image - embeddings. - """ - - @register_to_config - def __init__( - self, - embedding_dim: int = 768, - ): - super().__init__() - - self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.std = nn.Parameter(torch.ones(1, embedding_dim)) - - def scale(self, embeds): - embeds = (embeds - self.mean) * 1.0 / self.std - return embeds - - def unscale(self, embeds): - embeds = (embeds * self.std) + self.mean - return embeds diff --git a/my_diffusers/pipelines/stable_diffusion_safe/__init__.py b/my_diffusers/pipelines/stable_diffusion_safe/__init__.py deleted file mode 100644 index 5aecfeac112e53b2fc49278c1acaa95a6c0c7257..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion_safe/__init__.py +++ /dev/null @@ -1,71 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional, Union - -import numpy as np -import PIL -from PIL import Image - -from ...utils import BaseOutput, is_torch_available, is_transformers_available - - -@dataclass -class SafetyConfig(object): - WEAK = { - "sld_warmup_steps": 15, - "sld_guidance_scale": 20, - "sld_threshold": 0.0, - "sld_momentum_scale": 0.0, - "sld_mom_beta": 0.0, - } - MEDIUM = { - "sld_warmup_steps": 10, - "sld_guidance_scale": 1000, - "sld_threshold": 0.01, - "sld_momentum_scale": 0.3, - "sld_mom_beta": 0.4, - } - STRONG = { - "sld_warmup_steps": 7, - "sld_guidance_scale": 2000, - "sld_threshold": 0.025, - "sld_momentum_scale": 0.5, - "sld_mom_beta": 0.7, - } - MAX = { - "sld_warmup_steps": 0, - "sld_guidance_scale": 5000, - "sld_threshold": 1.0, - "sld_momentum_scale": 0.5, - "sld_mom_beta": 0.7, - } - - -@dataclass -class StableDiffusionSafePipelineOutput(BaseOutput): - """ - Output class for Safe Stable Diffusion pipelines. - - Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. - nsfw_content_detected (`List[bool]`) - List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, or `None` if safety checking could not be performed. - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" - (nsfw) content, or `None` if no safety check was performed or no images were flagged. - applied_safety_concept (`str`) - The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled - """ - - images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: Optional[List[bool]] - unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] - applied_safety_concept: Optional[str] - - -if is_transformers_available() and is_torch_available(): - from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe - from .safety_checker import SafeStableDiffusionSafetyChecker diff --git a/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 4e0637f9d6cb544374abe554ee8ba632c2e1a029..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-311.pyc deleted file mode 100644 index 6ee49fca665cf1877f75a892ded1a549537a1522..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-311.pyc b/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-311.pyc deleted file mode 100644 index 36832fdaf0471c9de73b4c4bbece4e2e2b0a759b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/my_diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py deleted file mode 100644 index ae66171d7573b2fb29acf9ebac6539732929806d..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ /dev/null @@ -1,736 +0,0 @@ -import inspect -import warnings -from typing import Callable, List, Optional, Union - -import numpy as np -import torch -from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - -from ...configuration_utils import FrozenDict -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionSafePipelineOutput -from .safety_checker import SafeStableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class StableDiffusionPipelineSafe(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Safe Latent Diffusion. - - The implementation is based on the [`StableDiffusionPipeline`] - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: SafeStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - requires_safety_checker: bool = True, - ): - super().__init__() - safety_concept: Optional[str] = ( - "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," - " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" - " abuse, brutality, cruelty" - ) - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self._safety_text_concept = safety_concept - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - @property - def safety_concept(self): - r""" - Getter method for the safety concept used with SLD - - Returns: - `str`: The text describing the safety concept - """ - return self._safety_text_concept - - @safety_concept.setter - def safety_concept(self, concept): - r""" - Setter method for the safety concept used with SLD - - Args: - concept (`str`): - The text of the new safety concept - """ - self._safety_text_concept = concept - - def enable_sequential_cpu_offload(self): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device("cuda") - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - enable_safety_guidance, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - - if not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # Encode the safety concept text - if enable_safety_guidance: - safety_concept_input = self.tokenizer( - [self._safety_text_concept], - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] - - # duplicate safety embeddings for each generation per prompt, using mps friendly method - seq_len = safety_embeddings.shape[1] - safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) - safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance + sld, we need to do three forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing three forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) - - else: - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def run_safety_checker(self, image, device, dtype, enable_safety_guidance): - if self.safety_checker is not None: - images = image.copy() - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - flagged_images = np.zeros((2, *image.shape[1:])) - if any(has_nsfw_concept): - logger.warning( - "Potential NSFW content was detected in one or more images. A black image will be returned. Me llamo Cesar" - " instead." - f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" - ) - for idx, has_nsfw_concept in enumerate(has_nsfw_concept): - if has_nsfw_concept: - flagged_images[idx] = images[idx] - #image[idx] = np.zeros(image[idx].shape) # black image - else: - has_nsfw_concept = None - flagged_images = None - return image, has_nsfw_concept, flagged_images - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def perform_safety_guidance( - self, - enable_safety_guidance, - safety_momentum, - noise_guidance, - noise_pred_out, - i, - sld_guidance_scale, - sld_warmup_steps, - sld_threshold, - sld_momentum_scale, - sld_mom_beta, - ): - # Perform SLD guidance - if enable_safety_guidance: - if safety_momentum is None: - safety_momentum = torch.zeros_like(noise_guidance) - noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] - noise_pred_safety_concept = noise_pred_out[2] - - # Equation 6 - scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) - - # Equation 6 - safety_concept_scale = torch.where( - (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale - ) - - # Equation 4 - noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) - - # Equation 7 - noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum - - # Equation 8 - safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety - - if i >= sld_warmup_steps: # Warmup - # Equation 3 - noise_guidance = noise_guidance - noise_guidance_safety - return noise_guidance, safety_momentum - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - sld_guidance_scale: Optional[float] = 1000, - sld_warmup_steps: Optional[int] = 10, - sld_threshold: Optional[float] = 0.01, - sld_momentum_scale: Optional[float] = 0.3, - sld_mom_beta: Optional[float] = 0.4, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - sld_guidance_scale (`float`, *optional*, defaults to 1000): - Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). - `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be - disabled. - sld_warmup_steps (`int`, *optional*, defaults to 10): - Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than - `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent - Diffusion](https://arxiv.org/abs/2211.05105). - sld_threshold (`float`, *optional*, defaults to 0.01): - Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold` - is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). - sld_momentum_scale (`float`, *optional*, defaults to 0.3): - Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0 - momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller - than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent - Diffusion](https://arxiv.org/abs/2211.05105). - sld_mom_beta (`float`, *optional*, defaults to 0.4): - Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous - momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller - than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent - Diffusion](https://arxiv.org/abs/2211.05105). - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance - if not enable_safety_guidance: - warnings.warn("Safety checker disabled!") - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - safety_momentum = None - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * (3 if enable_safety_guidance else 2)) - if do_classifier_free_guidance - else latents - ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) - noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] - - # default classifier free guidance - noise_guidance = noise_pred_text - noise_pred_uncond - - # Perform SLD guidance - if enable_safety_guidance: - if safety_momentum is None: - safety_momentum = torch.zeros_like(noise_guidance) - noise_pred_safety_concept = noise_pred_out[2] - - # Equation 6 - scale = torch.clamp( - torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 - ) - - # Equation 6 - safety_concept_scale = torch.where( - (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, - torch.zeros_like(scale), - scale, - ) - - # Equation 4 - noise_guidance_safety = torch.mul( - (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale - ) - - # Equation 7 - noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum - - # Equation 8 - safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety - - if i >= sld_warmup_steps: # Warmup - # Equation 3 - noise_guidance = noise_guidance - noise_guidance_safety - - noise_pred = noise_pred_uncond + guidance_scale * noise_guidance - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept, flagged_images = self.run_safety_checker( - image, device, prompt_embeds.dtype, enable_safety_guidance - ) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - if flagged_images is not None: - flagged_images = self.numpy_to_pil(flagged_images) - - if not return_dict: - return ( - image, - has_nsfw_concept, - self._safety_text_concept if enable_safety_guidance else None, - flagged_images, - ) - - return StableDiffusionSafePipelineOutput( - images=image, - nsfw_content_detected=has_nsfw_concept, - applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, - unsafe_images=flagged_images, - ) diff --git a/my_diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/my_diffusers/pipelines/stable_diffusion_safe/safety_checker.py deleted file mode 100644 index b8368544fd817879f335f9d72ca78bfe7ebe45a1..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stable_diffusion_safe/safety_checker.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel - -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -def cosine_distance(image_embeds, text_embeds): - normalized_image_embeds = nn.functional.normalize(image_embeds) - normalized_text_embeds = nn.functional.normalize(text_embeds) - return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) - - -class SafeStableDiffusionSafetyChecker(PreTrainedModel): - config_class = CLIPConfig - - _no_split_modules = ["CLIPEncoderLayer"] - - def __init__(self, config: CLIPConfig): - super().__init__(config) - - self.vision_model = CLIPVisionModel(config.vision_config) - self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) - - self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) - self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) - - self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) - self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) - - @torch.no_grad() - def forward(self, clip_input, images): - pooled_output = self.vision_model(clip_input)[1] # pooled_output - image_embeds = self.visual_projection(pooled_output) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() - - result = [] - batch_size = image_embeds.shape[0] - for i in range(batch_size): - result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} - - # increase this value to create a stronger `nfsw` filter - # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 - - for concept_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concept_idx] - concept_threshold = self.special_care_embeds_weights[concept_idx].item() - result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concept_idx] > float('inf'): - result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) - adjustment = 0.01 - - for concept_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concept_idx] - concept_threshold = self.concept_embeds_weights[concept_idx].item() - result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concept_idx] > float('inf'): - result_img["bad_concepts"].append(concept_idx) - - result.append(result_img) - - has_nsfw_concepts = [len(res["bad_concepts"]) > float('inf') for res in result] - - return images, has_nsfw_concepts - - @torch.no_grad() - def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): - pooled_output = self.vision_model(clip_input)[1] # pooled_output - image_embeds = self.visual_projection(pooled_output) - - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) - cos_dist = cosine_distance(image_embeds, self.concept_embeds) - - # increase this value to create a stronger `nsfw` filter - # at the cost of increasing the possibility of filtering benign images - adjustment = 0.0 - - special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment - # special_scores = special_scores.round(decimals=3) - special_care = torch.any(special_scores > 0, dim=1) - special_adjustment = special_care * 0.01 - special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) - - concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment - # concept_scores = concept_scores.round(decimals=3) - has_nsfw_concepts = torch.any(concept_scores > float('inf'), dim=1) - - return images, has_nsfw_concepts diff --git a/my_diffusers/pipelines/stochastic_karras_ve/__init__.py b/my_diffusers/pipelines/stochastic_karras_ve/__init__.py deleted file mode 100644 index 5a63c1d24afb2c4f36b0e284f0985a3ff508f4c7..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stochastic_karras_ve/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .pipeline_stochastic_karras_ve import KarrasVePipeline diff --git a/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 413c66efd174bbceca51a1af6d3282af0a36e6ba..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-311.pyc b/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-311.pyc deleted file mode 100644 index dea35e564c81ab27cfbe68b9193cd0aae4a0c54b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py deleted file mode 100644 index 4535500e25924c6f3f8ecdd5ccbcba0b6be5c6c7..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple, Union - -import torch - -from ...models import UNet2DModel -from ...schedulers import KarrasVeScheduler -from ...utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -class KarrasVePipeline(DiffusionPipeline): - r""" - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. - - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 - - Parameters: - unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. - scheduler ([`KarrasVeScheduler`]): - Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. - """ - - # add type hints for linting - unet: UNet2DModel - scheduler: KarrasVeScheduler - - def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): - super().__init__() - self.register_modules(unet=unet, scheduler=scheduler) - - @torch.no_grad() - def __call__( - self, - batch_size: int = 1, - num_inference_steps: int = 50, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - **kwargs, - ) -> Union[Tuple, ImagePipelineOutput]: - r""" - Args: - batch_size (`int`, *optional*, defaults to 1): - The number of images to generate. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is - True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - - img_size = self.unet.config.sample_size - shape = (batch_size, 3, img_size, img_size) - - model = self.unet - - # sample x_0 ~ N(0, sigma_0^2 * I) - sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma - - self.scheduler.set_timesteps(num_inference_steps) - - for t in self.progress_bar(self.scheduler.timesteps): - # here sigma_t == t_i from the paper - sigma = self.scheduler.schedule[t] - sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 - - # 1. Select temporarily increased noise level sigma_hat - # 2. Add new noise to move from sample_i to sample_hat - sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) - - # 3. Predict the noise residual given the noise magnitude `sigma_hat` - # The model inputs and output are adjusted by following eq. (213) in [1]. - model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample - - # 4. Evaluate dx/dt at sigma_hat - # 5. Take Euler step from sigma to sigma_prev - step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) - - if sigma_prev != 0: - # 6. Apply 2nd order correction - # The model inputs and output are adjusted by following eq. (213) in [1]. - model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample - step_output = self.scheduler.step_correct( - model_output, - sigma_hat, - sigma_prev, - sample_hat, - step_output.prev_sample, - step_output["derivative"], - ) - sample = step_output.prev_sample - - sample = (sample / 2 + 0.5).clamp(0, 1) - image = sample.cpu().permute(0, 2, 3, 1).numpy() - if output_type == "pil": - image = self.numpy_to_pil(sample) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/unclip/__init__.py b/my_diffusers/pipelines/unclip/__init__.py deleted file mode 100644 index 075e66bb680aca294b36aa7ad0abb8d0f651cd92..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/unclip/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from ...utils import ( - OptionalDependencyNotAvailable, - is_torch_available, - is_transformers_available, - is_transformers_version, -) - - -try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline -else: - from .pipeline_unclip import UnCLIPPipeline - from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline - from .text_proj import UnCLIPTextProjModel diff --git a/my_diffusers/pipelines/unclip/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/unclip/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 1010eeb5e114d2fd25b3ab91cb24f8912373919b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/unclip/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-311.pyc b/my_diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-311.pyc deleted file mode 100644 index 7af286ddc7768af77e79f2f3be5afff85664edc8..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/unclip/__pycache__/pipeline_unclip.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-311.pyc b/my_diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-311.pyc deleted file mode 100644 index 328edcfb73e3ea4e5191e9c1995244b6ed9e55f5..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/unclip/__pycache__/pipeline_unclip_image_variation.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/unclip/__pycache__/text_proj.cpython-311.pyc b/my_diffusers/pipelines/unclip/__pycache__/text_proj.cpython-311.pyc deleted file mode 100644 index 149701cff0fbd03df9bb2face2f02b2ef74224af..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/unclip/__pycache__/text_proj.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/unclip/pipeline_unclip.py b/my_diffusers/pipelines/unclip/pipeline_unclip.py deleted file mode 100644 index 3aac39b3a3b022724c67cfa5d387a08f03e5bbe7..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/unclip/pipeline_unclip.py +++ /dev/null @@ -1,534 +0,0 @@ -# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import List, Optional, Tuple, Union - -import torch -from torch.nn import functional as F -from transformers import CLIPTextModelWithProjection, CLIPTokenizer -from transformers.models.clip.modeling_clip import CLIPTextModelOutput - -from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel -from ...pipelines import DiffusionPipeline -from ...pipelines.pipeline_utils import ImagePipelineOutput -from ...schedulers import UnCLIPScheduler -from ...utils import is_accelerate_available, logging, randn_tensor -from .text_proj import UnCLIPTextProjModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class UnCLIPPipeline(DiffusionPipeline): - """ - Pipeline for text-to-image generation using unCLIP - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - text_encoder ([`CLIPTextModelWithProjection`]): - Frozen text-encoder. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - prior ([`PriorTransformer`]): - The canonincal unCLIP prior to approximate the image embedding from the text embedding. - text_proj ([`UnCLIPTextProjModel`]): - Utility class to prepare and combine the embeddings before they are passed to the decoder. - decoder ([`UNet2DConditionModel`]): - The decoder to invert the image embedding into an image. - super_res_first ([`UNet2DModel`]): - Super resolution unet. Used in all but the last step of the super resolution diffusion process. - super_res_last ([`UNet2DModel`]): - Super resolution unet. Used in the last step of the super resolution diffusion process. - prior_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the prior denoising process. Just a modified DDPMScheduler. - decoder_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. - super_res_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. - - """ - - prior: PriorTransformer - decoder: UNet2DConditionModel - text_proj: UnCLIPTextProjModel - text_encoder: CLIPTextModelWithProjection - tokenizer: CLIPTokenizer - super_res_first: UNet2DModel - super_res_last: UNet2DModel - - prior_scheduler: UnCLIPScheduler - decoder_scheduler: UnCLIPScheduler - super_res_scheduler: UnCLIPScheduler - - def __init__( - self, - prior: PriorTransformer, - decoder: UNet2DConditionModel, - text_encoder: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - text_proj: UnCLIPTextProjModel, - super_res_first: UNet2DModel, - super_res_last: UNet2DModel, - prior_scheduler: UnCLIPScheduler, - decoder_scheduler: UnCLIPScheduler, - super_res_scheduler: UnCLIPScheduler, - ): - super().__init__() - - self.register_modules( - prior=prior, - decoder=decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_proj=text_proj, - super_res_first=super_res_first, - super_res_last=super_res_last, - prior_scheduler=prior_scheduler, - decoder_scheduler=decoder_scheduler, - super_res_scheduler=super_res_scheduler, - ) - - def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - latents = latents * scheduler.init_noise_sigma - return latents - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, - text_attention_mask: Optional[torch.Tensor] = None, - ): - if text_model_output is None: - batch_size = len(prompt) if isinstance(prompt, list) else 1 - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - text_mask = text_inputs.attention_mask.bool().to(device) - - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - - text_encoder_output = self.text_encoder(text_input_ids.to(device)) - - prompt_embeds = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state - - else: - batch_size = text_model_output[0].shape[0] - prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] - text_mask = text_attention_mask - - prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - uncond_text_mask = uncond_input.attention_mask.bool().to(device) - negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) - - negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds - uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) - - seq_len = uncond_text_encoder_hidden_states.shape[1] - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - # done duplicates - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) - - text_mask = torch.cat([uncond_text_mask, text_mask]) - - return prompt_embeds, text_encoder_hidden_states, text_mask - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's - models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only - when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list - models = [ - self.decoder, - self.text_proj, - self.text_encoder, - self.super_res_first, - self.super_res_last, - ] - for cpu_offloaded_model in models: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): - return self.device - for module in self.decoder.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - @torch.no_grad() - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - prior_num_inference_steps: int = 25, - decoder_num_inference_steps: int = 25, - super_res_num_inference_steps: int = 7, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - prior_latents: Optional[torch.FloatTensor] = None, - decoder_latents: Optional[torch.FloatTensor] = None, - super_res_latents: Optional[torch.FloatTensor] = None, - text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, - text_attention_mask: Optional[torch.Tensor] = None, - prior_guidance_scale: float = 4.0, - decoder_guidance_scale: float = 8.0, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. This can only be left undefined if - `text_model_output` and `text_attention_mask` is passed. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - prior_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the prior. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - decoder_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - super_res_num_inference_steps (`int`, *optional*, defaults to 7): - The number of denoising steps for super resolution. More denoising steps usually lead to a higher - quality image at the expense of slower inference. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): - Pre-generated noisy latents to be used as inputs for the prior. - decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - prior_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - decoder_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - text_model_output (`CLIPTextModelOutput`, *optional*): - Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs - can be passed for tasks like text embedding interpolations. Make sure to also pass - `text_attention_mask` in this case. `prompt` can the be left to `None`. - text_attention_mask (`torch.Tensor`, *optional*): - Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention - masks are necessary when passing `text_model_output`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - """ - if prompt is not None: - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - else: - batch_size = text_model_output[0].shape[0] - - device = self._execution_device - - batch_size = batch_size * num_images_per_prompt - - do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 - - prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask - ) - - # prior - - self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) - prior_timesteps_tensor = self.prior_scheduler.timesteps - - embedding_dim = self.prior.config.embedding_dim - - prior_latents = self.prepare_latents( - (batch_size, embedding_dim), - prompt_embeds.dtype, - device, - generator, - prior_latents, - self.prior_scheduler, - ) - - for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents - - predicted_image_embedding = self.prior( - latent_model_input, - timestep=t, - proj_embedding=prompt_embeds, - encoder_hidden_states=text_encoder_hidden_states, - attention_mask=text_mask, - ).predicted_image_embedding - - if do_classifier_free_guidance: - predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) - predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( - predicted_image_embedding_text - predicted_image_embedding_uncond - ) - - if i + 1 == prior_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = prior_timesteps_tensor[i + 1] - - prior_latents = self.prior_scheduler.step( - predicted_image_embedding, - timestep=t, - sample=prior_latents, - generator=generator, - prev_timestep=prev_timestep, - ).prev_sample - - prior_latents = self.prior.post_process_latents(prior_latents) - - image_embeddings = prior_latents - - # done prior - - # decoder - - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( - image_embeddings=image_embeddings, - prompt_embeds=prompt_embeds, - text_encoder_hidden_states=text_encoder_hidden_states, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - - if device.type == "mps": - # HACK: MPS: There is a panic when padding bool tensors, - # so cast to int tensor for the pad and back to bool afterwards - text_mask = text_mask.type(torch.int) - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) - decoder_text_mask = decoder_text_mask.type(torch.bool) - else: - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) - - self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) - decoder_timesteps_tensor = self.decoder_scheduler.timesteps - - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size - - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) - - for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents - - noise_pred = self.decoder( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=text_encoder_hidden_states, - class_labels=additive_clip_time_embeddings, - attention_mask=decoder_text_mask, - ).sample - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) - noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) - noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) - noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) - - if i + 1 == decoder_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = decoder_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - decoder_latents = self.decoder_scheduler.step( - noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - decoder_latents = decoder_latents.clamp(-1, 1) - - image_small = decoder_latents - - # done decoder - - # super res - - self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) - super_res_timesteps_tensor = self.super_res_scheduler.timesteps - - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size - - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) - - if device.type == "mps": - # MPS does not support many interpolations - image_upscaled = F.interpolate(image_small, size=[height, width]) - else: - interpolate_antialias = {} - if "antialias" in inspect.signature(F.interpolate).parameters: - interpolate_antialias["antialias"] = True - - image_upscaled = F.interpolate( - image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias - ) - - for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): - # no classifier free guidance - - if i == super_res_timesteps_tensor.shape[0] - 1: - unet = self.super_res_last - else: - unet = self.super_res_first - - latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) - - noise_pred = unet( - sample=latent_model_input, - timestep=t, - ).sample - - if i + 1 == super_res_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = super_res_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - super_res_latents = self.super_res_scheduler.step( - noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - image = super_res_latents - # done super res - - # post processing - - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/my_diffusers/pipelines/unclip/pipeline_unclip_image_variation.py deleted file mode 100644 index e5e7668468412a70f4708a6fe4eafb323e14e359..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ /dev/null @@ -1,463 +0,0 @@ -# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import List, Optional, Union - -import PIL -import torch -from torch.nn import functional as F -from transformers import ( - CLIPFeatureExtractor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...models import UNet2DConditionModel, UNet2DModel -from ...pipelines import DiffusionPipeline, ImagePipelineOutput -from ...schedulers import UnCLIPScheduler -from ...utils import is_accelerate_available, logging, randn_tensor -from .text_proj import UnCLIPTextProjModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class UnCLIPImageVariationPipeline(DiffusionPipeline): - """ - Pipeline to generate variations from an input image using unCLIP - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - text_encoder ([`CLIPTextModelWithProjection`]): - Frozen text-encoder. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `image_encoder`. - image_encoder ([`CLIPVisionModelWithProjection`]): - Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - text_proj ([`UnCLIPTextProjModel`]): - Utility class to prepare and combine the embeddings before they are passed to the decoder. - decoder ([`UNet2DConditionModel`]): - The decoder to invert the image embedding into an image. - super_res_first ([`UNet2DModel`]): - Super resolution unet. Used in all but the last step of the super resolution diffusion process. - super_res_last ([`UNet2DModel`]): - Super resolution unet. Used in the last step of the super resolution diffusion process. - decoder_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. - super_res_scheduler ([`UnCLIPScheduler`]): - Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. - - """ - - decoder: UNet2DConditionModel - text_proj: UnCLIPTextProjModel - text_encoder: CLIPTextModelWithProjection - tokenizer: CLIPTokenizer - feature_extractor: CLIPFeatureExtractor - image_encoder: CLIPVisionModelWithProjection - super_res_first: UNet2DModel - super_res_last: UNet2DModel - - decoder_scheduler: UnCLIPScheduler - super_res_scheduler: UnCLIPScheduler - - def __init__( - self, - decoder: UNet2DConditionModel, - text_encoder: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - text_proj: UnCLIPTextProjModel, - feature_extractor: CLIPFeatureExtractor, - image_encoder: CLIPVisionModelWithProjection, - super_res_first: UNet2DModel, - super_res_last: UNet2DModel, - decoder_scheduler: UnCLIPScheduler, - super_res_scheduler: UnCLIPScheduler, - ): - super().__init__() - - self.register_modules( - decoder=decoder, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_proj=text_proj, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - super_res_first=super_res_first, - super_res_last=super_res_last, - decoder_scheduler=decoder_scheduler, - super_res_scheduler=super_res_scheduler, - ) - - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents - def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - latents = latents * scheduler.init_noise_sigma - return latents - - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - text_mask = text_inputs.attention_mask.bool().to(device) - text_encoder_output = self.text_encoder(text_input_ids.to(device)) - - prompt_embeds = text_encoder_output.text_embeds - text_encoder_hidden_states = text_encoder_output.last_hidden_state - - prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) - text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_text_mask = uncond_input.attention_mask.bool().to(device) - negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) - - negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds - uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) - - seq_len = uncond_text_encoder_hidden_states.shape[1] - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) - uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) - - # done duplicates - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) - - text_mask = torch.cat([uncond_text_mask, text_mask]) - - return prompt_embeds, text_encoder_hidden_states, text_mask - - def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): - dtype = next(self.image_encoder.parameters()).dtype - - if image_embeddings is None: - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(images=image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - image_embeddings = self.image_encoder(image).image_embeds - - image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - - return image_embeddings - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's - models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only - when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - models = [ - self.decoder, - self.text_proj, - self.text_encoder, - self.super_res_first, - self.super_res_last, - ] - for cpu_offloaded_model in models: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): - return self.device - for module in self.decoder.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - @torch.no_grad() - def __call__( - self, - image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, - num_images_per_prompt: int = 1, - decoder_num_inference_steps: int = 25, - super_res_num_inference_steps: int = 7, - generator: Optional[torch.Generator] = None, - decoder_latents: Optional[torch.FloatTensor] = None, - super_res_latents: Optional[torch.FloatTensor] = None, - image_embeddings: Optional[torch.Tensor] = None, - decoder_guidance_scale: float = 8.0, - output_type: Optional[str] = "pil", - return_dict: bool = True, - ): - """ - Function invoked when calling the pipeline for generation. - - Args: - image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): - The image or images to guide the image generation. If you provide a tensor, it needs to comply with the - configuration of - [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) - `CLIPFeatureExtractor`. Can be left to `None` only when `image_embeddings` are passed. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - decoder_num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality - image at the expense of slower inference. - super_res_num_inference_steps (`int`, *optional*, defaults to 7): - The number of denoising steps for super resolution. More denoising steps usually lead to a higher - quality image at the expense of slower inference. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): - Pre-generated noisy latents to be used as inputs for the decoder. - decoder_guidance_scale (`float`, *optional*, defaults to 4.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - image_embeddings (`torch.Tensor`, *optional*): - Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings - can be passed for tasks like image interpolations. `image` can the be left to `None`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - """ - if image is not None: - if isinstance(image, PIL.Image.Image): - batch_size = 1 - elif isinstance(image, list): - batch_size = len(image) - else: - batch_size = image.shape[0] - else: - batch_size = image_embeddings.shape[0] - - prompt = [""] * batch_size - - device = self._execution_device - - batch_size = batch_size * num_images_per_prompt - - do_classifier_free_guidance = decoder_guidance_scale > 1.0 - - prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance - ) - - image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) - - # decoder - text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( - image_embeddings=image_embeddings, - prompt_embeds=prompt_embeds, - text_encoder_hidden_states=text_encoder_hidden_states, - do_classifier_free_guidance=do_classifier_free_guidance, - ) - - if device.type == "mps": - # HACK: MPS: There is a panic when padding bool tensors, - # so cast to int tensor for the pad and back to bool afterwards - text_mask = text_mask.type(torch.int) - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) - decoder_text_mask = decoder_text_mask.type(torch.bool) - else: - decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) - - self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) - decoder_timesteps_tensor = self.decoder_scheduler.timesteps - - num_channels_latents = self.decoder.in_channels - height = self.decoder.sample_size - width = self.decoder.sample_size - - if decoder_latents is None: - decoder_latents = self.prepare_latents( - (batch_size, num_channels_latents, height, width), - text_encoder_hidden_states.dtype, - device, - generator, - decoder_latents, - self.decoder_scheduler, - ) - - for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents - - noise_pred = self.decoder( - sample=latent_model_input, - timestep=t, - encoder_hidden_states=text_encoder_hidden_states, - class_labels=additive_clip_time_embeddings, - attention_mask=decoder_text_mask, - ).sample - - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) - noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) - noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) - noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) - - if i + 1 == decoder_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = decoder_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - decoder_latents = self.decoder_scheduler.step( - noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - decoder_latents = decoder_latents.clamp(-1, 1) - - image_small = decoder_latents - - # done decoder - - # super res - - self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) - super_res_timesteps_tensor = self.super_res_scheduler.timesteps - - channels = self.super_res_first.in_channels // 2 - height = self.super_res_first.sample_size - width = self.super_res_first.sample_size - - if super_res_latents is None: - super_res_latents = self.prepare_latents( - (batch_size, channels, height, width), - image_small.dtype, - device, - generator, - super_res_latents, - self.super_res_scheduler, - ) - - if device.type == "mps": - # MPS does not support many interpolations - image_upscaled = F.interpolate(image_small, size=[height, width]) - else: - interpolate_antialias = {} - if "antialias" in inspect.signature(F.interpolate).parameters: - interpolate_antialias["antialias"] = True - - image_upscaled = F.interpolate( - image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias - ) - - for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): - # no classifier free guidance - - if i == super_res_timesteps_tensor.shape[0] - 1: - unet = self.super_res_last - else: - unet = self.super_res_first - - latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) - - noise_pred = unet( - sample=latent_model_input, - timestep=t, - ).sample - - if i + 1 == super_res_timesteps_tensor.shape[0]: - prev_timestep = None - else: - prev_timestep = super_res_timesteps_tensor[i + 1] - - # compute the previous noisy sample x_t -> x_t-1 - super_res_latents = self.super_res_scheduler.step( - noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator - ).prev_sample - - image = super_res_latents - - # done super res - - # post processing - - image = image * 0.5 + 0.5 - image = image.clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/unclip/text_proj.py b/my_diffusers/pipelines/unclip/text_proj.py deleted file mode 100644 index 0a54c3319f2850084f523922b713cdd8f04a6750..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/unclip/text_proj.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models import ModelMixin - - -class UnCLIPTextProjModel(ModelMixin, ConfigMixin): - """ - Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the - decoder. - - For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 - """ - - @register_to_config - def __init__( - self, - *, - clip_extra_context_tokens: int = 4, - clip_embeddings_dim: int = 768, - time_embed_dim: int, - cross_attention_dim, - ): - super().__init__() - - self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) - - # parameters for additional clip time embeddings - self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) - self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) - - # parameters for encoder hidden states - self.clip_extra_context_tokens = clip_extra_context_tokens - self.clip_extra_context_tokens_proj = nn.Linear( - clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim - ) - self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) - self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): - if do_classifier_free_guidance: - # Add the classifier free guidance embeddings to the image embeddings - image_embeddings_batch_size = image_embeddings.shape[0] - classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) - classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( - image_embeddings_batch_size, -1 - ) - image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) - - # The image embeddings batch size and the text embeddings batch size are equal - assert image_embeddings.shape[0] == prompt_embeds.shape[0] - - batch_size = prompt_embeds.shape[0] - - # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and - # adding CLIP embeddings to the existing timestep embedding, ... - time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) - time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) - additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds - - # ... and by projecting CLIP embeddings into four - # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" - clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) - clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) - - text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) - text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1) - text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2) - - return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/my_diffusers/pipelines/versatile_diffusion/__init__.py b/my_diffusers/pipelines/versatile_diffusion/__init__.py deleted file mode 100644 index abf9dcff59dbc922dcc7063a1e73560679a23696..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/versatile_diffusion/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from ...utils import ( - OptionalDependencyNotAvailable, - is_torch_available, - is_transformers_available, - is_transformers_version, -) - - -try: - if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import ( - VersatileDiffusionDualGuidedPipeline, - VersatileDiffusionImageVariationPipeline, - VersatileDiffusionPipeline, - VersatileDiffusionTextToImagePipeline, - ) -else: - from .modeling_text_unet import UNetFlatConditionModel - from .pipeline_versatile_diffusion import VersatileDiffusionPipeline - from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline - from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline - from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline diff --git a/my_diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index d20b3c9bb547c3b4bd32771ee23928ce539c13f2..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/versatile_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-311.pyc b/my_diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-311.pyc deleted file mode 100644 index acdab05c0241c74a898d94396b4c85d1e763962b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/versatile_diffusion/__pycache__/modeling_text_unet.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-311.pyc b/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-311.pyc deleted file mode 100644 index 24e17c3475b5841e2a5ea57b9c10e385b6f47164..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-311.pyc b/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-311.pyc deleted file mode 100644 index b7f8d50c80240a0fbcfdfeb2ceef6495c946682b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-311.pyc b/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-311.pyc deleted file mode 100644 index d41811b6b6fb0fe56555209e4f4d838f91601c31..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-311.pyc b/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-311.pyc deleted file mode 100644 index 5b6e557cc46fa56479fd168ca406edfeda8ccc61..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/my_diffusers/pipelines/versatile_diffusion/modeling_text_unet.py deleted file mode 100644 index 261135d33318f7516bd6fecabb07d9d13292fda9..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ /dev/null @@ -1,1474 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models import ModelMixin -from ...models.attention import CrossAttention -from ...models.cross_attention import AttnProcessor, CrossAttnAddedKVProcessor -from ...models.dual_transformer_2d import DualTransformer2DModel -from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from ...models.transformer_2d import Transformer2DModel -from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def get_down_block( - down_block_type, - num_layers, - in_channels, - out_channels, - temb_channels, - add_downsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - downsample_padding=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlockFlat": - return DownBlockFlat( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "CrossAttnDownBlockFlat": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") - return CrossAttnDownBlockFlat( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{down_block_type} is not supported.") - - -def get_up_block( - up_block_type, - num_layers, - in_channels, - out_channels, - prev_output_channel, - temb_channels, - add_upsample, - resnet_eps, - resnet_act_fn, - attn_num_head_channels, - resnet_groups=None, - cross_attention_dim=None, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - resnet_time_scale_shift="default", -): - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlockFlat": - return UpBlockFlat( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "CrossAttnUpBlockFlat": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") - return CrossAttnUpBlockFlat( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - raise ValueError(f"{up_block_type} is not supported.") - - -# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat -class UNetFlatConditionModel(ModelMixin, ConfigMixin): - r""" - UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a - timestep and returns sample shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): - The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip - the mid block layer if `None`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, it will skip the normalization and activation layers in post-processing - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately - summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, default to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - timestep_post_act (`str, *optional*, default to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, default to `None`): - The dimension of `cond_proj` layer in timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlockFlat", - "CrossAttnDownBlockFlat", - "CrossAttnDownBlockFlat", - "DownBlockFlat", - ), - mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlockFlat", - "CrossAttnUpBlockFlat", - "CrossAttnUpBlockFlat", - "CrossAttnUpBlockFlat", - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - time_embedding_type: str = "positional", - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, - ): - super().__init__() - - self.sample_size = sample_size - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:" - f" {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:" - f" {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - "Must provide the same number of `only_cross_attention` as `down_block_types`." - f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:" - f" {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = LinearMultiDim( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - if time_embedding_type == "fourier": - time_embed_dim = block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." - ) - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - temb_channels=time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlockFlatCrossAttn": - self.mid_block = UNetMidBlockFlatCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": - self.mid_block = UNetMidBlockFlatSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) - only_cross_attention = list(reversed(only_cross_attention)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=layers_per_block + 1, - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - self.conv_act = nn.SiLU() - else: - self.conv_norm_out = None - self.conv_act = None - - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = LinearMultiDim( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) - - @property - def attn_processors(self) -> Dict[str, AttnProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): - if hasattr(module, "set_processor"): - processors[f"{name}.processor"] = module.processor - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): - r""" - Parameters: - `processor (`dict` of `AttnProcessor` or `AttnProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `CrossAttention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - def set_attention_slice(self, slice_size): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) - - num_slicable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_slicable_layers * [1] - - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True - - # prepare attention_mask - if attention_mask is not None: - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - # 2. pre-process - sample = self.conv_in(sample) - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - - down_block_res_samples += res_samples - - if down_block_additional_residuals is not None: - new_down_block_res_samples = () - - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample += down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - # 4. mid - if self.mid_block is not None: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - ) - - if mid_block_additional_residual is not None: - sample += mid_block_additional_residual - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) - - -class LinearMultiDim(nn.Linear): - def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): - in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) - if out_features is None: - out_features = in_features - out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) - self.in_features_multidim = in_features - self.out_features_multidim = out_features - super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) - - def forward(self, input_tensor, *args, **kwargs): - shape = input_tensor.shape - n_dim = len(self.in_features_multidim) - input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) - output_tensor = super().forward(input_tensor) - output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) - return output_tensor - - -class ResnetBlockFlat(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - time_embedding_norm="default", - use_in_shortcut=None, - second_dim=4, - **kwargs, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - - in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) - self.in_channels_prod = np.array(in_channels).prod() - self.channels_multidim = in_channels - - if out_channels is not None: - out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) - out_channels_prod = np.array(out_channels).prod() - self.out_channels_multidim = out_channels - else: - out_channels_prod = self.in_channels_prod - self.out_channels_multidim = self.channels_multidim - self.time_embedding_norm = time_embedding_norm - - if groups_out is None: - groups_out = groups - - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) - self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) - - if temb_channels is not None: - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) - else: - self.time_emb_proj = None - - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) - - self.nonlinearity = nn.SiLU() - - self.use_in_shortcut = ( - self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut - ) - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, input_tensor, temb): - shape = input_tensor.shape - n_dim = len(self.channels_multidim) - input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) - input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) - - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) - - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] - hidden_states = hidden_states + temb - - hidden_states = self.norm2(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden_states - - output_tensor = output_tensor.view(*shape[0:-n_dim], -1) - output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) - - return output_tensor - - -# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim -class DownBlockFlat(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_downsample=True, - downsample_padding=1, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - LinearMultiDim( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, temb=None): - output_states = () - - for resnet in self.resnets: - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim -class CrossAttnDownBlockFlat(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - LinearMultiDim( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - # TODO(Patrick, William) - attention mask is not used - output_states = () - - for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - - -# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim -class UpBlockFlat(nn.Module): - def __init__( - self, - in_channels: int, - prev_output_channel: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlockFlat( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - else: - hidden_states = resnet(hidden_states, temb) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim -class CrossAttnUpBlockFlat(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - ): - super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlockFlat( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - if add_upsample: - self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, - ): - # TODO(Patrick, William) - attention mask is not used - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - cross_attention_kwargs, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat -class UNetMidBlockFlatCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - dual_cross_attention=False, - use_linear_projection=False, - upcast_attention=False, - ): - super().__init__() - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - # there is always at least one resnet - resnets = [ - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) - ) - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ).sample - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat -class UNetMidBlockFlatSimpleCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - output_scale_factor=1.0, - cross_attention_dim=1280, - ): - super().__init__() - - self.has_cross_attention = True - - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - - self.num_heads = in_channels // self.attn_num_head_channels - - # there is always at least one resnet - resnets = [ - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] - - for _ in range(num_layers): - attentions.append( - CrossAttention( - query_dim=in_channels, - cross_attention_dim=in_channels, - heads=self.num_heads, - dim_head=attn_num_head_channels, - added_kv_proj_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - bias=True, - upcast_softmax=True, - processor=CrossAttnAddedKVProcessor(), - ) - ) - resnets.append( - ResnetBlockFlat( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) - - # resnet - hidden_states = resnet(hidden_states, temb) - - return hidden_states diff --git a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py deleted file mode 100644 index f482ef11940a3315665675720392b39bf76547b5..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ /dev/null @@ -1,434 +0,0 @@ -import inspect -from typing import Callable, List, Optional, Union - -import PIL.Image -import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging -from ..pipeline_utils import DiffusionPipeline -from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline -from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline -from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class VersatileDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionMegaSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - tokenizer: CLIPTokenizer - image_feature_extractor: CLIPFeatureExtractor - text_encoder: CLIPTextModel - image_encoder: CLIPVisionModel - image_unet: UNet2DConditionModel - text_unet: UNet2DConditionModel - vae: AutoencoderKL - scheduler: KarrasDiffusionSchedulers - - def __init__( - self, - tokenizer: CLIPTokenizer, - image_feature_extractor: CLIPFeatureExtractor, - text_encoder: CLIPTextModel, - image_encoder: CLIPVisionModel, - image_unet: UNet2DConditionModel, - text_unet: UNet2DConditionModel, - vae: AutoencoderKL, - scheduler: KarrasDiffusionSchedulers, - ): - super().__init__() - - self.register_modules( - tokenizer=tokenizer, - image_feature_extractor=image_feature_extractor, - text_encoder=text_encoder, - image_encoder=image_encoder, - image_unet=image_unet, - text_unet=text_unet, - vae=vae, - scheduler=scheduler, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - @torch.no_grad() - def image_variation( - self, - image: Union[torch.FloatTensor, PIL.Image.Image], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): - The image prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> from diffusers import VersatileDiffusionPipeline - >>> import torch - >>> import requests - >>> from io import BytesIO - >>> from PIL import Image - - >>> # let's download an initial image - >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" - - >>> response = requests.get(url) - >>> image = Image.open(BytesIO(response.content)).convert("RGB") - - >>> pipe = VersatileDiffusionPipeline.from_pretrained( - ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> generator = torch.Generator(device="cuda").manual_seed(0) - >>> image = pipe.image_variation(image, generator=generator).images[0] - >>> image.save("./car_variation.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() - components = {name: component for name, component in self.components.items() if name in expected_components} - return VersatileDiffusionImageVariationPipeline(**components)( - image=image, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, - ) - - @torch.no_grad() - def text_to_image( - self, - prompt: Union[str, List[str]], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> from diffusers import VersatileDiffusionPipeline - >>> import torch - - >>> pipe = VersatileDiffusionPipeline.from_pretrained( - ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> generator = torch.Generator(device="cuda").manual_seed(0) - >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] - >>> image.save("./astronaut.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() - components = {name: component for name, component in self.components.items() if name in expected_components} - temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) - output = temp_pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, - ) - # swap the attention blocks back to the original state - temp_pipeline._swap_unet_attention_blocks() - - return output - - @torch.no_grad() - def dual_guided( - self, - prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], - image: Union[str, List[str]], - text_to_image_strength: float = 0.5, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> from diffusers import VersatileDiffusionPipeline - >>> import torch - >>> import requests - >>> from io import BytesIO - >>> from PIL import Image - - >>> # let's download an initial image - >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" - - >>> response = requests.get(url) - >>> image = Image.open(BytesIO(response.content)).convert("RGB") - >>> text = "a red car in the sun" - - >>> pipe = VersatileDiffusionPipeline.from_pretrained( - ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> generator = torch.Generator(device="cuda").manual_seed(0) - >>> text_to_image_strength = 0.75 - - >>> image = pipe.dual_guided( - ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator - ... ).images[0] - >>> image.save("./car_variation.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When - returning a tuple, the first element is a list with the generated images. - """ - - expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() - components = {name: component for name, component in self.components.items() if name in expected_components} - temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) - output = temp_pipeline( - prompt=prompt, - image=image, - text_to_image_strength=text_to_image_strength, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - eta=eta, - generator=generator, - latents=latents, - output_type=output_type, - return_dict=return_dict, - callback=callback, - callback_steps=callback_steps, - ) - temp_pipeline._revert_dual_attention() - - return output diff --git a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py deleted file mode 100644 index 529d9a2ae9c09e775a131803888939e3bacbefd6..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ /dev/null @@ -1,585 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Tuple, Union - -import numpy as np -import PIL -import torch -import torch.utils.checkpoint -from transformers import ( - CLIPFeatureExtractor, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionModelWithProjection, -) - -from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .modeling_text_unet import UNetFlatConditionModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) Model to encode and decode images to and from latent representations. - bert ([`LDMBertModel`]): - Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. - tokenizer (`transformers.BertTokenizer`): - Tokenizer of class - [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - tokenizer: CLIPTokenizer - image_feature_extractor: CLIPFeatureExtractor - text_encoder: CLIPTextModelWithProjection - image_encoder: CLIPVisionModelWithProjection - image_unet: UNet2DConditionModel - text_unet: UNetFlatConditionModel - vae: AutoencoderKL - scheduler: KarrasDiffusionSchedulers - - _optional_components = ["text_unet"] - - def __init__( - self, - tokenizer: CLIPTokenizer, - image_feature_extractor: CLIPFeatureExtractor, - text_encoder: CLIPTextModelWithProjection, - image_encoder: CLIPVisionModelWithProjection, - image_unet: UNet2DConditionModel, - text_unet: UNetFlatConditionModel, - vae: AutoencoderKL, - scheduler: KarrasDiffusionSchedulers, - ): - super().__init__() - self.register_modules( - tokenizer=tokenizer, - image_feature_extractor=image_feature_extractor, - text_encoder=text_encoder, - image_encoder=image_encoder, - image_unet=image_unet, - text_unet=text_unet, - vae=vae, - scheduler=scheduler, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - if self.text_unet is not None and ( - "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention - ): - # if loading from a universal checkpoint rather than a saved dual-guided pipeline - self._convert_to_dual_attention() - - def remove_unused_weights(self): - self.register_modules(text_unet=None) - - def _convert_to_dual_attention(self): - """ - Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks - from both `image_unet` and `text_unet` - """ - for name, module in self.image_unet.named_modules(): - if isinstance(module, Transformer2DModel): - parent_name, index = name.rsplit(".", 1) - index = int(index) - - image_transformer = self.image_unet.get_submodule(parent_name)[index] - text_transformer = self.text_unet.get_submodule(parent_name)[index] - - config = image_transformer.config - dual_transformer = DualTransformer2DModel( - num_attention_heads=config.num_attention_heads, - attention_head_dim=config.attention_head_dim, - in_channels=config.in_channels, - num_layers=config.num_layers, - dropout=config.dropout, - norm_num_groups=config.norm_num_groups, - cross_attention_dim=config.cross_attention_dim, - attention_bias=config.attention_bias, - sample_size=config.sample_size, - num_vector_embeds=config.num_vector_embeds, - activation_fn=config.activation_fn, - num_embeds_ada_norm=config.num_embeds_ada_norm, - ) - dual_transformer.transformers[0] = image_transformer - dual_transformer.transformers[1] = text_transformer - - self.image_unet.get_submodule(parent_name)[index] = dual_transformer - self.image_unet.register_to_config(dual_cross_attention=True) - - def _revert_dual_attention(self): - """ - Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call - this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` - """ - for name, module in self.image_unet.named_modules(): - if isinstance(module, DualTransformer2DModel): - parent_name, index = name.rsplit(".", 1) - index = int(index) - self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] - - self.image_unet.register_to_config(dual_cross_attention=False) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.image_unet, "_hf_hook"): - return self.device - for module in self.image_unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - """ - - def normalize_embeddings(encoder_output): - embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) - embeds_pooled = encoder_output.text_embeds - embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) - return embeds - - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - - if not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = normalize_embeddings(prompt_embeds) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens = [""] * batch_size - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - """ - - def normalize_embeddings(encoder_output): - embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) - embeds = self.image_encoder.visual_projection(embeds) - embeds_pooled = embeds[:, 0:1] - embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) - return embeds - - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") - pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) - image_embeddings = self.image_encoder(pixel_values) - image_embeddings = normalize_embeddings(image_embeddings) - - # duplicate image embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = image_embeddings.shape - image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) - image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size - uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") - pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) - negative_prompt_embeds = self.image_encoder(pixel_values) - negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and conditional embeddings into a single batch - # to avoid doing two forward passes - image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) - - return image_embeddings - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs(self, prompt, image, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") - if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): - raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): - for name, module in self.image_unet.named_modules(): - if isinstance(module, DualTransformer2DModel): - module.mix_ratio = mix_ratio - - for i, type in enumerate(condition_types): - if type == "text": - module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings - module.transformer_index_for_condition[i] = 1 # use the second (text) transformer - else: - module.condition_lengths[i] = 257 - module.transformer_index_for_condition[i] = 0 # use the first (image) transformer - - @torch.no_grad() - def __call__( - self, - prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], - image: Union[str, List[str]], - text_to_image_strength: float = 0.5, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> from diffusers import VersatileDiffusionDualGuidedPipeline - >>> import torch - >>> import requests - >>> from io import BytesIO - >>> from PIL import Image - - >>> # let's download an initial image - >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" - - >>> response = requests.get(url) - >>> image = Image.open(BytesIO(response.content)).convert("RGB") - >>> text = "a red car in the sun" - - >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( - ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 - ... ) - >>> pipe.remove_unused_weights() - >>> pipe = pipe.to("cuda") - - >>> generator = torch.Generator(device="cuda").manual_seed(0) - >>> text_to_image_strength = 0.75 - - >>> image = pipe( - ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator - ... ).images[0] - >>> image.save("./car_variation.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When - returning a tuple, the first element is a list with the generated images. - """ - # 0. Default height and width to unet - height = height or self.image_unet.config.sample_size * self.vae_scale_factor - width = width or self.image_unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, image, height, width, callback_steps) - - # 2. Define call parameters - prompt = [prompt] if not isinstance(prompt, list) else prompt - image = [image] if not isinstance(image, list) else image - batch_size = len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompts - prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) - image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) - dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) - prompt_types = ("text", "image") - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - dual_prompt_embeddings.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Combine the attention blocks of the image and text UNets - self.set_transformer_params(text_to_image_strength, prompt_types) - - # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py deleted file mode 100644 index fd6855af3852d2624b3bebace2158dfa780adbf9..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ /dev/null @@ -1,427 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import numpy as np -import PIL -import torch -import torch.utils.checkpoint -from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection - -from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) Model to encode and decode images to and from latent representations. - bert ([`LDMBertModel`]): - Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. - tokenizer (`transformers.BertTokenizer`): - Tokenizer of class - [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - image_feature_extractor: CLIPFeatureExtractor - image_encoder: CLIPVisionModelWithProjection - image_unet: UNet2DConditionModel - vae: AutoencoderKL - scheduler: KarrasDiffusionSchedulers - - def __init__( - self, - image_feature_extractor: CLIPFeatureExtractor, - image_encoder: CLIPVisionModelWithProjection, - image_unet: UNet2DConditionModel, - vae: AutoencoderKL, - scheduler: KarrasDiffusionSchedulers, - ): - super().__init__() - self.register_modules( - image_feature_extractor=image_feature_extractor, - image_encoder=image_encoder, - image_unet=image_unet, - vae=vae, - scheduler=scheduler, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.image_unet, "_hf_hook"): - return self.device - for module in self.image_unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - - def normalize_embeddings(encoder_output): - embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) - embeds = self.image_encoder.visual_projection(embeds) - embeds_pooled = embeds[:, 0:1] - embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) - return embeds - - if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: - prompt = [p for p in prompt] - - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") - pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) - image_embeddings = self.image_encoder(pixel_values) - image_embeddings = normalize_embeddings(image_embeddings) - - # duplicate image embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = image_embeddings.shape - image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) - image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_images: List[str] - if negative_prompt is None: - uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, PIL.Image.Image): - uncond_images = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_images = negative_prompt - - uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") - pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) - negative_prompt_embeds = self.image_encoder(pixel_values) - negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and conditional embeddings into a single batch - # to avoid doing two forward passes - image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) - - return image_embeddings - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs - def check_inputs(self, image, height, width, callback_steps): - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, list) - ): - raise ValueError( - "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" - f" {type(image)}" - ) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def __call__( - self, - image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): - The image prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> from diffusers import VersatileDiffusionImageVariationPipeline - >>> import torch - >>> import requests - >>> from io import BytesIO - >>> from PIL import Image - - >>> # let's download an initial image - >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" - - >>> response = requests.get(url) - >>> image = Image.open(BytesIO(response.content)).convert("RGB") - - >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( - ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> generator = torch.Generator(device="cuda").manual_seed(0) - >>> image = pipe(image, generator=generator).images[0] - >>> image.save("./car_variation.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.image_unet.config.sample_size * self.vae_scale_factor - width = width or self.image_unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(image, height, width, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - image_embeddings = self._encode_prompt( - image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - image_embeddings.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py deleted file mode 100644 index d1bb754c7b58d822393a5d202454ebf63dd39b35..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ /dev/null @@ -1,501 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, List, Optional, Union - -import torch -import torch.utils.checkpoint -from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer - -from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .modeling_text_unet import UNetFlatConditionModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): - r""" - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Parameters: - vqvae ([`VQModel`]): - Vector-quantized (VQ) Model to encode and decode images to and from latent representations. - bert ([`LDMBertModel`]): - Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. - tokenizer (`transformers.BertTokenizer`): - Tokenizer of class - [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - """ - tokenizer: CLIPTokenizer - image_feature_extractor: CLIPFeatureExtractor - text_encoder: CLIPTextModelWithProjection - image_unet: UNet2DConditionModel - text_unet: UNetFlatConditionModel - vae: AutoencoderKL - scheduler: KarrasDiffusionSchedulers - - _optional_components = ["text_unet"] - - def __init__( - self, - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModelWithProjection, - image_unet: UNet2DConditionModel, - text_unet: UNetFlatConditionModel, - vae: AutoencoderKL, - scheduler: KarrasDiffusionSchedulers, - ): - super().__init__() - self.register_modules( - tokenizer=tokenizer, - text_encoder=text_encoder, - image_unet=image_unet, - text_unet=text_unet, - vae=vae, - scheduler=scheduler, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - if self.text_unet is not None: - self._swap_unet_attention_blocks() - - def _swap_unet_attention_blocks(self): - """ - Swap the `Transformer2DModel` blocks between the image and text UNets - """ - for name, module in self.image_unet.named_modules(): - if isinstance(module, Transformer2DModel): - parent_name, index = name.rsplit(".", 1) - index = int(index) - self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( - self.text_unet.get_submodule(parent_name)[index], - self.image_unet.get_submodule(parent_name)[index], - ) - - def remove_unused_weights(self): - self.register_modules(text_unet=None) - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.image_unet, "_hf_hook"): - return self.device - for module in self.image_unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - """ - - def normalize_embeddings(encoder_output): - embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) - embeds_pooled = encoder_output.text_embeds - embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) - return embeds - - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - - if not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = normalize_embeddings(prompt_embeds) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - **kwargs, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Examples: - - ```py - >>> from diffusers import VersatileDiffusionTextToImagePipeline - >>> import torch - - >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( - ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 - ... ) - >>> pipe.remove_unused_weights() - >>> pipe = pipe.to("cuda") - - >>> generator = torch.Generator(device="cuda").manual_seed(0) - >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] - >>> image.save("./astronaut.png") - ``` - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height = height or self.image_unet.config.sample_size * self.vae_scale_factor - width = width or self.image_unet.config.sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) - - # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.image_unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) diff --git a/my_diffusers/pipelines/vq_diffusion/__init__.py b/my_diffusers/pipelines/vq_diffusion/__init__.py deleted file mode 100644 index 8c9f14f000648347fe75a5bec0cb45d08c7d2ff9..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/vq_diffusion/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from ...utils import is_torch_available, is_transformers_available - - -if is_transformers_available() and is_torch_available(): - from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline diff --git a/my_diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-311.pyc b/my_diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 33ce1d9af760044d13e5397db6b92dd794694c13..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/vq_diffusion/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-311.pyc b/my_diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-311.pyc deleted file mode 100644 index fa52dbe5c0eb291b51a8aee58f4dafc2d1bf49ae..0000000000000000000000000000000000000000 Binary files a/my_diffusers/pipelines/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/my_diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py deleted file mode 100644 index 9147afe127e4b24366249c4a6e058abae9501050..0000000000000000000000000000000000000000 --- a/my_diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, List, Optional, Tuple, Union - -import torch -from transformers import CLIPTextModel, CLIPTokenizer - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models import ModelMixin, Transformer2DModel, VQModel -from ...schedulers import VQDiffusionScheduler -from ...utils import logging -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): - """ - Utility class for storing learned text embeddings for classifier free sampling - """ - - @register_to_config - def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): - super().__init__() - - self.learnable = learnable - - if self.learnable: - assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" - assert length is not None, "learnable=True requires `length` to be set" - - embeddings = torch.zeros(length, hidden_size) - else: - embeddings = None - - self.embeddings = torch.nn.Parameter(embeddings) - - -class VQDiffusionPipeline(DiffusionPipeline): - r""" - Pipeline for text-to-image generation using VQ Diffusion - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vqvae ([`VQModel`]): - Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent - representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. VQ Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - transformer ([`Transformer2DModel`]): - Conditional transformer to denoise the encoded image latents. - scheduler ([`VQDiffusionScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - """ - - vqvae: VQModel - text_encoder: CLIPTextModel - tokenizer: CLIPTokenizer - transformer: Transformer2DModel - learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings - scheduler: VQDiffusionScheduler - - def __init__( - self, - vqvae: VQModel, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - transformer: Transformer2DModel, - scheduler: VQDiffusionScheduler, - learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, - ): - super().__init__() - - self.register_modules( - vqvae=vqvae, - transformer=transformer, - text_encoder=text_encoder, - tokenizer=tokenizer, - scheduler=scheduler, - learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, - ) - - def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] - - # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. - # While CLIP does normalize the pooled output of the text transformer when combining - # the image and text embeddings, CLIP does not directly normalize the last hidden state. - # - # CLIP normalizing the pooled output. - # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 - prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) - - # duplicate text embeddings for each generation per prompt - prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) - - if do_classifier_free_guidance: - if self.learned_classifier_free_sampling_embeddings.learnable: - negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings - negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) - else: - uncond_tokens = [""] * batch_size - - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] - # See comment for normalizing text embeddings - negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - num_inference_steps: int = 100, - guidance_scale: float = 5.0, - truncation_rate: float = 1.0, - num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - ) -> Union[ImagePipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - num_inference_steps (`int`, *optional*, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): - Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at - most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above - `truncation_rate` are set to zero. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor` of shape (batch), *optional*): - Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. - Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will - be generated of completely masked latent pixels. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - - Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` - is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. - """ - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - batch_size = batch_size * num_images_per_prompt - - do_classifier_free_guidance = guidance_scale > 1.0 - - prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - # get the initial completely masked latents unless the user supplied it - - latents_shape = (batch_size, self.transformer.num_latent_pixels) - if latents is None: - mask_class = self.transformer.num_vector_embeds - 1 - latents = torch.full(latents_shape, mask_class).to(self.device) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): - raise ValueError( - "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," - f" {self.transformer.num_vector_embeds - 1} (inclusive)." - ) - latents = latents.to(self.device) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=self.device) - - timesteps_tensor = self.scheduler.timesteps.to(self.device) - - sample = latents - - for i, t in enumerate(self.progress_bar(timesteps_tensor)): - # expand the sample if we are doing classifier free guidance - latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample - - # predict the un-noised image - # model_output == `log_p_x_0` - model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample - - if do_classifier_free_guidance: - model_output_uncond, model_output_text = model_output.chunk(2) - model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) - model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) - - model_output = self.truncate(model_output, truncation_rate) - - # remove `log(0)`'s (`-inf`s) - model_output = model_output.clamp(-70) - - # compute the previous noisy sample x_t -> x_t-1 - sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, sample) - - embedding_channels = self.vqvae.config.vq_embed_dim - embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) - embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) - image = self.vqvae.decode(embeddings, force_not_quantize=True).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image) - - def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor: - """ - Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The - lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. - """ - sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) - sorted_p_x_0 = torch.exp(sorted_log_p_x_0) - keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate - - # Ensure that at least the largest probability is not zeroed out - all_true = torch.full_like(keep_mask[:, 0:1, :], True) - keep_mask = torch.cat((all_true, keep_mask), dim=1) - keep_mask = keep_mask[:, :-1, :] - - keep_mask = keep_mask.gather(1, indices.argsort(1)) - - rv = log_p_x_0.clone() - - rv[~keep_mask] = -torch.inf # -inf = log(0) - - return rv diff --git a/my_diffusers/schedulers/__init__.py b/my_diffusers/schedulers/__init__.py deleted file mode 100644 index e5d5bb40633f39008090ae56c15b94a8bc378d07..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/__init__.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available - - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_pt_objects import * # noqa F403 -else: - from .scheduling_ddim import DDIMScheduler - from .scheduling_ddim_inverse import DDIMInverseScheduler - from .scheduling_ddpm import DDPMScheduler - from .scheduling_deis_multistep import DEISMultistepScheduler - from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler - from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler - from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler - from .scheduling_euler_discrete import EulerDiscreteScheduler - from .scheduling_heun_discrete import HeunDiscreteScheduler - from .scheduling_ipndm import IPNDMScheduler - from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler - from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler - from .scheduling_karras_ve import KarrasVeScheduler - from .scheduling_pndm import PNDMScheduler - from .scheduling_repaint import RePaintScheduler - from .scheduling_sde_ve import ScoreSdeVeScheduler - from .scheduling_sde_vp import ScoreSdeVpScheduler - from .scheduling_unclip import UnCLIPScheduler - from .scheduling_unipc_multistep import UniPCMultistepScheduler - from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - from .scheduling_vq_diffusion import VQDiffusionScheduler - -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_objects import * # noqa F403 -else: - from .scheduling_ddim_flax import FlaxDDIMScheduler - from .scheduling_ddpm_flax import FlaxDDPMScheduler - from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler - from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler - from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler - from .scheduling_pndm_flax import FlaxPNDMScheduler - from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler - from .scheduling_utils_flax import ( - FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - broadcast_to_shape_from_left, - ) - - -try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 -else: - from .scheduling_lms_discrete import LMSDiscreteScheduler diff --git a/my_diffusers/schedulers/__pycache__/__init__.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index d9215b5a8c838df2bf55de339c4595ad00020957..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-311.pyc deleted file mode 100644 index e2b5535e926e8fc9e9b865bfeb13ee82764e6436..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_ddim.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ddim_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ddim_flax.cpython-311.pyc deleted file mode 100644 index d8894b1b6f492cd65b01c4e2811e651df9f86757..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_ddim_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-311.pyc deleted file mode 100644 index 1dfd703615d2ad87b186d952d6cb2a33fa3e8b80..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_ddim_inverse.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-311.pyc deleted file mode 100644 index 46b521296992a65c302280638829f79903b6c77f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_ddpm.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ddpm_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ddpm_flax.cpython-311.pyc deleted file mode 100644 index 64e5a521e4c2e696ca7d81271d63ec73789d7cac..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_ddpm_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-311.pyc deleted file mode 100644 index 2da03d3bb51fc3e5fc90f21d059ffabf84aeca8e..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_deis_multistep.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-311.pyc deleted file mode 100644 index c4ebd2dc6cd62744732a4ea43db0e01f346993bc..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_flax.cpython-311.pyc deleted file mode 100644 index 62b143f5e9babe1be888baebe1cae28e54f1e36f..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_multistep_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-311.pyc deleted file mode 100644 index b170d690ce0bc6ce070661e68f76844a425af282..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_dpmsolver_singlestep.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-311.pyc deleted file mode 100644 index 4e9a5c64011d1c2acba34b7a77c73ee375bab34c..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_euler_ancestral_discrete.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-311.pyc deleted file mode 100644 index d9309e3440c5b0ce80806ef1f11aad20e72d2483..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_euler_discrete.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-311.pyc deleted file mode 100644 index 5ab7ffa88f52ad8743f87579262eacd3f9dbbdd1..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_heun_discrete.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-311.pyc deleted file mode 100644 index 2fcbd6a2a29ee71c86814bcd3997667decf8e570..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_ipndm.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-311.pyc deleted file mode 100644 index 3d0bae3a11309e54dc73d1565a8b7a001ed6f007..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_k_dpm_2_ancestral_discrete.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-311.pyc deleted file mode 100644 index fca3ad73d6952ff229002e9e9767e07d60af6dac..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_k_dpm_2_discrete.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-311.pyc deleted file mode 100644 index 5d26b53ba6bcf60d1a15e7a3665137bde9f66591..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_karras_ve.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_karras_ve_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_karras_ve_flax.cpython-311.pyc deleted file mode 100644 index 089d011a8d2f5d91e581aca200f6edf1338dede0..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_karras_ve_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-311.pyc deleted file mode 100644 index 23cf05b98b84afed5895eb1a087a5b699bf1d0ea..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete_flax.cpython-311.pyc deleted file mode 100644 index b0d56e4cf884aafa3bc23a35e679888b7c7e7134..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_lms_discrete_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-311.pyc deleted file mode 100644 index b16c503c46701866203973883cbf8526a1e2ae31..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_pndm.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_pndm_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_pndm_flax.cpython-311.pyc deleted file mode 100644 index 980ac93262c28a30cdc9bfcc3426279730bd6c91..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_pndm_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_repaint.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_repaint.cpython-311.pyc deleted file mode 100644 index c615270750d3480779a6210e154a0494e1f347cc..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_repaint.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-311.pyc deleted file mode 100644 index ac46dde0c5ccdc2f1c665eb7cf24fd03a8896ccb..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_sde_ve.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_sde_ve_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_sde_ve_flax.cpython-311.pyc deleted file mode 100644 index 4979d5fdc69f5eb86230b97ed79f62ddc28a619b..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_sde_ve_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-311.pyc deleted file mode 100644 index 2cf0c39f4cb228b0026c850c393d8176249020e0..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_sde_vp.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_unclip.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_unclip.cpython-311.pyc deleted file mode 100644 index f13c0624386a8e780888f28e91ed129ae8f1f2e4..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_unclip.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-311.pyc deleted file mode 100644 index bef3171360d13f99976bcc88fbfbe114034b3b67..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_unipc_multistep.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_utils.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_utils.cpython-311.pyc deleted file mode 100644 index eda435e6f7d79aae769d9ce1d9e718f05beb7b60..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_utils_flax.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_utils_flax.cpython-311.pyc deleted file mode 100644 index 3a32470be0cd6cfbabfc0ef2bdcf243659107220..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_utils_flax.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-311.pyc b/my_diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-311.pyc deleted file mode 100644 index 779252c16b682bb4e5a3ab0fa1f0c22afa2707a1..0000000000000000000000000000000000000000 Binary files a/my_diffusers/schedulers/__pycache__/scheduling_vq_diffusion.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/schedulers/scheduling_ddim.py b/my_diffusers/schedulers/scheduling_ddim.py deleted file mode 100644 index d107d34f45dce25bca50bca106ce4d27c27cdaeb..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_ddim.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion -# and https://github.com/hojonathanho/diffusion - -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM -class DDIMSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DDIMScheduler(SchedulerMixin, ConfigMixin): - """ - Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising - diffusion probabilistic models (DDPMs) with non-Markovian guidance. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2010.02502 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - set_alpha_to_one (`bool`, default `True`): - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - clip_sample: bool = True, - set_alpha_to_one: bool = True, - steps_offset: int = 0, - prediction_type: str = "epsilon", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) - - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - if num_inference_steps > self.config.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.config.num_train_timesteps} timesteps." - ) - - self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += self.config.steps_offset - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - eta: float = 0.0, - use_clipped_model_output: bool = False, - generator=None, - variance_noise: Optional[torch.FloatTensor] = None, - return_dict: bool = True, - ) -> Union[DDIMSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped - predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when - `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would - coincide with the one provided as input and `use_clipped_model_output` will have not effect. - generator: random number generator. - variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we - can directly provide the noise for the variance itself. This is useful for methods such as - CycleDiffusion. (https://arxiv.org/abs/2210.05559) - return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_sample_direction -> "direction pointing to x_t" - # - pred_prev_sample -> "x_t-1" - - # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - - # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - # predict V - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" - ) - - # 4. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) - - if use_clipped_model_output: - # the model_output is always re-derived from the clipped x_0 in Glide - model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - - if eta > 0: - device = model_output.device - if variance_noise is not None and generator is not None: - raise ValueError( - "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" - " `variance_noise` stays `None`." - ) - - if variance_noise is None: - variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype - ) - variance = std_dev_t * variance_noise - - prev_sample = prev_sample + variance - - if not return_dict: - return (prev_sample,) - - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity( - self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_ddim_flax.py b/my_diffusers/schedulers/scheduling_ddim_flax.py deleted file mode 100644 index 14a374d48b488dfc838b8c61939f431c842dbcac..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_ddim_flax.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion -# and https://github.com/hojonathanho/diffusion - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax.numpy as jnp - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import ( - CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - add_noise_common, - get_velocity_common, -) - - -@flax.struct.dataclass -class DDIMSchedulerState: - common: CommonSchedulerState - final_alpha_cumprod: jnp.ndarray - - # setable values - init_noise_sigma: jnp.ndarray - timesteps: jnp.ndarray - num_inference_steps: Optional[int] = None - - @classmethod - def create( - cls, - common: CommonSchedulerState, - final_alpha_cumprod: jnp.ndarray, - init_noise_sigma: jnp.ndarray, - timesteps: jnp.ndarray, - ): - return cls( - common=common, - final_alpha_cumprod=final_alpha_cumprod, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - ) - - -@dataclass -class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): - state: DDIMSchedulerState - - -class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising - diffusion probabilistic models (DDPMs) with non-Markovian guidance. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2010.02502 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`jnp.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - set_alpha_to_one (`bool`, default `True`): - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. - `v-prediction` is not supported for this scheduler. - dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): - the `dtype` used for params and computation. - """ - - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] - - dtype: jnp.dtype - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - set_alpha_to_one: bool = True, - steps_offset: int = 0, - prediction_type: str = "epsilon", - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) - - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - final_alpha_cumprod = ( - jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] - ) - - # standard deviation of the initial noise distribution - init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - - timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] - - return DDIMSchedulerState.create( - common=common, - final_alpha_cumprod=final_alpha_cumprod, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - ) - - def scale_model_input( - self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None - ) -> jnp.ndarray: - """ - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - sample (`jnp.ndarray`): input sample - timestep (`int`, optional): current timestep - - Returns: - `jnp.ndarray`: scaled input sample - """ - return sample - - def set_timesteps( - self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () - ) -> DDIMSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`DDIMSchedulerState`): - the `FlaxDDIMScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - step_ratio = self.config.num_train_timesteps // num_inference_steps - # creates integer timesteps by multiplying by ratio - # rounding to avoid issues when num_inference_step is power of 3 - timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset - - return state.replace( - num_inference_steps=num_inference_steps, - timesteps=timesteps, - ) - - def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): - alpha_prod_t = state.common.alphas_cumprod[timestep] - alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod - ) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - - def step( - self, - state: DDIMSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - eta: float = 0.0, - return_dict: bool = True, - ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class - - Returns: - [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_sample_direction -> "direction pointing to x_t" - # - pred_prev_sample -> "x_t-1" - - # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps - - alphas_cumprod = state.common.alphas_cumprod - final_alpha_cumprod = state.final_alpha_cumprod - - # 2. compute alphas, betas - alpha_prod_t = alphas_cumprod[timestep] - alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod) - - beta_prod_t = 1 - alpha_prod_t - - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - # predict V - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction`" - ) - - # 4. compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(state, timestep, prev_timestep) - std_dev_t = eta * variance ** (0.5) - - # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - - # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - - if not return_dict: - return (prev_sample, state) - - return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) - - def add_noise( - self, - state: DDIMSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return add_noise_common(state.common, original_samples, noise, timesteps) - - def get_velocity( - self, - state: DDIMSchedulerState, - sample: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return get_velocity_common(state.common, sample, noise, timesteps) - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_ddim_inverse.py b/my_diffusers/schedulers/scheduling_ddim_inverse.py deleted file mode 100644 index 7006bd133932ab073acf41c5e946eb0023a65f7c..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_ddim_inverse.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion -# and https://github.com/hojonathanho/diffusion -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import SchedulerMixin -from diffusers.utils import BaseOutput - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM -class DDIMSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): - """ - DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2010.02502 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - set_alpha_to_one (`bool`, default `True`): - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - """ - - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - clip_sample: bool = True, - set_alpha_to_one: bool = True, - steps_offset: int = 0, - prediction_type: str = "epsilon", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) - - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - if num_inference_steps > self.config.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.config.num_train_timesteps} timesteps." - ) - - self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += self.config.steps_offset - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - eta: float = 0.0, - use_clipped_model_output: bool = False, - variance_noise: Optional[torch.FloatTensor] = None, - return_dict: bool = True, - ) -> Union[DDIMSchedulerOutput, Tuple]: - e_t = model_output - - x = sample - prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps - - a_t = self.alphas_cumprod[timestep - 1] - a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod - - pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt() - - dir_xt = (1.0 - a_prev).sqrt() * e_t - - prev_sample = a_prev.sqrt() * pred_x0 + dir_xt - - if not return_dict: - return (prev_sample, pred_x0) - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_ddpm.py b/my_diffusers/schedulers/scheduling_ddpm.py deleted file mode 100644 index 6fbeac6118fba4bb5ab06fe6d6521e9fc05d7088..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_ddpm.py +++ /dev/null @@ -1,367 +0,0 @@ -# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - - -@dataclass -class DDPMSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DDPMScheduler(SchedulerMixin, ConfigMixin): - """ - Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and - Langevin dynamics sampling. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2006.11239 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - variance_type: str = "fixed_small", - clip_sample: bool = True, - prediction_type: str = "epsilon", - clip_sample_range: Optional[float] = 1.0, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - elif beta_schedule == "sigmoid": - # GeoDiff sigmoid schedule - betas = torch.linspace(-6, 6, num_train_timesteps) - self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.one = torch.tensor(1.0) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) - - self.variance_type = variance_type - - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - if num_inference_steps > self.config.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.config.num_train_timesteps} timesteps." - ) - - self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - - def _get_variance(self, t, predicted_variance=None, variance_type=None): - num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - prev_t = t - self.config.num_train_timesteps // num_inference_steps - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one - current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev - - # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) - # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t - - if variance_type is None: - variance_type = self.config.variance_type - - # hacks - were probably added for training stability - if variance_type == "fixed_small": - variance = torch.clamp(variance, min=1e-20) - # for rl-diffuser https://arxiv.org/abs/2205.09991 - elif variance_type == "fixed_small_log": - variance = torch.log(torch.clamp(variance, min=1e-20)) - variance = torch.exp(0.5 * variance) - elif variance_type == "fixed_large": - variance = current_beta_t - elif variance_type == "fixed_large_log": - # Glide max_log - variance = torch.log(current_beta_t) - elif variance_type == "learned": - return predicted_variance - elif variance_type == "learned_range": - min_log = torch.log(variance) - max_log = torch.log(self.betas[t]) - frac = (predicted_variance + 1) / 2 - variance = frac * max_log + (1 - frac) * min_log - - return variance - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - generator=None, - return_dict: bool = True, - ) -> Union[DDPMSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - t = timestep - num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: - model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) - else: - predicted_variance = None - - # 1. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - current_alpha_t = alpha_prod_t / alpha_prod_t_prev - current_beta_t = 1 - current_alpha_t - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" - " `v_prediction` for the DDPMScheduler." - ) - - # 3. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = torch.clamp( - pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range - ) - - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t - current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t - - # 5. Compute predicted previous sample µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample - - # 6. Add noise - variance = 0 - if t > 0: - device = model_output.device - variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype - ) - if self.variance_type == "fixed_small_log": - variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise - elif self.variance_type == "learned_range": - variance = self._get_variance(t, predicted_variance=predicted_variance) - variance = torch.exp(0.5 * variance) * variance_noise - else: - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise - - pred_prev_sample = pred_prev_sample + variance - - if not return_dict: - return (pred_prev_sample,) - - return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity( - self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_ddpm_flax.py b/my_diffusers/schedulers/scheduling_ddpm_flax.py deleted file mode 100644 index 529d2bd03a75403e298ec7a30808689a48cf5301..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_ddpm_flax.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax -import jax.numpy as jnp - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import ( - CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - add_noise_common, - get_velocity_common, -) - - -@flax.struct.dataclass -class DDPMSchedulerState: - common: CommonSchedulerState - - # setable values - init_noise_sigma: jnp.ndarray - timesteps: jnp.ndarray - num_inference_steps: Optional[int] = None - - @classmethod - def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) - - -@dataclass -class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): - state: DDPMSchedulerState - - -class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and - Langevin dynamics sampling. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2006.11239 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. - `v-prediction` is not supported for this scheduler. - dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): - the `dtype` used for params and computation. - """ - - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] - - dtype: jnp.dtype - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - variance_type: str = "fixed_small", - clip_sample: bool = True, - prediction_type: str = "epsilon", - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) - - # standard deviation of the initial noise distribution - init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - - timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] - - return DDPMSchedulerState.create( - common=common, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - ) - - def scale_model_input( - self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None - ) -> jnp.ndarray: - """ - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - sample (`jnp.ndarray`): input sample - timestep (`int`, optional): current timestep - - Returns: - `jnp.ndarray`: scaled input sample - """ - return sample - - def set_timesteps( - self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () - ) -> DDPMSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`DDIMSchedulerState`): - the `FlaxDDPMScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - step_ratio = self.config.num_train_timesteps // num_inference_steps - # creates integer timesteps by multiplying by ratio - # rounding to avoid issues when num_inference_step is power of 3 - timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] - - return state.replace( - num_inference_steps=num_inference_steps, - timesteps=timesteps, - ) - - def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None): - alpha_prod_t = state.common.alphas_cumprod[t] - alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) - - # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) - # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t] - - if variance_type is None: - variance_type = self.config.variance_type - - # hacks - were probably added for training stability - if variance_type == "fixed_small": - variance = jnp.clip(variance, a_min=1e-20) - # for rl-diffuser https://arxiv.org/abs/2205.09991 - elif variance_type == "fixed_small_log": - variance = jnp.log(jnp.clip(variance, a_min=1e-20)) - elif variance_type == "fixed_large": - variance = state.common.betas[t] - elif variance_type == "fixed_large_log": - # Glide max_log - variance = jnp.log(state.common.betas[t]) - elif variance_type == "learned": - return predicted_variance - elif variance_type == "learned_range": - min_log = variance - max_log = state.common.betas[t] - frac = (predicted_variance + 1) / 2 - variance = frac * max_log + (1 - frac) * min_log - - return variance - - def step( - self, - state: DDPMSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - key: Optional[jax.random.KeyArray] = None, - return_dict: bool = True, - ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - key (`jax.random.KeyArray`): a PRNG key. - return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class - - Returns: - [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - t = timestep - - if key is None: - key = jax.random.PRNGKey(0) - - if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: - model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) - else: - predicted_variance = None - - # 1. compute alphas, betas - alpha_prod_t = state.common.alphas_cumprod[t] - alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " - " for the FlaxDDPMScheduler." - ) - - # 3. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = jnp.clip(pred_original_sample, -1, 1) - - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t - current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t - - # 5. Compute predicted previous sample µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample - - # 6. Add noise - def random_variance(): - split_key = jax.random.split(key, num=1) - noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) - return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise - - variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype)) - - pred_prev_sample = pred_prev_sample + variance - - if not return_dict: - return (pred_prev_sample, state) - - return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) - - def add_noise( - self, - state: DDPMSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return add_noise_common(state.common, original_samples, noise, timesteps) - - def get_velocity( - self, - state: DDPMSchedulerState, - sample: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return get_velocity_common(state.common, sample, noise, timesteps) - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_deis_multistep.py b/my_diffusers/schedulers/scheduling_deis_multistep.py deleted file mode 100644 index f58284005deda5e40098d415d06623fdd76860bf..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_deis_multistep.py +++ /dev/null @@ -1,481 +0,0 @@ -# Copyright 2023 FLAIR Lab and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info -# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the - polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification - enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More - variants of DEIS can be found in https://github.com/qsh-zh/deis. - - Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1` - reduces to DDIM. - - We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set `thresholding=True` to use the dynamic thresholding. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - solver_order (`int`, default `2`): - the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and - `solver_order=3` for unconditional sampling. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, - or `v-prediction`. - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - Note that the thresholding method is unsuitable for latent-space diffusion models (such as - stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen - (https://arxiv.org/abs/2205.11487). - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid woks when `thresholding=True` - algorithm_type (`str`, default `deis`): - the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in - the future - lower_order_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically - find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. - - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "deis", - solver_type: str = "logrho", - lower_order_final: bool = True, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # settings for DEIS - if algorithm_type not in ["deis"]: - if algorithm_type in ["dpmsolver", "dpmsolver++"]: - algorithm_type = "deis" - else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - - if solver_type not in ["logrho"]: - if solver_type in ["midpoint", "heun", "bh1", "bh2"]: - solver_type = "logrho" - else: - raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor - ) -> torch.FloatTensor: - """ - Convert the model output to the corresponding type that the algorithm DEIS needs. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the converted model output. - """ - if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DEISMultistepScheduler." - ) - - if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - dynamic_max_val = torch.quantile( - torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 - ) - dynamic_max_val = torch.maximum( - dynamic_max_val, - self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), - )[(...,) + (None,) * (x0_pred.ndim - 1)] - x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val - x0_pred = x0_pred.type(orig_dtype) - - if self.config.algorithm_type == "deis": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - return (sample - alpha_t * x0_pred) / sigma_t - else: - raise NotImplementedError("only support log-rho multistep deis now") - - def deis_first_order_update( - self, - model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the first-order DEIS (equivalent to DDIM). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] - h = lambda_t - lambda_s - if self.config.algorithm_type == "deis": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - else: - raise NotImplementedError("only support log-rho multistep deis now") - return x_t - - def multistep_deis_second_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the second-order multistep DEIS. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] - m0, m1 = model_output_list[-1], model_output_list[-2] - alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1] - sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1] - - rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 - - if self.config.algorithm_type == "deis": - - def ind_fn(t, b, c): - # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}] - return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c)) - - coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1) - coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0) - - x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1) - return x_t - else: - raise NotImplementedError("only support log-rho multistep deis now") - - def multistep_deis_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the third-order multistep DEIS. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2] - sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2] - rho_t, rho_s0, rho_s1, rho_s2 = ( - sigma_t / alpha_t, - sigma_s0 / alpha_s0, - sigma_s1 / alpha_s1, - simga_s2 / alpha_s2, - ) - - if self.config.algorithm_type == "deis": - - def ind_fn(t, b, c, d): - # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}] - numerator = t * ( - np.log(c) * (np.log(d) - np.log(t) + 1) - - np.log(d) * np.log(t) - + np.log(d) - + np.log(t) ** 2 - - 2 * np.log(t) - + 2 - ) - denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d)) - return numerator / denominator - - coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2) - coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0) - coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1) - - x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2) - - return x_t - else: - raise NotImplementedError("only support log-rho multistep deis now") - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the multistep DEIS. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - - model_output = self.convert_model_output(model_output, timestep, sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) - else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_dpmsolver_multistep.py b/my_diffusers/schedulers/scheduling_dpmsolver_multistep.py deleted file mode 100644 index db6eb530a1d78a2353fa494d3c8a7aa5345f9d90..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ /dev/null @@ -1,529 +0,0 @@ -# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with - the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality - samples, and it can generate quite good samples even in only 10 steps. - - For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 - - Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We - recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - - We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic - thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as - stable-diffusion). - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - solver_order (`int`, default `2`): - the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to - use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion - models (such as stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen - (https://arxiv.org/abs/2205.11487). - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++`. - algorithm_type (`str`, default `dpmsolver++`): - the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in - https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided - sampling (e.g. stable-diffusion). - solver_type (`str`, default `midpoint`): - the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects - the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are - slightly better, so we recommend to use the `midpoint` type. - lower_order_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically - find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. - - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", - lower_order_final: bool = True, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: - if algorithm_type == "deis": - algorithm_type = "dpmsolver++" - else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - if solver_type not in ["midpoint", "heun"]: - if solver_type in ["logrho", "bh1", "bh2"]: - solver_type = "midpoint" - else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor - ) -> torch.FloatTensor: - """ - Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. - - DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to - discretize an integral of the data prediction model. So we need to first convert the model output to the - corresponding type to match the algorithm. - - Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or - DPM-Solver++ for both noise prediction model and data prediction model. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the converted model output. - """ - # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type == "dpmsolver++": - if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - dynamic_max_val = torch.quantile( - torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 - ) - dynamic_max_val = torch.maximum( - dynamic_max_val, - self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), - )[(...,) + (None,) * (x0_pred.ndim - 1)] - x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val - x0_pred = x0_pred.type(orig_dtype) - return x0_pred - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": - if self.config.prediction_type == "epsilon": - return model_output - elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - epsilon = alpha_t * model_output + sigma_t * sample - return epsilon - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverMultistepScheduler." - ) - - def dpm_solver_first_order_update( - self, - model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the first-order DPM-Solver (equivalent to DDIM). - - See https://arxiv.org/abs/2206.00927 for the detailed derivation. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - return x_t - - def multistep_dpm_solver_second_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the second-order multistep DPM-Solver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] - m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) - return x_t - - def multistep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the third-order multistep DPM-Solver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], - ) - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] - h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m0 - D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the multistep DPM-Solver. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - - model_output = self.convert_model_output(model_output, timestep, sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) - else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_dpm_solver_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/my_diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py deleted file mode 100644 index 9b4ee67a7f5dbf8384eaedc0ede322284a413edd..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ /dev/null @@ -1,622 +0,0 @@ -# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver - -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import flax -import jax -import jax.numpy as jnp - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import ( - CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - add_noise_common, -) - - -@flax.struct.dataclass -class DPMSolverMultistepSchedulerState: - common: CommonSchedulerState - alpha_t: jnp.ndarray - sigma_t: jnp.ndarray - lambda_t: jnp.ndarray - - # setable values - init_noise_sigma: jnp.ndarray - timesteps: jnp.ndarray - num_inference_steps: Optional[int] = None - - # running values - model_outputs: Optional[jnp.ndarray] = None - lower_order_nums: Optional[jnp.int32] = None - prev_timestep: Optional[jnp.int32] = None - cur_sample: Optional[jnp.ndarray] = None - - @classmethod - def create( - cls, - common: CommonSchedulerState, - alpha_t: jnp.ndarray, - sigma_t: jnp.ndarray, - lambda_t: jnp.ndarray, - init_noise_sigma: jnp.ndarray, - timesteps: jnp.ndarray, - ): - return cls( - common=common, - alpha_t=alpha_t, - sigma_t=sigma_t, - lambda_t=lambda_t, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - ) - - -@dataclass -class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): - state: DPMSolverMultistepSchedulerState - - -class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with - the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality - samples, and it can generate quite good samples even in only 10 steps. - - For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 - - Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We - recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - - We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic - thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as - stable-diffusion). - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - solver_order (`int`, default `2`): - the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, - or `v-prediction`. - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to - use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion - models (such as stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen - (https://arxiv.org/abs/2205.11487). - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++`. - algorithm_type (`str`, default `dpmsolver++`): - the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in - https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided - sampling (e.g. stable-diffusion). - solver_type (`str`, default `midpoint`): - the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects - the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are - slightly better, so we recommend to use the `midpoint` type. - lower_order_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically - find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. - dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): - the `dtype` used for params and computation. - """ - - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] - - dtype: jnp.dtype - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", - lower_order_final: bool = True, - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) - - # Currently we only support VP-type noise schedule - alpha_t = jnp.sqrt(common.alphas_cumprod) - sigma_t = jnp.sqrt(1 - common.alphas_cumprod) - lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) - - # settings for DPM-Solver - if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]: - raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}") - if self.config.solver_type not in ["midpoint", "heun"]: - raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}") - - # standard deviation of the initial noise distribution - init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - - timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] - - return DPMSolverMultistepSchedulerState.create( - common=common, - alpha_t=alpha_t, - sigma_t=sigma_t, - lambda_t=lambda_t, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - ) - - def set_timesteps( - self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple - ) -> DPMSolverMultistepSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`DPMSolverMultistepSchedulerState`): - the `FlaxDPMSolverMultistepScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - shape (`Tuple`): - the shape of the samples to be generated. - """ - - timesteps = ( - jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .astype(jnp.int32) - ) - - # initial running values - - model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) - lower_order_nums = jnp.int32(0) - prev_timestep = jnp.int32(-1) - cur_sample = jnp.zeros(shape, dtype=self.dtype) - - return state.replace( - num_inference_steps=num_inference_steps, - timesteps=timesteps, - model_outputs=model_outputs, - lower_order_nums=lower_order_nums, - prev_timestep=prev_timestep, - cur_sample=cur_sample, - ) - - def convert_model_output( - self, - state: DPMSolverMultistepSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - ) -> jnp.ndarray: - """ - Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. - - DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to - discretize an integral of the data prediction model. So we need to first convert the model output to the - corresponding type to match the algorithm. - - Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or - DPM-Solver++ for both noise prediction model and data prediction model. - - Args: - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - - Returns: - `jnp.ndarray`: the converted model output. - """ - # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type == "dpmsolver++": - if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " - " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dynamic_max_val = jnp.percentile( - jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) - ) - dynamic_max_val = jnp.maximum( - dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) - ) - x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val - return x0_pred - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": - if self.config.prediction_type == "epsilon": - return model_output - elif self.config.prediction_type == "sample": - alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] - epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] - epsilon = alpha_t * model_output + sigma_t * sample - return epsilon - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " - " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." - ) - - def dpm_solver_first_order_update( - self, - state: DPMSolverMultistepSchedulerState, - model_output: jnp.ndarray, - timestep: int, - prev_timestep: int, - sample: jnp.ndarray, - ) -> jnp.ndarray: - """ - One step for the first-order DPM-Solver (equivalent to DDIM). - - See https://arxiv.org/abs/2206.00927 for the detailed derivation. - - Args: - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - - Returns: - `jnp.ndarray`: the sample tensor at the previous timestep. - """ - t, s0 = prev_timestep, timestep - m0 = model_output - lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0] - alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0] - sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0] - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 - return x_t - - def multistep_dpm_solver_second_order_update( - self, - state: DPMSolverMultistepSchedulerState, - model_output_list: jnp.ndarray, - timestep_list: List[int], - prev_timestep: int, - sample: jnp.ndarray, - ) -> jnp.ndarray: - """ - One step for the second-order multistep DPM-Solver. - - Args: - model_output_list (`List[jnp.ndarray]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - - Returns: - `jnp.ndarray`: the sample tensor at the previous timestep. - """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] - m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] - alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] - sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 - ) - return x_t - - def multistep_dpm_solver_third_order_update( - self, - state: DPMSolverMultistepSchedulerState, - model_output_list: jnp.ndarray, - timestep_list: List[int], - prev_timestep: int, - sample: jnp.ndarray, - ) -> jnp.ndarray: - """ - One step for the third-order multistep DPM-Solver. - - Args: - model_output_list (`List[jnp.ndarray]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - - Returns: - `jnp.ndarray`: the sample tensor at the previous timestep. - """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - state.lambda_t[t], - state.lambda_t[s0], - state.lambda_t[s1], - state.lambda_t[s2], - ) - alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] - sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] - h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m0 - D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (sigma_t / sigma_s0) * sample - - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ( - (alpha_t / alpha_s0) * sample - - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - def step( - self, - state: DPMSolverMultistepSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - return_dict: bool = True, - ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process - from the learned model outputs (most often the predicted noise). - - Args: - state (`DPMSolverMultistepSchedulerState`): - the `FlaxDPMSolverMultistepScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class - - Returns: - [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if - `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - (step_index,) = jnp.where(state.timesteps == timestep, size=1) - step_index = step_index[0] - - prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1]) - - model_output = self.convert_model_output(state, model_output, timestep, sample) - - model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) - model_outputs_new = model_outputs_new.at[-1].set(model_output) - state = state.replace( - model_outputs=model_outputs_new, - prev_timestep=prev_timestep, - cur_sample=sample, - ) - - def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - return self.dpm_solver_first_order_update( - state, - state.model_outputs[-1], - state.timesteps[step_index], - state.prev_timestep, - state.cur_sample, - ) - - def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]]) - return self.multistep_dpm_solver_second_order_update( - state, - state.model_outputs, - timestep_list, - state.prev_timestep, - state.cur_sample, - ) - - def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - timestep_list = jnp.array( - [ - state.timesteps[step_index - 2], - state.timesteps[step_index - 1], - state.timesteps[step_index], - ] - ) - return self.multistep_dpm_solver_third_order_update( - state, - state.model_outputs, - timestep_list, - state.prev_timestep, - state.cur_sample, - ) - - step_2_output = step_2(state) - step_3_output = step_3(state) - - if self.config.solver_order == 2: - return step_2_output - elif self.config.lower_order_final and len(state.timesteps) < 15: - return jax.lax.select( - state.lower_order_nums < 2, - step_2_output, - jax.lax.select( - step_index == len(state.timesteps) - 2, - step_2_output, - step_3_output, - ), - ) - else: - return jax.lax.select( - state.lower_order_nums < 2, - step_2_output, - step_3_output, - ) - - step_1_output = step_1(state) - step_23_output = step_23(state) - - if self.config.solver_order == 1: - prev_sample = step_1_output - - elif self.config.lower_order_final and len(state.timesteps) < 15: - prev_sample = jax.lax.select( - state.lower_order_nums < 1, - step_1_output, - jax.lax.select( - step_index == len(state.timesteps) - 1, - step_1_output, - step_23_output, - ), - ) - - else: - prev_sample = jax.lax.select( - state.lower_order_nums < 1, - step_1_output, - step_23_output, - ) - - state = state.replace( - lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), - ) - - if not return_dict: - return (prev_sample, state) - - return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) - - def scale_model_input( - self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None - ) -> jnp.ndarray: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - state (`DPMSolverMultistepSchedulerState`): - the `FlaxDPMSolverMultistepScheduler` state data class instance. - sample (`jnp.ndarray`): input sample - timestep (`int`, optional): current timestep - - Returns: - `jnp.ndarray`: scaled input sample - """ - return sample - - def add_noise( - self, - state: DPMSolverMultistepSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return add_noise_common(state.common, original_samples, noise, timesteps) - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/my_diffusers/schedulers/scheduling_dpmsolver_singlestep.py deleted file mode 100644 index 5b3951c7a7437301037763cace5b01ec59d4522a..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ /dev/null @@ -1,605 +0,0 @@ -# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): - """ - DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with - the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality - samples, and it can generate quite good samples even in only 10 steps. - - For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 - - Currently, we support the singlestep DPM-Solver for both noise prediction models and data prediction models. We - recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - - We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic - thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as - stable-diffusion). - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - solver_order (`int`, default `2`): - the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, - or `v-prediction`. - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to - use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion - models (such as stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen - (https://arxiv.org/abs/2205.11487). - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++`. - algorithm_type (`str`, default `dpmsolver++`): - the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in - https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided - sampling (e.g. stable-diffusion). - solver_type (`str`, default `midpoint`): - the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects - the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are - slightly better, so we recommend to use the `midpoint` type. - lower_order_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable - this to use up all the function evaluations. - - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", - lower_order_final: bool = True, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: - if algorithm_type == "deis": - algorithm_type = "dpmsolver++" - else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - if solver_type not in ["midpoint", "heun"]: - if solver_type in ["logrho", "bh1", "bh2"]: - solver_type = "midpoint" - else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.sample = None - self.order_list = self.get_order_list(num_train_timesteps) - - def get_order_list(self, num_inference_steps: int) -> List[int]: - """ - Computes the solver order at each time step. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - steps = num_inference_steps - order = self.solver_order - if self.lower_order_final: - if order == 3: - if steps % 3 == 0: - orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] - elif steps % 3 == 1: - orders = [1, 2, 3] * (steps // 3) + [1] - else: - orders = [1, 2, 3] * (steps // 3) + [1, 2] - elif order == 2: - if steps % 2 == 0: - orders = [1, 2] * (steps // 2) - else: - orders = [1, 2] * (steps // 2) + [1] - elif order == 1: - orders = [1] * steps - else: - if order == 3: - orders = [1, 2, 3] * (steps // 3) - elif order == 2: - orders = [1, 2] * (steps // 2) - elif order == 1: - orders = [1] * steps - return orders - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [None] * self.config.solver_order - self.sample = None - self.orders = self.get_order_list(num_inference_steps) - - def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor - ) -> torch.FloatTensor: - """ - Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. - - DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to - discretize an integral of the data prediction model. So we need to first convert the model output to the - corresponding type to match the algorithm. - - Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or - DPM-Solver++ for both noise prediction model and data prediction model. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the converted model output. - """ - # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type == "dpmsolver++": - if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverSinglestepScheduler." - ) - - if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - dtype = x0_pred.dtype - dynamic_max_val = torch.quantile( - torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(), - self.config.dynamic_thresholding_ratio, - dim=1, - ) - dynamic_max_val = torch.maximum( - dynamic_max_val, - self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), - )[(...,) + (None,) * (x0_pred.ndim - 1)] - x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val - x0_pred = x0_pred.to(dtype) - return x0_pred - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": - if self.config.prediction_type == "epsilon": - return model_output - elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - epsilon = alpha_t * model_output + sigma_t * sample - return epsilon - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DPMSolverSinglestepScheduler." - ) - - def dpm_solver_first_order_update( - self, - model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the first-order DPM-Solver (equivalent to DDIM). - - See https://arxiv.org/abs/2206.00927 for the detailed derivation. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - return x_t - - def singlestep_dpm_solver_second_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the second-order singlestep DPM-Solver. - - It computes the solution at time `prev_timestep` from the time `timestep_list[-2]`. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] - m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] - alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1] - sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1] - h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m1, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s1) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s1) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s1) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s1) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - ) - return x_t - - def singlestep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the third-order singlestep DPM-Solver. - - It computes the solution at time `prev_timestep` from the time `timestep_list[-3]`. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( - self.lambda_t[t], - self.lambda_t[s0], - self.lambda_t[s1], - self.lambda_t[s2], - ) - alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2] - sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2] - h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m2 - D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) - D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) - D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (sigma_t / sigma_s2) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (sigma_t / sigma_s2) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 - ) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 - ) - elif self.config.solver_type == "heun": - x_t = ( - (alpha_t / alpha_s2) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 - ) - return x_t - - def singlestep_dpm_solver_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - order: int, - ) -> torch.FloatTensor: - """ - One step for the singlestep DPM-Solver. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - order (`int`): - the solver order at this step. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - if order == 1: - return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) - elif order == 2: - return self.singlestep_dpm_solver_second_order_update( - model_output_list, timestep_list, prev_timestep, sample - ) - elif order == 3: - return self.singlestep_dpm_solver_third_order_update( - model_output_list, timestep_list, prev_timestep, sample - ) - else: - raise ValueError(f"Order must be 1, 2, 3, got {order}") - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the singlestep DPM-Solver. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - - model_output = self.convert_model_output(model_output, timestep, sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - order = self.order_list[step_index] - # For single-step solvers, we use the initial value at each time with order = 1. - if order == 1: - self.sample = sample - - timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] - prev_sample = self.singlestep_dpm_solver_update( - self.model_outputs, timestep_list, prev_timestep, self.sample, order - ) - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/my_diffusers/schedulers/scheduling_euler_ancestral_discrete.py deleted file mode 100644 index 1b517bdec5703495afeee26a1c8ed4cb98561d7c..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete -class EulerAncestralDiscreteSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.is_scale_input_called = False - - def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] - ) -> torch.FloatTensor: - """ - Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain - - Returns: - `torch.FloatTensor`: scaled input sample - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - sample = sample / ((sigma**2 + 1) ** 0.5) - self.is_scale_input_called = True - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) - else: - self.timesteps = torch.from_numpy(timesteps).to(device=device) - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`float`): current timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - generator (`torch.Generator`, optional): Random number generator. - return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise - a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if not self.is_scale_input_called: - logger.warning( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - sigma_from = self.sigmas[step_index] - sigma_to = self.sigmas[step_index + 1] - sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - - dt = sigma_down - sigma - - prev_sample = sample + derivative * dt - - device = model_output.device - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) - - prev_sample = prev_sample + noise * sigma_up - - if not return_dict: - return (prev_sample,) - - return EulerAncestralDiscreteSchedulerOutput( - prev_sample=prev_sample, pred_original_sample=pred_original_sample - ) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - self.timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - schedule_timesteps = self.timesteps - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = self.sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_euler_discrete.py b/my_diffusers/schedulers/scheduling_euler_discrete.py deleted file mode 100644 index d6252904fd9ac250b555dfe00d08c2ce64e0936b..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_euler_discrete.py +++ /dev/null @@ -1,335 +0,0 @@ -# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging, randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete -class EulerDiscreteSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original - k-diffusion implementation by Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - prediction_type (`str`, default `"epsilon"`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - interpolation_type (`str`, default `"linear"`, optional): - interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of - [`"linear"`, `"log_linear"`]. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - interpolation_type: str = "linear", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.is_scale_input_called = False - - def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] - ) -> torch.FloatTensor: - """ - Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain - - Returns: - `torch.FloatTensor`: scaled input sample - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - - sample = sample / ((sigma**2 + 1) ** 0.5) - - self.is_scale_input_called = True - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - - if self.config.interpolation_type == "linear": - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - elif self.config.interpolation_type == "log_linear": - sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() - else: - raise ValueError( - f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" - " 'linear' or 'log_linear'" - ) - - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) - else: - self.timesteps = torch.from_numpy(timesteps).to(device=device) - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - s_churn: float = 0.0, - s_tmin: float = 0.0, - s_tmax: float = float("inf"), - s_noise: float = 1.0, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`float`): current timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - s_churn (`float`) - s_tmin (`float`) - s_tmax (`float`) - s_noise (`float`) - generator (`torch.Generator`, optional): Random number generator. - return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if not self.is_scale_input_called: - logger.warning( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 - - noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator - ) - - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) - - if gamma > 0: - sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility - if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": - pred_original_sample = model_output - elif self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma_hat * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma_hat - - dt = self.sigmas[step_index + 1] - sigma_hat - - prev_sample = sample + derivative * dt - - if not return_dict: - return (prev_sample,) - - return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - self.timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - schedule_timesteps = self.timesteps - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = self.sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_heun_discrete.py b/my_diffusers/schedulers/scheduling_heun_discrete.py deleted file mode 100644 index f7f1467fc53a7e27a7ffdb171a40791aa5b97134..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_heun_discrete.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original - k-diffusion implementation by Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the - starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 2 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.00085, # sensible defaults - beta_end: float = 0.012, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # set all values - self.set_timesteps(num_train_timesteps, None, num_train_timesteps) - - def index_for_timestep(self, timestep): - indices = (self.timesteps == timestep).nonzero() - if self.state_in_first_order: - pos = -1 - else: - pos = 0 - return indices[pos].item() - - def scale_model_input( - self, - sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - ) -> torch.FloatTensor: - """ - Args: - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep - Returns: - `torch.FloatTensor`: scaled input sample - """ - step_index = self.index_for_timestep(timestep) - - sigma = self.sigmas[step_index] - sample = sample / ((sigma**2 + 1) ** 0.5) - return sample - - def set_timesteps( - self, - num_inference_steps: int, - device: Union[str, torch.device] = None, - num_train_timesteps: Optional[int] = None, - ): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - sigmas = torch.from_numpy(sigmas).to(device=device) - self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - - timesteps = torch.from_numpy(timesteps) - timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) - - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = timesteps.to(device, dtype=torch.float32) - else: - self.timesteps = timesteps.to(device=device) - - # empty dt and derivative - self.prev_derivative = None - self.dt = None - - @property - def state_in_first_order(self): - return self.dt is None - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: Union[float, torch.FloatTensor], - sample: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Args: - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep - (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - step_index = self.index_for_timestep(timestep) - - if self.state_in_first_order: - sigma = self.sigmas[step_index] - sigma_next = self.sigmas[step_index + 1] - else: - # 2nd order / Heun's method - sigma = self.sigmas[step_index - 1] - sigma_next = self.sigmas[step_index] - - # currently only gamma=0 is supported. This usually works best anyways. - # We can support gamma in the future but then need to scale the timestep before - # passing it to the model which requires a change in API - gamma = 0 - sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - sigma_input = sigma_hat if self.state_in_first_order else sigma_next - pred_original_sample = sample - sigma_input * model_output - elif self.config.prediction_type == "v_prediction": - sigma_input = sigma_hat if self.state_in_first_order else sigma_next - pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( - sample / (sigma_input**2 + 1) - ) - elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - if self.state_in_first_order: - # 2. Convert to an ODE derivative for 1st order - derivative = (sample - pred_original_sample) / sigma_hat - # 3. delta timestep - dt = sigma_next - sigma_hat - - # store for 2nd order step - self.prev_derivative = derivative - self.dt = dt - self.sample = sample - else: - # 2. 2nd order / Heun's method - derivative = (sample - pred_original_sample) / sigma_next - derivative = (self.prev_derivative + derivative) / 2 - - # 3. take prev timestep & sample - dt = self.dt - sample = self.sample - - # free dt and derivative - # Note, this puts the scheduler in "first order mode" - self.prev_derivative = None - self.dt = None - self.sample = None - - prev_sample = sample + derivative * dt - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - self.timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [self.index_for_timestep(t) for t in timesteps] - - sigma = self.sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_ipndm.py b/my_diffusers/schedulers/scheduling_ipndm.py deleted file mode 100644 index 80e521590782de6bc14e9b8c29642c7595fafc93..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_ipndm.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput - - -class IPNDMScheduler(SchedulerMixin, ConfigMixin): - """ - Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion - [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2202.09778 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - """ - - order = 1 - - @register_to_config - def __init__( - self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None - ): - # set `betas`, `alphas`, `timesteps` - self.set_timesteps(num_train_timesteps) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # For now we only support F-PNDM, i.e. the runge-kutta method - # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf - # mainly at formula (9), (12), (13) and the Algorithm 2. - self.pndm_order = 4 - - # running values - self.ets = [] - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - self.num_inference_steps = num_inference_steps - steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] - steps = torch.cat([steps, torch.tensor([0.0])]) - - if self.config.trained_betas is not None: - self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) - else: - self.betas = torch.sin(steps * math.pi / 2) ** 2 - - self.alphas = (1.0 - self.betas**2) ** 0.5 - - timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] - self.timesteps = timesteps.to(device) - - self.ets = [] - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple - times to approximate the solution. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - timestep_index = (self.timesteps == timestep).nonzero().item() - prev_timestep_index = timestep_index + 1 - - ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] - self.ets.append(ets) - - if len(self.ets) == 1: - ets = self.ets[-1] - elif len(self.ets) == 2: - ets = (3 * self.ets[-1] - self.ets[-2]) / 2 - elif len(self.ets) == 3: - ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 - else: - ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) - - prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): - alpha = self.alphas[timestep_index] - sigma = self.betas[timestep_index] - - next_alpha = self.alphas[prev_timestep_index] - next_sigma = self.betas[prev_timestep_index] - - pred = (sample - sigma * ets) / max(alpha, 1e-8) - prev_sample = next_alpha * pred + ets * next_sigma - - return prev_sample - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/my_diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py deleted file mode 100644 index c8b1f2c3bedf899d55a253b2c8a07709f6b476d8..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: - https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 - - Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the - starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 2 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.00085, # sensible defaults - beta_end: float = 0.012, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # set all values - self.set_timesteps(num_train_timesteps, None, num_train_timesteps) - - def index_for_timestep(self, timestep): - indices = (self.timesteps == timestep).nonzero() - if self.state_in_first_order: - pos = -1 - else: - pos = 0 - return indices[pos].item() - - def scale_model_input( - self, - sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - ) -> torch.FloatTensor: - """ - Args: - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep - Returns: - `torch.FloatTensor`: scaled input sample - """ - step_index = self.index_for_timestep(timestep) - - if self.state_in_first_order: - sigma = self.sigmas[step_index] - else: - sigma = self.sigmas_interpol[step_index - 1] - - sample = sample / ((sigma**2 + 1) ** 0.5) - return sample - - def set_timesteps( - self, - num_inference_steps: int, - device: Union[str, torch.device] = None, - num_train_timesteps: Optional[int] = None, - ): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) - - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - sigmas = torch.from_numpy(sigmas).to(device=device) - - # compute up and down sigmas - sigmas_next = sigmas.roll(-1) - sigmas_next[-1] = 0.0 - sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5 - sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5 - sigmas_down[-1] = 0.0 - - # compute interpolated sigmas - sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp() - sigmas_interpol[-2:] = 0.0 - - # set sigmas - self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) - self.sigmas_interpol = torch.cat( - [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] - ) - self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) - self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - - if str(device).startswith("mps"): - # mps does not support float64 - timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) - else: - timesteps = torch.from_numpy(timesteps).to(device) - - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) - interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() - - self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) - - self.sample = None - - def sigma_to_t(self, sigma): - # get log sigma - log_sigma = sigma.log() - - # get distribution - dists = log_sigma - self.log_sigmas[:, None] - - # get sigmas range - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = self.log_sigmas[low_idx] - high = self.log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.view(sigma.shape) - return t - - @property - def state_in_first_order(self): - return self.sample is None - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: Union[float, torch.FloatTensor], - sample: Union[torch.FloatTensor, np.ndarray], - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Args: - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep - (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - step_index = self.index_for_timestep(timestep) - - if self.state_in_first_order: - sigma = self.sigmas[step_index] - sigma_interpol = self.sigmas_interpol[step_index] - sigma_up = self.sigmas_up[step_index] - sigma_down = self.sigmas_down[step_index - 1] - else: - # 2nd order / KPDM2's method - sigma = self.sigmas[step_index - 1] - sigma_interpol = self.sigmas_interpol[step_index - 1] - sigma_up = self.sigmas_up[step_index - 1] - sigma_down = self.sigmas_down[step_index - 1] - - # currently only gamma=0 is supported. This usually works best anyways. - # We can support gamma in the future but then need to scale the timestep before - # passing it to the model which requires a change in API - gamma = 0 - sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now - - device = model_output.device - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol - pred_original_sample = sample - sigma_input * model_output - elif self.config.prediction_type == "v_prediction": - sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol - pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( - sample / (sigma_input**2 + 1) - ) - elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - if self.state_in_first_order: - # 2. Convert to an ODE derivative for 1st order - derivative = (sample - pred_original_sample) / sigma_hat - # 3. delta timestep - dt = sigma_interpol - sigma_hat - - # store for 2nd order step - self.sample = sample - self.dt = dt - prev_sample = sample + derivative * dt - else: - # DPM-Solver-2 - # 2. Convert to an ODE derivative for 2nd order - derivative = (sample - pred_original_sample) / sigma_interpol - # 3. delta timestep - dt = sigma_down - sigma_hat - - sample = self.sample - self.sample = None - - prev_sample = sample + derivative * dt - prev_sample = prev_sample + noise * sigma_up - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - self.timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [self.index_for_timestep(t) for t in timesteps] - - sigma = self.sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/my_diffusers/schedulers/scheduling_k_dpm_2_discrete.py deleted file mode 100644 index 809da798f889ebe9d7788fd6c422918cd8e1b440..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: - https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 - - Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the - starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 2 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.00085, # sensible defaults - beta_end: float = 0.012, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - # set all values - self.set_timesteps(num_train_timesteps, None, num_train_timesteps) - - def index_for_timestep(self, timestep): - indices = (self.timesteps == timestep).nonzero() - if self.state_in_first_order: - pos = -1 - else: - pos = 0 - return indices[pos].item() - - def scale_model_input( - self, - sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - ) -> torch.FloatTensor: - """ - Args: - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep - Returns: - `torch.FloatTensor`: scaled input sample - """ - step_index = self.index_for_timestep(timestep) - - if self.state_in_first_order: - sigma = self.sigmas[step_index] - else: - sigma = self.sigmas_interpol[step_index] - - sample = sample / ((sigma**2 + 1) ** 0.5) - return sample - - def set_timesteps( - self, - num_inference_steps: int, - device: Union[str, torch.device] = None, - num_train_timesteps: Optional[int] = None, - ): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) - - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - sigmas = torch.from_numpy(sigmas).to(device=device) - - # interpolate sigmas - sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp() - - self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) - self.sigmas_interpol = torch.cat( - [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] - ) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - - if str(device).startswith("mps"): - # mps does not support float64 - timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) - else: - timesteps = torch.from_numpy(timesteps).to(device) - - # interpolate timesteps - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) - interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() - - self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) - - self.sample = None - - def sigma_to_t(self, sigma): - # get log sigma - log_sigma = sigma.log() - - # get distribution - dists = log_sigma - self.log_sigmas[:, None] - - # get sigmas range - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = self.log_sigmas[low_idx] - high = self.log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.view(sigma.shape) - return t - - @property - def state_in_first_order(self): - return self.sample is None - - def step( - self, - model_output: Union[torch.FloatTensor, np.ndarray], - timestep: Union[float, torch.FloatTensor], - sample: Union[torch.FloatTensor, np.ndarray], - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Args: - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep - (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - step_index = self.index_for_timestep(timestep) - - if self.state_in_first_order: - sigma = self.sigmas[step_index] - sigma_interpol = self.sigmas_interpol[step_index + 1] - sigma_next = self.sigmas[step_index + 1] - else: - # 2nd order / KDPM2's method - sigma = self.sigmas[step_index - 1] - sigma_interpol = self.sigmas_interpol[step_index] - sigma_next = self.sigmas[step_index] - - # currently only gamma=0 is supported. This usually works best anyways. - # We can support gamma in the future but then need to scale the timestep before - # passing it to the model which requires a change in API - gamma = 0 - sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol - pred_original_sample = sample - sigma_input * model_output - elif self.config.prediction_type == "v_prediction": - sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol - pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( - sample / (sigma_input**2 + 1) - ) - elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - if self.state_in_first_order: - # 2. Convert to an ODE derivative for 1st order - derivative = (sample - pred_original_sample) / sigma_hat - # 3. delta timestep - dt = sigma_interpol - sigma_hat - - # store for 2nd order step - self.sample = sample - else: - # DPM-Solver-2 - # 2. Convert to an ODE derivative for 2nd order - derivative = (sample - pred_original_sample) / sigma_interpol - - # 3. delta timestep - dt = sigma_next - sigma_hat - - sample = self.sample - self.sample = None - - prev_sample = sample + derivative * dt - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - self.timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [self.index_for_timestep(t) for t in timesteps] - - sigma = self.sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_karras_ve.py b/my_diffusers/schedulers/scheduling_karras_ve.py deleted file mode 100644 index 87f6514a4e93e4a75bd6228ed852306b8c005c3d..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_karras_ve.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import SchedulerMixin - - -@dataclass -class KarrasVeOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Derivative of predicted original image sample (x_0). - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - derivative: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -class KarrasVeScheduler(SchedulerMixin, ConfigMixin): - """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. - - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of - Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the - optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. - - Args: - sigma_min (`float`): minimum noise magnitude - sigma_max (`float`): maximum noise magnitude - s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. - A reasonable range is [1.000, 1.011]. - s_churn (`float`): the parameter controlling the overall amount of stochasticity. - A reasonable range is [0, 100]. - s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). - A reasonable range is [0, 10]. - s_max (`float`): the end value of the sigma range where we add noise. - A reasonable range is [0.2, 80]. - - """ - - order = 2 - - @register_to_config - def __init__( - self, - sigma_min: float = 0.02, - sigma_max: float = 100, - s_noise: float = 1.007, - s_churn: float = 80, - s_min: float = 0.05, - s_max: float = 50, - ): - # standard deviation of the initial noise distribution - self.init_noise_sigma = sigma_max - - # setable values - self.num_inference_steps: int = None - self.timesteps: np.IntTensor = None - self.schedule: torch.FloatTensor = None # sigma(t_i) - - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - - """ - self.num_inference_steps = num_inference_steps - timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps).to(device) - schedule = [ - ( - self.config.sigma_max**2 - * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) - ) - for i in self.timesteps - ] - self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) - - def add_noise_to_input( - self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None - ) -> Tuple[torch.FloatTensor, float]: - """ - Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a - higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. - - TODO Args: - """ - if self.config.s_min <= sigma <= self.config.s_max: - gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) - else: - gamma = 0 - - # sample eps ~ N(0, S_noise^2 * I) - eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device) - sigma_hat = sigma + gamma * sigma - sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) - - return sample_hat, sigma_hat - - def step( - self, - model_output: torch.FloatTensor, - sigma_hat: float, - sigma_prev: float, - sample_hat: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[KarrasVeOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor`): TODO - return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class - - KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). - Returns: - [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: - [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - - pred_original_sample = sample_hat + sigma_hat * model_output - derivative = (sample_hat - pred_original_sample) / sigma_hat - sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative - - if not return_dict: - return (sample_prev, derivative) - - return KarrasVeOutput( - prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample - ) - - def step_correct( - self, - model_output: torch.FloatTensor, - sigma_hat: float, - sigma_prev: float, - sample_hat: torch.FloatTensor, - sample_prev: torch.FloatTensor, - derivative: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[KarrasVeOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. TODO complete description - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor`): TODO - sample_prev (`torch.FloatTensor`): TODO - derivative (`torch.FloatTensor`): TODO - return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class - - Returns: - prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO - - """ - pred_original_sample = sample_prev + sigma_prev * model_output - derivative_corr = (sample_prev - pred_original_sample) / sigma_prev - sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) - - if not return_dict: - return (sample_prev, derivative) - - return KarrasVeOutput( - prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample - ) - - def add_noise(self, original_samples, noise, timesteps): - raise NotImplementedError() diff --git a/my_diffusers/schedulers/scheduling_karras_ve_flax.py b/my_diffusers/schedulers/scheduling_karras_ve_flax.py deleted file mode 100644 index 45c0dbddf7efd22df21cc9859e68d62b54aa8609..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_karras_ve_flax.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax.numpy as jnp -from jax import random - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .scheduling_utils_flax import FlaxSchedulerMixin - - -@flax.struct.dataclass -class KarrasVeSchedulerState: - # setable values - num_inference_steps: Optional[int] = None - timesteps: Optional[jnp.ndarray] = None - schedule: Optional[jnp.ndarray] = None # sigma(t_i) - - @classmethod - def create(cls): - return cls() - - -@dataclass -class FlaxKarrasVeOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): - Derivative of predicted original image sample (x_0). - state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. - """ - - prev_sample: jnp.ndarray - derivative: jnp.ndarray - state: KarrasVeSchedulerState - - -class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and - the VE column of Table 1 from [1] for reference. - - [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." - https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic - differential equations." https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of - Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the - optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. - - Args: - sigma_min (`float`): minimum noise magnitude - sigma_max (`float`): maximum noise magnitude - s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. - A reasonable range is [1.000, 1.011]. - s_churn (`float`): the parameter controlling the overall amount of stochasticity. - A reasonable range is [0, 100]. - s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). - A reasonable range is [0, 10]. - s_max (`float`): the end value of the sigma range where we add noise. - A reasonable range is [0.2, 80]. - """ - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - sigma_min: float = 0.02, - sigma_max: float = 100, - s_noise: float = 1.007, - s_churn: float = 80, - s_min: float = 0.05, - s_max: float = 50, - ): - pass - - def create_state(self): - return KarrasVeSchedulerState.create() - - def set_timesteps( - self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () - ) -> KarrasVeSchedulerState: - """ - Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`KarrasVeSchedulerState`): - the `FlaxKarrasVeScheduler` state data class. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - - """ - timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() - schedule = [ - ( - self.config.sigma_max**2 - * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) - ) - for i in timesteps - ] - - return state.replace( - num_inference_steps=num_inference_steps, - schedule=jnp.array(schedule, dtype=jnp.float32), - timesteps=timesteps, - ) - - def add_noise_to_input( - self, - state: KarrasVeSchedulerState, - sample: jnp.ndarray, - sigma: float, - key: random.KeyArray, - ) -> Tuple[jnp.ndarray, float]: - """ - Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a - higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. - - TODO Args: - """ - if self.config.s_min <= sigma <= self.config.s_max: - gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) - else: - gamma = 0 - - # sample eps ~ N(0, S_noise^2 * I) - key = random.split(key, num=1) - eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) - sigma_hat = sigma + gamma * sigma - sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) - - return sample_hat, sigma_hat - - def step( - self, - state: KarrasVeSchedulerState, - model_output: jnp.ndarray, - sigma_hat: float, - sigma_prev: float, - sample_hat: jnp.ndarray, - return_dict: bool = True, - ) -> Union[FlaxKarrasVeOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class - - Returns: - [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion - chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ - - pred_original_sample = sample_hat + sigma_hat * model_output - derivative = (sample_hat - pred_original_sample) / sigma_hat - sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative - - if not return_dict: - return (sample_prev, derivative, state) - - return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) - - def step_correct( - self, - state: KarrasVeSchedulerState, - model_output: jnp.ndarray, - sigma_hat: float, - sigma_prev: float, - sample_hat: jnp.ndarray, - sample_prev: jnp.ndarray, - derivative: jnp.ndarray, - return_dict: bool = True, - ) -> Union[FlaxKarrasVeOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. TODO complete description - - Args: - state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. - model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. - sigma_hat (`float`): TODO - sigma_prev (`float`): TODO - sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO - derivative (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class - - Returns: - prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO - - """ - pred_original_sample = sample_prev + sigma_prev * model_output - derivative_corr = (sample_prev - pred_original_sample) / sigma_prev - sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) - - if not return_dict: - return (sample_prev, derivative, state) - - return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) - - def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): - raise NotImplementedError() diff --git a/my_diffusers/schedulers/scheduling_lms_discrete.py b/my_diffusers/schedulers/scheduling_lms_discrete.py deleted file mode 100644 index 0fe1f77f9b5c0c22c676dba122c085f63f0d84fa..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_lms_discrete.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -import warnings -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -from scipy import integrate - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete -class LMSDiscreteSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by - Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.derivatives = [] - self.is_scale_input_called = False - - def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] - ) -> torch.FloatTensor: - """ - Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain - - Returns: - `torch.FloatTensor`: scaled input sample - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - sample = sample / ((sigma**2 + 1) ** 0.5) - self.is_scale_input_called = True - return sample - - def get_lms_coefficient(self, order, t, current_order): - """ - Compute a linear multistep coefficient. - - Args: - order (TODO): - t (TODO): - current_order (TODO): - """ - - def lms_derivative(tau): - prod = 1.0 - for k in range(order): - if current_order == k: - continue - prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) - return prod - - integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] - - return integrated_coeff - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - - self.sigmas = torch.from_numpy(sigmas).to(device=device) - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) - else: - self.timesteps = torch.from_numpy(timesteps).to(device=device) - - self.derivatives = [] - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - order: int = 4, - return_dict: bool = True, - ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`float`): current timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. - - """ - if not self.is_scale_input_called: - warnings.warn( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - self.derivatives.append(derivative) - if len(self.derivatives) > order: - self.derivatives.pop(0) - - # 3. Compute linear multistep coefficients - order = min(step_index + 1, order) - lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)] - - # 4. Compute previous sample based on the derivatives path - prev_sample = sample + sum( - coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) - ) - - if not return_dict: - return (prev_sample,) - - return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_lms_discrete_flax.py b/my_diffusers/schedulers/scheduling_lms_discrete_flax.py deleted file mode 100644 index f96e602afe121a09876b0ff7db1d3192e441e32a..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_lms_discrete_flax.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax.numpy as jnp -from scipy import integrate - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import ( - CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - broadcast_to_shape_from_left, -) - - -@flax.struct.dataclass -class LMSDiscreteSchedulerState: - common: CommonSchedulerState - - # setable values - init_noise_sigma: jnp.ndarray - timesteps: jnp.ndarray - sigmas: jnp.ndarray - num_inference_steps: Optional[int] = None - - # running values - derivatives: Optional[jnp.ndarray] = None - - @classmethod - def create( - cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray - ): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) - - -@dataclass -class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): - state: LMSDiscreteSchedulerState - - -class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by - Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`jnp.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): - the `dtype` used for params and computation. - """ - - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] - - dtype: jnp.dtype - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - prediction_type: str = "epsilon", - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) - - timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] - sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 - - # standard deviation of the initial noise distribution - init_noise_sigma = sigmas.max() - - return LMSDiscreteSchedulerState.create( - common=common, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - sigmas=sigmas, - ) - - def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: - """ - Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. - - Args: - state (`LMSDiscreteSchedulerState`): - the `FlaxLMSDiscreteScheduler` state data class instance. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - timestep (`int`): - current discrete timestep in the diffusion chain. - - Returns: - `jnp.ndarray`: scaled input sample - """ - (step_index,) = jnp.where(state.timesteps == timestep, size=1) - step_index = step_index[0] - - sigma = state.sigmas[step_index] - sample = sample / ((sigma**2 + 1) ** 0.5) - return sample - - def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): - """ - Compute a linear multistep coefficient. - - Args: - order (TODO): - t (TODO): - current_order (TODO): - """ - - def lms_derivative(tau): - prod = 1.0 - for k in range(order): - if current_order == k: - continue - prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) - return prod - - integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] - - return integrated_coeff - - def set_timesteps( - self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () - ) -> LMSDiscreteSchedulerState: - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`LMSDiscreteSchedulerState`): - the `FlaxLMSDiscreteScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) - - low_idx = jnp.floor(timesteps).astype(jnp.int32) - high_idx = jnp.ceil(timesteps).astype(jnp.int32) - - frac = jnp.mod(timesteps, 1.0) - - sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 - sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) - - timesteps = timesteps.astype(jnp.int32) - - # initial running values - derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) - - return state.replace( - timesteps=timesteps, - sigmas=sigmas, - num_inference_steps=num_inference_steps, - derivatives=derivatives, - ) - - def step( - self, - state: LMSDiscreteSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - order: int = 4, - return_dict: bool = True, - ) -> Union[FlaxLMSSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class - - Returns: - [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - sigma = state.sigmas[timestep] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - if self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma * model_output - elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" - ) - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - state = state.replace(derivatives=jnp.append(state.derivatives, derivative)) - if len(state.derivatives) > order: - state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) - - # 3. Compute linear multistep coefficients - order = min(timestep + 1, order) - lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] - - # 4. Compute previous sample based on the derivatives path - prev_sample = sample + sum( - coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) - ) - - if not return_dict: - return (prev_sample, state) - - return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) - - def add_noise( - self, - state: LMSDiscreteSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - sigma = state.sigmas[timesteps].flatten() - sigma = broadcast_to_shape_from_left(sigma, noise.shape) - - noisy_samples = original_samples + noise * sigma - - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_pndm.py b/my_diffusers/schedulers/scheduling_pndm.py deleted file mode 100644 index 562cefb178933e40deccb82ebed66e891ee5ab87..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_pndm.py +++ /dev/null @@ -1,425 +0,0 @@ -# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class PNDMScheduler(SchedulerMixin, ConfigMixin): - """ - Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, - namely Runge-Kutta method and a linear multi-step method. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2202.09778 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - skip_prk_steps (`bool`): - allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required - before plms steps; defaults to `False`. - set_alpha_to_one (`bool`, default `False`): - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) - or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. - - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - skip_prk_steps: bool = False, - set_alpha_to_one: bool = False, - prediction_type: str = "epsilon", - steps_offset: int = 0, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # For now we only support F-PNDM, i.e. the runge-kutta method - # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf - # mainly at formula (9), (12), (13) and the Algorithm 2. - self.pndm_order = 4 - - # running values - self.cur_model_output = 0 - self.counter = 0 - self.cur_sample = None - self.ets = [] - - # setable values - self.num_inference_steps = None - self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self.prk_timesteps = None - self.plms_timesteps = None - self.timesteps = None - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - - self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() - self._timesteps += self.config.steps_offset - - if self.config.skip_prk_steps: - # for some models like stable diffusion the prk steps can/should be skipped to - # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation - # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 - self.prk_timesteps = np.array([]) - self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ - ::-1 - ].copy() - else: - prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( - np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order - ) - self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() - self.plms_timesteps = self._timesteps[:-3][ - ::-1 - ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - - timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - - self.ets = [] - self.counter = 0 - self.cur_model_output = 0 - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: - return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) - else: - return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) - - def step_prk( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the - solution to the differential equation. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 - prev_timestep = timestep - diff_to_prev - timestep = self.prk_timesteps[self.counter // 4 * 4] - - if self.counter % 4 == 0: - self.cur_model_output += 1 / 6 * model_output - self.ets.append(model_output) - self.cur_sample = sample - elif (self.counter - 1) % 4 == 0: - self.cur_model_output += 1 / 3 * model_output - elif (self.counter - 2) % 4 == 0: - self.cur_model_output += 1 / 3 * model_output - elif (self.counter - 3) % 4 == 0: - model_output = self.cur_model_output + 1 / 6 * model_output - self.cur_model_output = 0 - - # cur_sample should not be `None` - cur_sample = self.cur_sample if self.cur_sample is not None else sample - - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) - self.counter += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def step_plms( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple - times to approximate the solution. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if not self.config.skip_prk_steps and len(self.ets) < 3: - raise ValueError( - f"{self.__class__} can only be run AFTER scheduler has been run " - "in 'prk' mode for at least 12 iterations " - "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " - "for more information." - ) - - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - - if self.counter != 1: - self.ets = self.ets[-3:] - self.ets.append(model_output) - else: - prev_timestep = timestep - timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps - - if len(self.ets) == 1 and self.counter == 0: - model_output = model_output - self.cur_sample = sample - elif len(self.ets) == 1 and self.counter == 1: - model_output = (model_output + self.ets[-1]) / 2 - sample = self.cur_sample - self.cur_sample = None - elif len(self.ets) == 2: - model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 - elif len(self.ets) == 3: - model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 - else: - model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) - - prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) - self.counter += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): - # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf - # this function computes x_(t−δ) using the formula of (9) - # Note that x_t needs to be added to both sides of the equation - - # Notation ( -> - # alpha_prod_t -> α_t - # alpha_prod_t_prev -> α_(t−δ) - # beta_prod_t -> (1 - α_t) - # beta_prod_t_prev -> (1 - α_(t−δ)) - # sample -> x_t - # model_output -> e_θ(x_t, t) - # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - if self.config.prediction_type == "v_prediction": - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - elif self.config.prediction_type != "epsilon": - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" - ) - - # corresponds to (α_(t−δ) - α_t) divided by - # denominator of x_t in formula (9) and plus 1 - # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = - # sqrt(α_(t−δ)) / sqrt(α_t)) - sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) - - # corresponds to denominator of e_θ(x_t, t) in formula (9) - model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( - alpha_prod_t * beta_prod_t * alpha_prod_t_prev - ) ** (0.5) - - # full formula (9) - prev_sample = ( - sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff - ) - - return prev_sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_pndm_flax.py b/my_diffusers/schedulers/scheduling_pndm_flax.py deleted file mode 100644 index c654f2de8dd3e4f96403cce4b9db8f8b7b69861f..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_pndm_flax.py +++ /dev/null @@ -1,511 +0,0 @@ -# Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax -import jax.numpy as jnp - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import ( - CommonSchedulerState, - FlaxKarrasDiffusionSchedulers, - FlaxSchedulerMixin, - FlaxSchedulerOutput, - add_noise_common, -) - - -@flax.struct.dataclass -class PNDMSchedulerState: - common: CommonSchedulerState - final_alpha_cumprod: jnp.ndarray - - # setable values - init_noise_sigma: jnp.ndarray - timesteps: jnp.ndarray - num_inference_steps: Optional[int] = None - prk_timesteps: Optional[jnp.ndarray] = None - plms_timesteps: Optional[jnp.ndarray] = None - - # running values - cur_model_output: Optional[jnp.ndarray] = None - counter: Optional[jnp.int32] = None - cur_sample: Optional[jnp.ndarray] = None - ets: Optional[jnp.ndarray] = None - - @classmethod - def create( - cls, - common: CommonSchedulerState, - final_alpha_cumprod: jnp.ndarray, - init_noise_sigma: jnp.ndarray, - timesteps: jnp.ndarray, - ): - return cls( - common=common, - final_alpha_cumprod=final_alpha_cumprod, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - ) - - -@dataclass -class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): - state: PNDMSchedulerState - - -class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, - namely Runge-Kutta method and a linear multi-step method. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2202.09778 - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`jnp.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - skip_prk_steps (`bool`): - allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required - before plms steps; defaults to `False`. - set_alpha_to_one (`bool`, default `False`): - each diffusion step uses the value of alphas product at that step and at the previous one. For the final - step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, - otherwise it uses the value of alpha at step 0. - steps_offset (`int`, default `0`): - an offset added to the inference steps. You can use a combination of `offset=1` and - `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in - stable diffusion. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): - the `dtype` used for params and computation. - """ - - _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] - - dtype: jnp.dtype - pndm_order: int - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - skip_prk_steps: bool = False, - set_alpha_to_one: bool = False, - steps_offset: int = 0, - prediction_type: str = "epsilon", - dtype: jnp.dtype = jnp.float32, - ): - self.dtype = dtype - - # For now we only support F-PNDM, i.e. the runge-kutta method - # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf - # mainly at formula (9), (12), (13) and the Algorithm 2. - self.pndm_order = 4 - - def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: - if common is None: - common = CommonSchedulerState.create(self) - - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - final_alpha_cumprod = ( - jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] - ) - - # standard deviation of the initial noise distribution - init_noise_sigma = jnp.array(1.0, dtype=self.dtype) - - timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] - - return PNDMSchedulerState.create( - common=common, - final_alpha_cumprod=final_alpha_cumprod, - init_noise_sigma=init_noise_sigma, - timesteps=timesteps, - ) - - def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`PNDMSchedulerState`): - the `FlaxPNDMScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - shape (`Tuple`): - the shape of the samples to be generated. - """ - - step_ratio = self.config.num_train_timesteps // num_inference_steps - # creates integer timesteps by multiplying by ratio - # rounding to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset - - if self.config.skip_prk_steps: - # for some models like stable diffusion the prk steps can/should be skipped to - # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation - # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 - - prk_timesteps = jnp.array([], dtype=jnp.int32) - plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] - - else: - prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( - jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), - self.pndm_order, - ) - - prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] - plms_timesteps = _timesteps[:-3][::-1] - - timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) - - # initial running values - - cur_model_output = jnp.zeros(shape, dtype=self.dtype) - counter = jnp.int32(0) - cur_sample = jnp.zeros(shape, dtype=self.dtype) - ets = jnp.zeros((4,) + shape, dtype=self.dtype) - - return state.replace( - timesteps=timesteps, - num_inference_steps=num_inference_steps, - prk_timesteps=prk_timesteps, - plms_timesteps=plms_timesteps, - cur_model_output=cur_model_output, - counter=counter, - cur_sample=cur_sample, - ets=ets, - ) - - def scale_model_input( - self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None - ) -> jnp.ndarray: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - sample (`jnp.ndarray`): input sample - timestep (`int`, optional): current timestep - - Returns: - `jnp.ndarray`: scaled input sample - """ - return sample - - def step( - self, - state: PNDMSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - return_dict: bool = True, - ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. - - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class - - Returns: - [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.config.skip_prk_steps: - prev_sample, state = self.step_plms(state, model_output, timestep, sample) - else: - prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) - plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) - - cond = state.counter < len(state.prk_timesteps) - - prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) - - state = state.replace( - cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), - ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), - cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), - counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), - ) - - if not return_dict: - return (prev_sample, state) - - return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) - - def step_prk( - self, - state: PNDMSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: - """ - Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the - solution to the differential equation. - - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class - - Returns: - [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - diff_to_prev = jnp.where( - state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 - ) - prev_timestep = timestep - diff_to_prev - timestep = state.prk_timesteps[state.counter // 4 * 4] - - model_output = jax.lax.select( - (state.counter % 4) != 3, - model_output, # remainder 0, 1, 2 - state.cur_model_output + 1 / 6 * model_output, # remainder 3 - ) - - state = state.replace( - cur_model_output=jax.lax.select_n( - state.counter % 4, - state.cur_model_output + 1 / 6 * model_output, # remainder 0 - state.cur_model_output + 1 / 3 * model_output, # remainder 1 - state.cur_model_output + 1 / 3 * model_output, # remainder 2 - jnp.zeros_like(state.cur_model_output), # remainder 3 - ), - ets=jax.lax.select( - (state.counter % 4) == 0, - state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0 - state.ets, # remainder 1, 2, 3 - ), - cur_sample=jax.lax.select( - (state.counter % 4) == 0, - sample, # remainder 0 - state.cur_sample, # remainder 1, 2, 3 - ), - ) - - cur_sample = state.cur_sample - prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) - state = state.replace(counter=state.counter + 1) - - return (prev_sample, state) - - def step_plms( - self, - state: PNDMSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: - """ - Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple - times to approximate the solution. - - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class - - Returns: - [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before - - prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps - prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) - - # Reference: - # if state.counter != 1: - # state.ets.append(model_output) - # else: - # prev_timestep = timestep - # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps - - prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) - timestep = jnp.where( - state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep - ) - - # Reference: - # if len(state.ets) == 1 and state.counter == 0: - # model_output = model_output - # state.cur_sample = sample - # elif len(state.ets) == 1 and state.counter == 1: - # model_output = (model_output + state.ets[-1]) / 2 - # sample = state.cur_sample - # state.cur_sample = None - # elif len(state.ets) == 2: - # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 - # elif len(state.ets) == 3: - # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 - # else: - # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) - - state = state.replace( - ets=jax.lax.select( - state.counter != 1, - state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1 - state.ets, # counter 1 - ), - cur_sample=jax.lax.select( - state.counter != 1, - sample, # counter != 1 - state.cur_sample, # counter 1 - ), - ) - - state = state.replace( - cur_model_output=jax.lax.select_n( - jnp.clip(state.counter, 0, 4), - model_output, # counter 0 - (model_output + state.ets[-1]) / 2, # counter 1 - (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2 - (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3 - (1 / 24) - * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4 - ), - ) - - sample = state.cur_sample - model_output = state.cur_model_output - prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) - state = state.replace(counter=state.counter + 1) - - return (prev_sample, state) - - def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): - # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf - # this function computes x_(t−δ) using the formula of (9) - # Note that x_t needs to be added to both sides of the equation - - # Notation ( -> - # alpha_prod_t -> α_t - # alpha_prod_t_prev -> α_(t−δ) - # beta_prod_t -> (1 - α_t) - # beta_prod_t_prev -> (1 - α_(t−δ)) - # sample -> x_t - # model_output -> e_θ(x_t, t) - # prev_sample -> x_(t−δ) - alpha_prod_t = state.common.alphas_cumprod[timestep] - alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod - ) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - if self.config.prediction_type == "v_prediction": - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample - elif self.config.prediction_type != "epsilon": - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" - ) - - # corresponds to (α_(t−δ) - α_t) divided by - # denominator of x_t in formula (9) and plus 1 - # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = - # sqrt(α_(t−δ)) / sqrt(α_t)) - sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) - - # corresponds to denominator of e_θ(x_t, t) in formula (9) - model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( - alpha_prod_t * beta_prod_t * alpha_prod_t_prev - ) ** (0.5) - - # full formula (9) - prev_sample = ( - sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff - ) - - return prev_sample - - def add_noise( - self, - state: PNDMSchedulerState, - original_samples: jnp.ndarray, - noise: jnp.ndarray, - timesteps: jnp.ndarray, - ) -> jnp.ndarray: - return add_noise_common(state.common, original_samples, noise, timesteps) - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_repaint.py b/my_diffusers/schedulers/scheduling_repaint.py deleted file mode 100644 index 96af210f06b10513ec72277315c9c1a84c3a5bef..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_repaint.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import SchedulerMixin - - -@dataclass -class RePaintSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from - the current timestep. `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: torch.FloatTensor - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class RePaintScheduler(SchedulerMixin, ConfigMixin): - """ - RePaint is a schedule for DDPM inpainting inside a given mask. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - eta (`float`): - The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and - 1.0 is DDPM scheduler respectively. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample between -1 and 1 for numerical stability. - - """ - - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - eta: float = 0.0, - trained_betas: Optional[np.ndarray] = None, - clip_sample: bool = True, - ): - if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - elif beta_schedule == "sigmoid": - # GeoDiff sigmoid schedule - betas = torch.linspace(-6, 6, num_train_timesteps) - self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.one = torch.tensor(1.0) - - self.final_alpha_cumprod = torch.tensor(1.0) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) - - self.eta = eta - - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def set_timesteps( - self, - num_inference_steps: int, - jump_length: int = 10, - jump_n_sample: int = 10, - device: Union[str, torch.device] = None, - ): - num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) - self.num_inference_steps = num_inference_steps - - timesteps = [] - - jumps = {} - for j in range(0, num_inference_steps - jump_length, jump_length): - jumps[j] = jump_n_sample - 1 - - t = num_inference_steps - while t >= 1: - t = t - 1 - timesteps.append(t) - - if jumps.get(t, 0) > 0: - jumps[t] = jumps[t] - 1 - for _ in range(jump_length): - t = t + 1 - timesteps.append(t) - - timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) - self.timesteps = torch.from_numpy(timesteps).to(device) - - def _get_variance(self, t): - prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps - - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - # For t > 0, compute predicted variance βt (see formula (6) and (7) from - # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get - # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add - # variance to pred_sample - # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf - # without eta. - # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - - return variance - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - original_image: torch.FloatTensor, - mask: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[RePaintSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned - diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - original_image (`torch.FloatTensor`): - the original image to inpaint on. - mask (`torch.FloatTensor`): - the mask where 0.0 values define which part of the original image to inpaint (change). - generator (`torch.Generator`, *optional*): random number generator. - return_dict (`bool`): option for returning tuple rather than - DDPMSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - t = timestep - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps - - # 1. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - beta_prod_t = 1 - alpha_prod_t - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 - - # 3. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) - - # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we - # substitute formula (7) in the algorithm coming from DDPM paper - # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. - # DDIM schedule gives the same results as DDPM with eta = 1.0 - # Noise is being reused in 7. and 8., but no impact on quality has - # been observed. - - # 5. Add noise - device = model_output.device - noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) - std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 - - variance = 0 - if t > 0 and self.eta > 0: - variance = std_dev_t * noise - - # 6. compute "direction pointing to x_t" of formula (12) - # from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output - - # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance - - # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf - prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise - - # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf - pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part - - if not return_dict: - return ( - pred_prev_sample, - pred_original_sample, - ) - - return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) - - def undo_step(self, sample, timestep, generator=None): - n = self.config.num_train_timesteps // self.num_inference_steps - - for i in range(n): - beta = self.betas[timestep + i] - if sample.device.type == "mps": - # randn does not work reproducibly on mps - noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator) - noise = noise.to(sample.device) - else: - noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) - - # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf - sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise - - return sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_sde_ve.py b/my_diffusers/schedulers/scheduling_sde_ve.py deleted file mode 100644 index 0784f35a7826f97b562807d95f771f5edb476a67..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_sde_ve.py +++ /dev/null @@ -1,281 +0,0 @@ -# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch - -import math -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import SchedulerMixin, SchedulerOutput - - -@dataclass -class SdeVeOutput(BaseOutput): - """ - Output class for the ScoreSdeVeScheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. - """ - - prev_sample: torch.FloatTensor - prev_sample_mean: torch.FloatTensor - - -class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): - """ - The variance exploding stochastic differential equation (SDE) scheduler. - - For more information, see the original paper: https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - snr (`float`): - coefficient weighting the step from the model_output sample (from the network) to the random noise. - sigma_min (`float`): - initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the - distribution of the data. - sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. - sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to - epsilon. - correct_steps (`int`): number of correction steps performed on a produced sample. - """ - - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 2000, - snr: float = 0.15, - sigma_min: float = 0.01, - sigma_max: float = 1348.0, - sampling_eps: float = 1e-5, - correct_steps: int = 1, - ): - # standard deviation of the initial noise distribution - self.init_noise_sigma = sigma_max - - # setable values - self.timesteps = None - - self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) - - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def set_timesteps( - self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None - ): - """ - Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). - - """ - sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - - self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device) - - def set_sigmas( - self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None - ): - """ - Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. - - The sigmas control the weight of the `drift` and `diffusion` components of sample update. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - sigma_min (`float`, optional): - initial noise scale value (overrides value given at Scheduler instantiation). - sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). - sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). - - """ - sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min - sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max - sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - if self.timesteps is None: - self.set_timesteps(num_inference_steps, sampling_eps) - - self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) - self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) - self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) - - def get_adjacent_sigma(self, timesteps, t): - return torch.where( - timesteps == 0, - torch.zeros_like(t.to(timesteps.device)), - self.discrete_sigmas[timesteps - 1].to(timesteps.device), - ) - - def step_pred( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[SdeVeOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if - `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.timesteps is None: - raise ValueError( - "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - timestep = timestep * torch.ones( - sample.shape[0], device=sample.device - ) # torch.repeat_interleave(timestep, sample.shape[0]) - timesteps = (timestep * (len(self.timesteps) - 1)).long() - - # mps requires indices to be in the same device, so we use cpu as is the default with cuda - timesteps = timesteps.to(self.discrete_sigmas.device) - - sigma = self.discrete_sigmas[timesteps].to(sample.device) - adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) - drift = torch.zeros_like(sample) - diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 - - # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) - # also equation 47 shows the analog from SDE models to ancestral sampling methods - diffusion = diffusion.flatten() - while len(diffusion.shape) < len(sample.shape): - diffusion = diffusion.unsqueeze(-1) - drift = drift - diffusion**2 * model_output - - # equation 6: sample noise for the diffusion term of - noise = randn_tensor( - sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype - ) - prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep - # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g - - if not return_dict: - return (prev_sample, prev_sample_mean) - - return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) - - def step_correct( - self, - model_output: torch.FloatTensor, - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. This is often run repeatedly - after making the prediction for the previous timestep. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if - `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.timesteps is None: - raise ValueError( - "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" - # sample noise for correction - noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device) - - # compute step size from the model_output, the noise, and the snr - grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() - noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() - step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 - step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) - # self.repeat_scalar(step_size, sample.shape[0]) - - # compute corrected sample: model_output term and noise term - step_size = step_size.flatten() - while len(step_size.shape) < len(sample.shape): - step_size = step_size.unsqueeze(-1) - prev_sample_mean = sample + step_size * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - timesteps = timesteps.to(original_samples.device) - sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] - noise = torch.randn_like(original_samples) * sigmas[:, None, None, None] - noisy_samples = noise + original_samples - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_sde_ve_flax.py b/my_diffusers/schedulers/scheduling_sde_ve_flax.py deleted file mode 100644 index 09f9c2cf29ddaa4626f7b055350d7a38adc02f80..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_sde_ve_flax.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import flax -import jax.numpy as jnp -from jax import random - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left - - -@flax.struct.dataclass -class ScoreSdeVeSchedulerState: - # setable values - timesteps: Optional[jnp.ndarray] = None - discrete_sigmas: Optional[jnp.ndarray] = None - sigmas: Optional[jnp.ndarray] = None - - @classmethod - def create(cls): - return cls() - - -@dataclass -class FlaxSdeVeOutput(FlaxSchedulerOutput): - """ - Output class for the ScoreSdeVeScheduler's step function output. - - Args: - state (`ScoreSdeVeSchedulerState`): - prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): - Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. - """ - - state: ScoreSdeVeSchedulerState - prev_sample: jnp.ndarray - prev_sample_mean: Optional[jnp.ndarray] = None - - -class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): - """ - The variance exploding stochastic differential equation (SDE) scheduler. - - For more information, see the original paper: https://arxiv.org/abs/2011.13456 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - snr (`float`): - coefficient weighting the step from the model_output sample (from the network) to the random noise. - sigma_min (`float`): - initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the - distribution of the data. - sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. - sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to - epsilon. - correct_steps (`int`): number of correction steps performed on a produced sample. - """ - - @property - def has_state(self): - return True - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 2000, - snr: float = 0.15, - sigma_min: float = 0.01, - sigma_max: float = 1348.0, - sampling_eps: float = 1e-5, - correct_steps: int = 1, - ): - pass - - def create_state(self): - state = ScoreSdeVeSchedulerState.create() - return self.set_sigmas( - state, - self.config.num_train_timesteps, - self.config.sigma_min, - self.config.sigma_max, - self.config.sampling_eps, - ) - - def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None - ) -> ScoreSdeVeSchedulerState: - """ - Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). - - """ - sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - - timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) - return state.replace(timesteps=timesteps) - - def set_sigmas( - self, - state: ScoreSdeVeSchedulerState, - num_inference_steps: int, - sigma_min: float = None, - sigma_max: float = None, - sampling_eps: float = None, - ) -> ScoreSdeVeSchedulerState: - """ - Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. - - The sigmas control the weight of the `drift` and `diffusion` components of sample update. - - Args: - state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - sigma_min (`float`, optional): - initial noise scale value (overrides value given at Scheduler instantiation). - sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). - sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). - """ - sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min - sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max - sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps - if state.timesteps is None: - state = self.set_timesteps(state, num_inference_steps, sampling_eps) - - discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) - sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) - - return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) - - def get_adjacent_sigma(self, state, timesteps, t): - return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) - - def step_pred( - self, - state: ScoreSdeVeSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - key: random.KeyArray, - return_dict: bool = True, - ) -> Union[FlaxSdeVeOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class - - Returns: - [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if state.timesteps is None: - raise ValueError( - "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - timestep = timestep * jnp.ones( - sample.shape[0], - ) - timesteps = (timestep * (len(state.timesteps) - 1)).long() - - sigma = state.discrete_sigmas[timesteps] - adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) - drift = jnp.zeros_like(sample) - diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 - - # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) - # also equation 47 shows the analog from SDE models to ancestral sampling methods - diffusion = diffusion.flatten() - diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) - drift = drift - diffusion**2 * model_output - - # equation 6: sample noise for the diffusion term of - key = random.split(key, num=1) - noise = random.normal(key=key, shape=sample.shape) - prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep - # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g - - if not return_dict: - return (prev_sample, prev_sample_mean, state) - - return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) - - def step_correct( - self, - state: ScoreSdeVeSchedulerState, - model_output: jnp.ndarray, - sample: jnp.ndarray, - key: random.KeyArray, - return_dict: bool = True, - ) -> Union[FlaxSdeVeOutput, Tuple]: - """ - Correct the predicted sample based on the output model_output of the network. This is often run repeatedly - after making the prediction for the previous timestep. - - Args: - state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class - - Returns: - [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - if state.timesteps is None: - raise ValueError( - "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" - # sample noise for correction - key = random.split(key, num=1) - noise = random.normal(key=key, shape=sample.shape) - - # compute step size from the model_output, the noise, and the snr - grad_norm = jnp.linalg.norm(model_output) - noise_norm = jnp.linalg.norm(noise) - step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 - step_size = step_size * jnp.ones(sample.shape[0]) - - # compute corrected sample: model_output term and noise term - step_size = step_size.flatten() - step_size = broadcast_to_shape_from_left(step_size, sample.shape) - prev_sample_mean = sample + step_size * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise - - if not return_dict: - return (prev_sample, state) - - return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_sde_vp.py b/my_diffusers/schedulers/scheduling_sde_vp.py deleted file mode 100644 index 6e2ead90edb57cd1eb1d270695e222d404064180..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_sde_vp.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch - -import math -from typing import Union - -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import randn_tensor -from .scheduling_utils import SchedulerMixin - - -class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): - """ - The variance preserving stochastic differential equation (SDE) scheduler. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more information, see the original paper: https://arxiv.org/abs/2011.13456 - - UNDER CONSTRUCTION - - """ - - order = 1 - - @register_to_config - def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): - self.sigmas = None - self.discrete_sigmas = None - self.timesteps = None - - def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): - self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) - - def step_pred(self, score, x, t, generator=None): - if self.timesteps is None: - raise ValueError( - "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" - ) - - # TODO(Patrick) better comments + non-PyTorch - # postprocess model score - log_mean_coeff = ( - -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min - ) - std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) - std = std.flatten() - while len(std.shape) < len(score.shape): - std = std.unsqueeze(-1) - score = -score / std - - # compute - dt = -1.0 / len(self.timesteps) - - beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) - beta_t = beta_t.flatten() - while len(beta_t.shape) < len(x.shape): - beta_t = beta_t.unsqueeze(-1) - drift = -0.5 * beta_t * x - - diffusion = torch.sqrt(beta_t) - drift = drift - diffusion**2 * score - x_mean = x + drift * dt - - # add noise - noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) - x = x_mean + diffusion * math.sqrt(-dt) * noise - - return x, x_mean - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_unclip.py b/my_diffusers/schedulers/scheduling_unclip.py deleted file mode 100644 index 6403ee3f15187cf6af97dc03ceb3ad3d0b538597..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_unclip.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, randn_tensor -from .scheduling_utils import SchedulerMixin - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP -class UnCLIPSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class UnCLIPScheduler(SchedulerMixin, ConfigMixin): - """ - This is a modified DDPM Scheduler specifically for the karlo unCLIP model. - - This scheduler has some minor variations in how it calculates the learned range variance and dynamically - re-calculates betas based off the timesteps it is skipping. - - The scheduler also uses a slightly different step ratio when computing timesteps to use for inference. - - See [`~DDPMScheduler`] for more information on DDPM scheduling - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log` - or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical - stability. - clip_sample_range (`float`, default `1.0`): - The range to clip the sample between. See `clip_sample`. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) - or `sample` (directly predicting the noisy sample`) - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - variance_type: str = "fixed_small_log", - clip_sample: bool = True, - clip_sample_range: Optional[float] = 1.0, - prediction_type: str = "epsilon", - beta_schedule: str = "squaredcos_cap_v2", - ): - if beta_schedule != "squaredcos_cap_v2": - raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") - - self.betas = betas_for_alpha_bar(num_train_timesteps) - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.one = torch.tensor(1.0) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # setable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) - - self.variance_type = variance_type - - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The - different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy - of the results. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - """ - self.num_inference_steps = num_inference_steps - step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1) - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) - - def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None): - if prev_timestep is None: - prev_timestep = t - 1 - - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - if prev_timestep == t - 1: - beta = self.betas[t] - else: - beta = 1 - alpha_prod_t / alpha_prod_t_prev - - # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) - # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = beta_prod_t_prev / beta_prod_t * beta - - if variance_type is None: - variance_type = self.config.variance_type - - # hacks - were probably added for training stability - if variance_type == "fixed_small_log": - variance = torch.log(torch.clamp(variance, min=1e-20)) - variance = torch.exp(0.5 * variance) - elif variance_type == "learned_range": - # NOTE difference with DDPM scheduler - min_log = variance.log() - max_log = beta.log() - - frac = (predicted_variance + 1) / 2 - variance = frac * max_log + (1 - frac) * min_log - - return variance - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - prev_timestep: Optional[int] = None, - generator=None, - return_dict: bool = True, - ) -> Union[UnCLIPSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. - Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. - generator: random number generator. - return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - - """ - t = timestep - - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": - model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) - else: - predicted_variance = None - - # 1. compute alphas, betas - if prev_timestep is None: - prev_timestep = t - 1 - - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev - - if prev_timestep == t - 1: - beta = self.betas[t] - alpha = self.alphas[t] - else: - beta = 1 - alpha_prod_t / alpha_prod_t_prev - alpha = 1 - beta - - # 2. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.config.prediction_type == "sample": - pred_original_sample = model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`" - " for the UnCLIPScheduler." - ) - - # 3. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = torch.clamp( - pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range - ) - - # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t - current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t - - # 5. Compute predicted previous sample µ_t - # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample - - # 6. Add noise - variance = 0 - if t > 0: - variance_noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device - ) - - variance = self._get_variance( - t, - predicted_variance=predicted_variance, - prev_timestep=prev_timestep, - ) - - if self.variance_type == "fixed_small_log": - variance = variance - elif self.variance_type == "learned_range": - variance = (0.5 * variance).exp() - else: - raise ValueError( - f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`" - " for the UnCLIPScheduler." - ) - - variance = variance * variance_noise - - pred_prev_sample = pred_prev_sample + variance - - if not return_dict: - return (pred_prev_sample,) - - return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) diff --git a/my_diffusers/schedulers/scheduling_unipc_multistep.py b/my_diffusers/schedulers/scheduling_unipc_multistep.py deleted file mode 100644 index 66a807fa1c10530f7c54e14768c2e90d041be467..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_unipc_multistep.py +++ /dev/null @@ -1,607 +0,0 @@ -# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info -# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a - corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is - by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can - also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied - after any off-the-shelf solvers to increase the order of accuracy. - - For more details, see the original paper: https://arxiv.org/abs/2302.04867 - - Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend - to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - - We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note - that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - solver_order (`int`, default `2`): - the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of - accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling, - and `solver_order=3` for unconditional sampling. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion - process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 - https://imagen.research.google/video/paper.pdf) - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the - dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models - (such as stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen - (https://arxiv.org/abs/2205.11487). - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. - predict_x0 (`bool`, default `True`): - whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details - solver_type (`str`, default `bh2`): - the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2` - otherwise. - lower_order_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically - find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. - disable_corrector (`list`, default `[]`): - decide which step to disable the corrector. For large guidance scale, the misalignment between the - `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated - by disable the corrector at the first few steps (e.g., disable_corrector=[0]) - solver_p (`SchedulerMixin`, default `None`): - can be any other scheduler. If specified, the algorithm will become solver_p + UniC. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: List[int] = [], - solver_p: SchedulerMixin = None, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - if solver_type not in ["bh1", "bh2"]: - if solver_type in ["midpoint", "heun", "logrho"]: - solver_type = "bh1" - else: - raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") - - self.predict_x0 = predict_x0 - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.timestep_list = [None] * solver_order - self.lower_order_nums = 0 - self.disable_corrector = disable_corrector - self.solver_p = solver_p - self.last_sample = None - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - self.last_sample = None - if self.solver_p: - self.solver_p.set_timesteps(num_inference_steps, device=device) - - def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor - ) -> torch.FloatTensor: - r""" - Convert the model output to the corresponding type that the algorithm PC needs. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the converted model output. - """ - if self.predict_x0: - if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - dynamic_max_val = torch.quantile( - torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 - ) - dynamic_max_val = torch.maximum( - dynamic_max_val, - self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), - )[(...,) + (None,) * (x0_pred.ndim - 1)] - x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val - x0_pred = x0_pred.type(orig_dtype) - return x0_pred - else: - if self.config.prediction_type == "epsilon": - return model_output - elif self.config.prediction_type == "sample": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - epsilon = alpha_t * model_output + sigma_t * sample - return epsilon - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the UniPCMultistepScheduler." - ) - - def multistep_uni_p_bh_update( - self, - model_output: torch.FloatTensor, - prev_timestep: int, - sample: torch.FloatTensor, - order: int, - ) -> torch.FloatTensor: - """ - One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. - - Args: - model_output (`torch.FloatTensor`): - direct outputs from learned diffusion model at the current timestep. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - order (`int`): the order of UniP at this step, also the p in UniPC-p. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - timestep_list = self.timestep_list - model_output_list = self.model_outputs - - s0, t = self.timestep_list[-1], prev_timestep - m0 = model_output_list[-1] - x = sample - - if self.solver_p: - x_t = self.solver_p.step(model_output, s0, x).prev_sample - return x_t - - lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] - - h = lambda_t - lambda_s0 - device = sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = timestep_list[-(i + 1)] - mi = model_output_list[-(i + 1)] - lambda_si = self.lambda_t[si] - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) # (B, K) - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) - else: - D1s = None - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) - else: - pred_res = 0 - x_t = x_t_ - alpha_t * B_h * pred_res - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) - else: - pred_res = 0 - x_t = x_t_ - sigma_t * B_h * pred_res - - x_t = x_t.to(x.dtype) - return x_t - - def multistep_uni_c_bh_update( - self, - this_model_output: torch.FloatTensor, - this_timestep: int, - last_sample: torch.FloatTensor, - this_sample: torch.FloatTensor, - order: int, - ) -> torch.FloatTensor: - """ - One step for the UniC (B(h) version). - - Args: - this_model_output (`torch.FloatTensor`): the model outputs at `x_t` - this_timestep (`int`): the current timestep `t` - last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}` - this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}` - order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy - should be order + 1 - - Returns: - `torch.FloatTensor`: the corrected sample tensor at the current timestep. - """ - timestep_list = self.timestep_list - model_output_list = self.model_outputs - - s0, t = timestep_list[-1], this_timestep - m0 = model_output_list[-1] - x = last_sample - x_t = this_sample - model_t = this_model_output - - lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] - alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] - sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] - - h = lambda_t - lambda_s0 - device = this_sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = timestep_list[-(i + 1)] - mi = model_output_list[-(i + 1)] - lambda_si = self.lambda_t[si] - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) - else: - D1s = None - - # for order 1, we use a simplified version - if order == 1: - rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_c = torch.linalg.solve(R, b) - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) - x_t = x_t.to(x.dtype) - return x_t - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the multistep UniPC. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - - use_corrector = ( - step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None - ) - - model_output_convert = self.convert_model_output(model_output, timestep, sample) - if use_corrector: - sample = self.multistep_uni_c_bh_update( - this_model_output=model_output_convert, - this_timestep=timestep, - last_sample=self.last_sample, - this_sample=sample, - order=self.this_order, - ) - - # now prepare to run the predictor - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.timestep_list[i] = self.timestep_list[i + 1] - - self.model_outputs[-1] = model_output_convert - self.timestep_list[-1] = timestep - - if self.config.lower_order_final: - this_order = min(self.config.solver_order, len(self.timesteps) - step_index) - else: - this_order = self.config.solver_order - - self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep - assert self.this_order > 0 - - self.last_sample = sample - prev_sample = self.multistep_uni_p_bh_update( - model_output=model_output, # pass the original non-converted model output, in case solver-p is used - prev_timestep=prev_timestep, - sample=sample, - order=self.this_order, - ) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/my_diffusers/schedulers/scheduling_utils.py b/my_diffusers/schedulers/scheduling_utils.py deleted file mode 100644 index 386d60b2eae79c8205ad22938e519e5876b303ad..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_utils.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import importlib -import os -from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, Optional, Union - -import torch - -from ..utils import BaseOutput - - -SCHEDULER_CONFIG_NAME = "scheduler_config.json" - - -# NOTE: We make this type an enum because it simplifies usage in docs and prevents -# circular imports when used for `_compatibles` within the schedulers module. -# When it's used as a type in pipelines, it really is a Union because the actual -# scheduler instance is passed in. -class KarrasDiffusionSchedulers(Enum): - DDIMScheduler = 1 - DDPMScheduler = 2 - PNDMScheduler = 3 - LMSDiscreteScheduler = 4 - EulerDiscreteScheduler = 5 - HeunDiscreteScheduler = 6 - EulerAncestralDiscreteScheduler = 7 - DPMSolverMultistepScheduler = 8 - DPMSolverSinglestepScheduler = 9 - KDPM2DiscreteScheduler = 10 - KDPM2AncestralDiscreteScheduler = 11 - DEISMultistepScheduler = 12 - UniPCMultistepScheduler = 13 - - -@dataclass -class SchedulerOutput(BaseOutput): - """ - Base class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: torch.FloatTensor - - -class SchedulerMixin: - """ - Mixin containing common functions for the schedulers. - - Class attributes: - - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that - `from_config` can be used from a class different than the one used to save the config (should be overridden - by parent class). - """ - - config_name = SCHEDULER_CONFIG_NAME - _compatibles = [] - has_compatibles = True - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Dict[str, Any] = None, - subfolder: Optional[str] = None, - return_unused_kwargs=False, - **kwargs, - ): - r""" - Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an - organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing the schedluer configurations saved using - [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. - subfolder (`str`, *optional*): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - return_unused_kwargs (`bool`, *optional*, defaults to `False`): - Whether kwargs that are not consumed by the Python class should be returned or not. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to - use this method in a firewalled environment. - - - - """ - config, kwargs = cls.load_config( - pretrained_model_name_or_path=pretrained_model_name_or_path, - subfolder=subfolder, - return_unused_kwargs=True, - **kwargs, - ) - return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) - - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - """ - Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the - [`~SchedulerMixin.from_pretrained`] class method. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the configuration JSON file will be saved (will be created if it does not exist). - """ - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - - @property - def compatibles(self): - """ - Returns all schedulers that are compatible with this scheduler - - Returns: - `List[SchedulerMixin]`: List of compatible schedulers - """ - return self._get_compatibles() - - @classmethod - def _get_compatibles(cls): - compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) - diffusers_library = importlib.import_module(__name__.split(".")[0]) - compatible_classes = [ - getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) - ] - return compatible_classes diff --git a/my_diffusers/schedulers/scheduling_utils_flax.py b/my_diffusers/schedulers/scheduling_utils_flax.py deleted file mode 100644 index 19ce5b8360b9be5bb4b4ec46fbeac0715d6b5869..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_utils_flax.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import importlib -import math -import os -from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, Optional, Tuple, Union - -import flax -import jax.numpy as jnp - -from ..utils import BaseOutput - - -SCHEDULER_CONFIG_NAME = "scheduler_config.json" - - -# NOTE: We make this type an enum because it simplifies usage in docs and prevents -# circular imports when used for `_compatibles` within the schedulers module. -# When it's used as a type in pipelines, it really is a Union because the actual -# scheduler instance is passed in. -class FlaxKarrasDiffusionSchedulers(Enum): - FlaxDDIMScheduler = 1 - FlaxDDPMScheduler = 2 - FlaxPNDMScheduler = 3 - FlaxLMSDiscreteScheduler = 4 - FlaxDPMSolverMultistepScheduler = 5 - - -@dataclass -class FlaxSchedulerOutput(BaseOutput): - """ - Base class for the scheduler's step function output. - - Args: - prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: jnp.ndarray - - -class FlaxSchedulerMixin: - """ - Mixin containing common functions for the schedulers. - - Class attributes: - - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that - `from_config` can be used from a class different than the one used to save the config (should be overridden - by parent class). - """ - - config_name = SCHEDULER_CONFIG_NAME - ignore_for_config = ["dtype"] - _compatibles = [] - has_compatibles = True - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Dict[str, Any] = None, - subfolder: Optional[str] = None, - return_unused_kwargs=False, - **kwargs, - ): - r""" - Instantiate a Scheduler class from a pre-defined JSON-file. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an - organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], - e.g., `./my_model_directory/`. - subfolder (`str`, *optional*): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - return_unused_kwargs (`bool`, *optional*, defaults to `False`): - Whether kwargs that are not consumed by the Python class should be returned or not. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - - - - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - - - Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to - use this method in a firewalled environment. - - - - """ - config, kwargs = cls.load_config( - pretrained_model_name_or_path=pretrained_model_name_or_path, - subfolder=subfolder, - return_unused_kwargs=True, - **kwargs, - ) - scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) - - if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): - state = scheduler.create_state() - - if return_unused_kwargs: - return scheduler, state, unused_kwargs - - return scheduler, state - - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): - """ - Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the - [`~FlaxSchedulerMixin.from_pretrained`] class method. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the configuration JSON file will be saved (will be created if it does not exist). - """ - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - - @property - def compatibles(self): - """ - Returns all schedulers that are compatible with this scheduler - - Returns: - `List[SchedulerMixin]`: List of compatible schedulers - """ - return self._get_compatibles() - - @classmethod - def _get_compatibles(cls): - compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) - diffusers_library = importlib.import_module(__name__.split(".")[0]) - compatible_classes = [ - getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) - ] - return compatible_classes - - -def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: - assert len(shape) >= x.ndim - return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) - - -def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return jnp.array(betas, dtype=dtype) - - -@flax.struct.dataclass -class CommonSchedulerState: - alphas: jnp.ndarray - betas: jnp.ndarray - alphas_cumprod: jnp.ndarray - - @classmethod - def create(cls, scheduler): - config = scheduler.config - - if config.trained_betas is not None: - betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) - elif config.beta_schedule == "linear": - betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) - elif config.beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - betas = ( - jnp.linspace( - config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype - ) - ** 2 - ) - elif config.beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) - else: - raise NotImplementedError( - f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" - ) - - alphas = 1.0 - betas - - alphas_cumprod = jnp.cumprod(alphas, axis=0) - - return cls( - alphas=alphas, - betas=betas, - alphas_cumprod=alphas_cumprod, - ) - - -def get_sqrt_alpha_prod( - state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray -): - alphas_cumprod = state.alphas_cumprod - - sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) - - return sqrt_alpha_prod, sqrt_one_minus_alpha_prod - - -def add_noise_common( - state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray -): - sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps) - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - -def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray): - sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps) - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity diff --git a/my_diffusers/schedulers/scheduling_vq_diffusion.py b/my_diffusers/schedulers/scheduling_vq_diffusion.py deleted file mode 100644 index b92722e4d462ca675bbf11230c1c39810de48b6e..0000000000000000000000000000000000000000 --- a/my_diffusers/schedulers/scheduling_vq_diffusion.py +++ /dev/null @@ -1,496 +0,0 @@ -# Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from .scheduling_utils import SchedulerMixin - - -@dataclass -class VQDiffusionSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): - Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: torch.LongTensor - - -def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: - """ - Convert batch of vector of class indices into batch of log onehot vectors - - Args: - x (`torch.LongTensor` of shape `(batch size, vector length)`): - Batch of class indices - - num_classes (`int`): - number of classes to be used for the onehot vectors - - Returns: - `torch.FloatTensor` of shape `(batch size, num classes, vector length)`: - Log onehot vectors - """ - x_onehot = F.one_hot(x, num_classes) - x_onehot = x_onehot.permute(0, 2, 1) - log_x = torch.log(x_onehot.float().clamp(min=1e-30)) - return log_x - - -def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: - """ - Apply gumbel noise to `logits` - """ - uniform = torch.rand(logits.shape, device=logits.device, generator=generator) - gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) - noised = gumbel_noise + logits - return noised - - -def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): - """ - Cumulative and non-cumulative alpha schedules. - - See section 4.1. - """ - att = ( - np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) - + alpha_cum_start - ) - att = np.concatenate(([1], att)) - at = att[1:] / att[:-1] - att = np.concatenate((att[1:], [1])) - return at, att - - -def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): - """ - Cumulative and non-cumulative gamma schedules. - - See section 4.1. - """ - ctt = ( - np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) - + gamma_cum_start - ) - ctt = np.concatenate(([0], ctt)) - one_minus_ctt = 1 - ctt - one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] - ct = 1 - one_minus_ct - ctt = np.concatenate((ctt[1:], [0])) - return ct, ctt - - -class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): - """ - The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. - - The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous - diffusion timestep. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - For more details, see the original paper: https://arxiv.org/abs/2111.14822 - - Args: - num_vec_classes (`int`): - The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked - latent pixel. - - num_train_timesteps (`int`): - Number of diffusion steps used to train the model. - - alpha_cum_start (`float`): - The starting cumulative alpha value. - - alpha_cum_end (`float`): - The ending cumulative alpha value. - - gamma_cum_start (`float`): - The starting cumulative gamma value. - - gamma_cum_end (`float`): - The ending cumulative gamma value. - """ - - order = 1 - - @register_to_config - def __init__( - self, - num_vec_classes: int, - num_train_timesteps: int = 100, - alpha_cum_start: float = 0.99999, - alpha_cum_end: float = 0.000009, - gamma_cum_start: float = 0.000009, - gamma_cum_end: float = 0.99999, - ): - self.num_embed = num_vec_classes - - # By convention, the index for the mask class is the last class index - self.mask_class = self.num_embed - 1 - - at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) - ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) - - num_non_mask_classes = self.num_embed - 1 - bt = (1 - at - ct) / num_non_mask_classes - btt = (1 - att - ctt) / num_non_mask_classes - - at = torch.tensor(at.astype("float64")) - bt = torch.tensor(bt.astype("float64")) - ct = torch.tensor(ct.astype("float64")) - log_at = torch.log(at) - log_bt = torch.log(bt) - log_ct = torch.log(ct) - - att = torch.tensor(att.astype("float64")) - btt = torch.tensor(btt.astype("float64")) - ctt = torch.tensor(ctt.astype("float64")) - log_cumprod_at = torch.log(att) - log_cumprod_bt = torch.log(btt) - log_cumprod_ct = torch.log(ctt) - - self.log_at = log_at.float() - self.log_bt = log_bt.float() - self.log_ct = log_ct.float() - self.log_cumprod_at = log_cumprod_at.float() - self.log_cumprod_bt = log_cumprod_bt.float() - self.log_cumprod_ct = log_cumprod_ct.float() - - # setable values - self.num_inference_steps = None - self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - - device (`str` or `torch.device`): - device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. - """ - self.num_inference_steps = num_inference_steps - timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps).to(device) - - self.log_at = self.log_at.to(device) - self.log_bt = self.log_bt.to(device) - self.log_ct = self.log_ct.to(device) - self.log_cumprod_at = self.log_cumprod_at.to(device) - self.log_cumprod_bt = self.log_cumprod_bt.to(device) - self.log_cumprod_ct = self.log_cumprod_ct.to(device) - - def step( - self, - model_output: torch.FloatTensor, - timestep: torch.long, - sample: torch.LongTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[VQDiffusionSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the - docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. - - Args: - log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): - The log probabilities for the predicted classes of the initial latent pixels. Does not include a - prediction for the masked class as the initial unnoised image cannot be masked. - - t (`torch.long`): - The timestep that determines which transition matrices are used. - - x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): - The classes of each latent pixel at time `t` - - generator: (`torch.Generator` or None): - RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. - - return_dict (`bool`): - option for returning tuple rather than VQDiffusionSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. - """ - if timestep == 0: - log_p_x_t_min_1 = model_output - else: - log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) - - log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) - - x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) - - if not return_dict: - return (x_t_min_1,) - - return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) - - def q_posterior(self, log_p_x_0, x_t, t): - """ - Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). - - Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only - forward probabilities. - - Equation (11) stated in terms of forward probabilities via Equation (5): - - Where: - - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0) - - p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) - - Args: - log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): - The log probabilities for the predicted classes of the initial latent pixels. Does not include a - prediction for the masked class as the initial unnoised image cannot be masked. - - x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): - The classes of each latent pixel at time `t` - - t (torch.Long): - The timestep that determines which transition matrix is used. - - Returns: - `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: - The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). - """ - log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) - - log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( - t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True - ) - - log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( - t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False - ) - - # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) - # . . . - # . . . - # . . . - # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) - q = log_p_x_0 - log_q_x_t_given_x_0 - - # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , - # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) - q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) - - # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n - # . . . - # . . . - # . . . - # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n - q = q - q_log_sum_exp - - # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} - # . . . - # . . . - # . . . - # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} - # c_cumulative_{t-1} ... c_cumulative_{t-1} - q = self.apply_cumulative_transitions(q, t - 1) - - # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n - # . . . - # . . . - # . . . - # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n - # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 - log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp - - # For each column, there are two possible cases. - # - # Where: - # - sum(p_n(x_0))) is summing over all classes for x_0 - # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) - # - C_j is the class transitioning to - # - # 1. x_t is masked i.e. x_t = c_k - # - # Simplifying the expression, the column vector is: - # . - # . - # . - # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) - # . - # . - # . - # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) - # - # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. - # - # For the other rows, we can state the equation as ... - # - # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] - # - # This verifies the other rows. - # - # 2. x_t is not masked - # - # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: - # . - # . - # . - # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) - # . - # . - # . - # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) - # . - # . - # . - # 0 - # - # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. - return log_p_x_t_min_1 - - def log_Q_t_transitioning_to_known_class( - self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool - ): - """ - Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each - latent pixel in `x_t`. - - See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix - is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs. - - Args: - t (torch.Long): - The timestep that determines which transition matrix is used. - - x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): - The classes of each latent pixel at time `t`. - - log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): - The log one-hot vectors of `x_t` - - cumulative (`bool`): - If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, - we use the cumulative transition matrix `0`->`t`. - - Returns: - `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: - Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability - transition matrix. - - When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be - masked. - - Where: - - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. - - C_0 is a class of a latent pixel embedding - - C_k is the class of the masked latent pixel - - non-cumulative result (omitting logarithms): - ``` - q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) - . . . - . . . - . . . - q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) - ``` - - cumulative result (omitting logarithms): - ``` - q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) - . . . - . . . - . . . - q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) - ``` - """ - if cumulative: - a = self.log_cumprod_at[t] - b = self.log_cumprod_bt[t] - c = self.log_cumprod_ct[t] - else: - a = self.log_at[t] - b = self.log_bt[t] - c = self.log_ct[t] - - if not cumulative: - # The values in the onehot vector can also be used as the logprobs for transitioning - # from masked latent pixels. If we are not calculating the cumulative transitions, - # we need to save these vectors to be re-appended to the final matrix so the values - # aren't overwritten. - # - # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector - # if x_t is not masked - # - # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector - # if x_t is masked - log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) - - # `index_to_log_onehot` will add onehot vectors for masked pixels, - # so the default one hot matrix has one too many rows. See the doc string - # for an explanation of the dimensionality of the returned matrix. - log_onehot_x_t = log_onehot_x_t[:, :-1, :] - - # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. - # - # Don't worry about what values this sets in the columns that mark transitions - # to masked latent pixels. They are overwrote later with the `mask_class_mask`. - # - # Looking at the below logspace formula in non-logspace, each value will evaluate to either - # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column - # or - # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. - # - # See equation 7 for more details. - log_Q_t = (log_onehot_x_t + a).logaddexp(b) - - # The whole column of each masked pixel is `c` - mask_class_mask = x_t == self.mask_class - mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) - log_Q_t[mask_class_mask] = c - - if not cumulative: - log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) - - return log_Q_t - - def apply_cumulative_transitions(self, q, t): - bsz = q.shape[0] - a = self.log_cumprod_at[t] - b = self.log_cumprod_bt[t] - c = self.log_cumprod_ct[t] - - num_latent_pixels = q.shape[2] - c = c.expand(bsz, 1, num_latent_pixels) - - q = (q + a).logaddexp(b) - q = torch.cat((q, c), dim=1) - - return q diff --git a/my_diffusers/training_utils.py b/my_diffusers/training_utils.py deleted file mode 100644 index 67a8e48d381f93c860c13704ed2968fce1147f0d..0000000000000000000000000000000000000000 --- a/my_diffusers/training_utils.py +++ /dev/null @@ -1,324 +0,0 @@ -import copy -import os -import random -from typing import Any, Dict, Iterable, Optional, Union - -import numpy as np -import torch - -from .utils import deprecate - - -def enable_full_determinism(seed: int): - """ - Helper function for reproducible behavior during distributed training. See - - https://pytorch.org/docs/stable/notes/randomness.html for pytorch - """ - # set seed first - set_seed(seed) - - # Enable PyTorch deterministic mode. This potentially requires either the environment - # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, - # depending on the CUDA version, so we set them both here - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True) - - # Enable CUDNN deterministic mode - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def set_seed(seed: int): - """ - Args: - Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. - seed (`int`): The seed to set. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # ^^ safe to call this function even if cuda is not available - - -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, - parameters: Iterable[torch.nn.Parameter], - decay: float = 0.9999, - min_decay: float = 0.0, - update_after_step: int = 0, - use_ema_warmup: bool = False, - inv_gamma: Union[float, int] = 1.0, - power: Union[float, int] = 2 / 3, - model_cls: Optional[Any] = None, - model_config: Dict[str, Any] = None, - **kwargs, - ): - """ - Args: - parameters (Iterable[torch.nn.Parameter]): The parameters to track. - decay (float): The decay factor for the exponential moving average. - min_decay (float): The minimum decay factor for the exponential moving average. - update_after_step (int): The number of steps to wait before starting to update the EMA weights. - use_ema_warmup (bool): Whether to use EMA warmup. - inv_gamma (float): - Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. - power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. - device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA - weights will be stored on CPU. - - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - """ - - if isinstance(parameters, torch.nn.Module): - deprecation_message = ( - "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " - "Please pass the parameters of the module instead." - ) - deprecate( - "passing a `torch.nn.Module` to `ExponentialMovingAverage`", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - parameters = parameters.parameters() - - # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility - use_ema_warmup = True - - if kwargs.get("max_value", None) is not None: - deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." - deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) - decay = kwargs["max_value"] - - if kwargs.get("min_value", None) is not None: - deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." - deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) - min_decay = kwargs["min_value"] - - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - if kwargs.get("device", None) is not None: - deprecation_message = "The `device` argument is deprecated. Please use `to` instead." - deprecate("device", "1.0.0", deprecation_message, standard_warn=False) - self.to(device=kwargs["device"]) - - self.temp_stored_params = None - - self.decay = decay - self.min_decay = min_decay - self.update_after_step = update_after_step - self.use_ema_warmup = use_ema_warmup - self.inv_gamma = inv_gamma - self.power = power - self.optimization_step = 0 - self.cur_decay_value = None # set in `step()` - - self.model_cls = model_cls - self.model_config = model_config - - @classmethod - def from_pretrained(cls, path, model_cls) -> "EMAModel": - _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) - model = model_cls.from_pretrained(path) - - ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) - - ema_model.load_state_dict(ema_kwargs) - return ema_model - - def save_pretrained(self, path): - if self.model_cls is None: - raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") - - if self.model_config is None: - raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") - - model = self.model_cls.from_config(self.model_config) - state_dict = self.state_dict() - state_dict.pop("shadow_params", None) - - model.register_to_config(**state_dict) - self.copy_to(model.parameters()) - model.save_pretrained(path) - - def get_decay(self, optimization_step: int) -> float: - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - - if step <= 0: - return 0.0 - - if self.use_ema_warmup: - cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power - else: - cur_decay_value = (1 + step) / (10 + step) - - cur_decay_value = min(cur_decay_value, self.decay) - # make sure decay is not smaller than min_decay - cur_decay_value = max(cur_decay_value, self.min_decay) - return cur_decay_value - - @torch.no_grad() - def step(self, parameters: Iterable[torch.nn.Parameter]): - if isinstance(parameters, torch.nn.Module): - deprecation_message = ( - "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " - "Please pass the parameters of the module instead." - ) - deprecate( - "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", - "1.0.0", - deprecation_message, - standard_warn=False, - ) - parameters = parameters.parameters() - - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - decay = self.get_decay(self.optimization_step) - self.cur_decay_value = decay - one_minus_decay = 1 - decay - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the parameters with which this - `ExponentialMovingAverage` was initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.to(param.device).data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during - checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "min_decay": self.min_decay, - "optimization_step": self.optimization_step, - "update_after_step": self.update_after_step, - "use_ema_warmup": self.use_ema_warmup, - "inv_gamma": self.inv_gamma, - "power": self.power, - "shadow_params": self.shadow_params, - } - - def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: - r""" - Args: - Save the current parameters for restoring later. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] - - def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: - r""" - Args: - Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: - affecting the original optimization process. Store the parameters before the `copy_to()` method. After - validation (or model saving), use this to restore the former parameters. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the parameters with which this - `ExponentialMovingAverage` was initialized will be used. - """ - if self.temp_stored_params is None: - raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") - for c_param, param in zip(self.temp_stored_params, parameters): - param.data.copy_(c_param.data) - - # Better memory-wise. - self.temp_stored_params = None - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Args: - Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the - ema state dict. - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict.get("decay", self.decay) - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.min_decay = state_dict.get("min_decay", self.min_decay) - if not isinstance(self.min_decay, float): - raise ValueError("Invalid min_decay") - - self.optimization_step = state_dict.get("optimization_step", self.optimization_step) - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.update_after_step = state_dict.get("update_after_step", self.update_after_step) - if not isinstance(self.update_after_step, int): - raise ValueError("Invalid update_after_step") - - self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) - if not isinstance(self.use_ema_warmup, bool): - raise ValueError("Invalid use_ema_warmup") - - self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) - if not isinstance(self.inv_gamma, (float, int)): - raise ValueError("Invalid inv_gamma") - - self.power = state_dict.get("power", self.power) - if not isinstance(self.power, (float, int)): - raise ValueError("Invalid power") - - shadow_params = state_dict.get("shadow_params", None) - if shadow_params is not None: - self.shadow_params = shadow_params - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") diff --git a/my_diffusers/utils/__init__.py b/my_diffusers/utils/__init__.py deleted file mode 100644 index 64d5c695baecef1f9515c4b0cf31d65c2628f543..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/__init__.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os - -from packaging import version - -from .. import __version__ -from .accelerate_utils import apply_forward_hook -from .constants import ( - CONFIG_NAME, - DEPRECATED_REVISION_ARGS, - DIFFUSERS_CACHE, - DIFFUSERS_DYNAMIC_MODULE_NAME, - FLAX_WEIGHTS_NAME, - HF_MODULES_CACHE, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, - ONNX_EXTERNAL_WEIGHTS_NAME, - ONNX_WEIGHTS_NAME, - SAFETENSORS_WEIGHTS_NAME, - WEIGHTS_NAME, -) -from .deprecation_utils import deprecate -from .doc_utils import replace_example_docstring -from .dynamic_modules_utils import get_class_from_dynamic_module -from .hub_utils import HF_HUB_OFFLINE, http_user_agent -from .import_utils import ( - ENV_VARS_TRUE_AND_AUTO_VALUES, - ENV_VARS_TRUE_VALUES, - USE_JAX, - USE_TF, - USE_TORCH, - DummyObject, - OptionalDependencyNotAvailable, - is_accelerate_available, - is_accelerate_version, - is_flax_available, - is_inflect_available, - is_k_diffusion_available, - is_k_diffusion_version, - is_librosa_available, - is_omegaconf_available, - is_onnx_available, - is_safetensors_available, - is_scipy_available, - is_tensorboard_available, - is_tf_available, - is_torch_available, - is_torch_version, - is_transformers_available, - is_transformers_version, - is_unidecode_available, - is_wandb_available, - is_xformers_available, - requires_backends, -) -from .logging import get_logger -from .outputs import BaseOutput -from .pil_utils import PIL_INTERPOLATION -from .torch_utils import randn_tensor - - -if is_torch_available(): - from .testing_utils import ( - floats_tensor, - load_hf_numpy, - load_image, - load_numpy, - nightly, - parse_flag_from_env, - print_tensor_test, - require_torch_gpu, - skip_mps, - slow, - torch_all_close, - torch_device, - ) - - -logger = get_logger(__name__) - - -def check_min_version(min_version): - if version.parse(__version__) < version.parse(min_version): - if "dev" in min_version: - error_message = ( - "This example requires a source install from HuggingFace diffusers (see " - "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," - ) - else: - error_message = f"This example requires a minimum version of {min_version}," - error_message += f" but the version found is {__version__}.\n" - raise ImportError(error_message) diff --git a/my_diffusers/utils/__pycache__/__init__.cpython-311.pyc b/my_diffusers/utils/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index b80fff1f97d2a8180b2d28cdc7d514ae688f9be6..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/__init__.cpython-39.pyc b/my_diffusers/utils/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 1f2bf1090783379b8bb608ba3929a14b86feb54a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/accelerate_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/accelerate_utils.cpython-311.pyc deleted file mode 100644 index bc18de783e0642107a64c6e80f0139effe84eb43..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/accelerate_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/accelerate_utils.cpython-39.pyc b/my_diffusers/utils/__pycache__/accelerate_utils.cpython-39.pyc deleted file mode 100644 index 81292c306ae044b7757a910bb80b1fb1308c4eb7..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/accelerate_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/constants.cpython-311.pyc b/my_diffusers/utils/__pycache__/constants.cpython-311.pyc deleted file mode 100644 index c901c0fc87f7d2422ce372ce650bf8fdbbca9d94..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/constants.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/constants.cpython-39.pyc b/my_diffusers/utils/__pycache__/constants.cpython-39.pyc deleted file mode 100644 index e997986c0be8f1db8c243abf6a221a35189e71b7..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/constants.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/deprecation_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/deprecation_utils.cpython-311.pyc deleted file mode 100644 index e9d0bae0a14cc2cd72c0f30bcf72a2d232da8e37..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/deprecation_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/deprecation_utils.cpython-39.pyc b/my_diffusers/utils/__pycache__/deprecation_utils.cpython-39.pyc deleted file mode 100644 index 9534cdb46996c18e50561f0e61fae3c8708b8c60..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/deprecation_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/doc_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/doc_utils.cpython-311.pyc deleted file mode 100644 index 554deb0afdaffa3ff5a1e03fef58442fa0992eeb..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/doc_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/doc_utils.cpython-39.pyc b/my_diffusers/utils/__pycache__/doc_utils.cpython-39.pyc deleted file mode 100644 index 469d2736de92fab55d034e77b249ac9e29aa9929..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/doc_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-311.pyc deleted file mode 100644 index 0d8df817106efd51c373573dbef497c5f2e01af7..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_flax_and_transformers_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_flax_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_flax_objects.cpython-311.pyc deleted file mode 100644 index c2ee42d71937d48182a4e900500ab26dd869f54c..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_flax_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_onnx_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_onnx_objects.cpython-311.pyc deleted file mode 100644 index 91b06b659605f7ca1404fc8164d76febfab2106d..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_onnx_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_pt_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_pt_objects.cpython-311.pyc deleted file mode 100644 index b73d5d5e295d4d7ef4d51df5e2b3f489930a9b36..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_pt_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-311.pyc deleted file mode 100644 index 9fb6e4defb2be0c64a177897f57fa54e555cf6bb..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_torch_and_librosa_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_torch_and_scipy_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_torch_and_scipy_objects.cpython-311.pyc deleted file mode 100644 index 46105af795f39fa7957459f0311df28a6a03ec41..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_torch_and_scipy_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-311.pyc deleted file mode 100644 index e5ba9c24a1d692267b36b35b1b0f4fd1eab2b4b1..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_and_k_diffusion_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-311.pyc deleted file mode 100644 index 19ff4e552288a20dda960e42877fa94904965360..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_and_onnx_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-311.pyc b/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-311.pyc deleted file mode 100644 index 9afea8ea47fce0aa9f217baf3a26091104cf0943..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dummy_torch_and_transformers_objects.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dynamic_modules_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/dynamic_modules_utils.cpython-311.pyc deleted file mode 100644 index 07bd8edef802eb79aedd0b5e2fa3ff4bd59bed17..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dynamic_modules_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/dynamic_modules_utils.cpython-39.pyc b/my_diffusers/utils/__pycache__/dynamic_modules_utils.cpython-39.pyc deleted file mode 100644 index b6a4775c8d5584a2573047ec4bd6179c936eaf87..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/dynamic_modules_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/hub_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/hub_utils.cpython-311.pyc deleted file mode 100644 index 0660515b201dcf348c43c18c7c2562adddd5bc2e..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/hub_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/hub_utils.cpython-39.pyc b/my_diffusers/utils/__pycache__/hub_utils.cpython-39.pyc deleted file mode 100644 index a3fdf017a749caf0ccb07db90eb128768a97d94a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/hub_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/import_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/import_utils.cpython-311.pyc deleted file mode 100644 index b6d622851fae60b8fd2515f6032ea9e988559f81..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/import_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/import_utils.cpython-39.pyc b/my_diffusers/utils/__pycache__/import_utils.cpython-39.pyc deleted file mode 100644 index acd30f4e482142831dd194925cd1a04fb1df36b8..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/import_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/logging.cpython-311.pyc b/my_diffusers/utils/__pycache__/logging.cpython-311.pyc deleted file mode 100644 index b494e7b7c35aa3766146c15fc9b7bd0d1834f253..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/logging.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/logging.cpython-39.pyc b/my_diffusers/utils/__pycache__/logging.cpython-39.pyc deleted file mode 100644 index b8faddcad5b0319f79ba7555a0b02af257673db2..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/logging.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/outputs.cpython-311.pyc b/my_diffusers/utils/__pycache__/outputs.cpython-311.pyc deleted file mode 100644 index dc8458c2b967c3a782fb83412b0492262846281a..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/outputs.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/outputs.cpython-39.pyc b/my_diffusers/utils/__pycache__/outputs.cpython-39.pyc deleted file mode 100644 index 0b97bef5f99c1b69cb2550af5ab47091031c23d5..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/outputs.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/pil_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/pil_utils.cpython-311.pyc deleted file mode 100644 index f3c340a604db38d7386c532ccd664163bdab6827..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/pil_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/pil_utils.cpython-39.pyc b/my_diffusers/utils/__pycache__/pil_utils.cpython-39.pyc deleted file mode 100644 index ecc43c374692235bd5a9a2bc62574f3401d661cd..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/pil_utils.cpython-39.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/testing_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/testing_utils.cpython-311.pyc deleted file mode 100644 index 2e36caa74264e40b1c44e781471daa828f396409..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/testing_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/__pycache__/torch_utils.cpython-311.pyc b/my_diffusers/utils/__pycache__/torch_utils.cpython-311.pyc deleted file mode 100644 index aaffd04a77bc0b903da1d696c2baf45f518357b0..0000000000000000000000000000000000000000 Binary files a/my_diffusers/utils/__pycache__/torch_utils.cpython-311.pyc and /dev/null differ diff --git a/my_diffusers/utils/accelerate_utils.py b/my_diffusers/utils/accelerate_utils.py deleted file mode 100644 index 10a83e1dd209cca198f4038d0d7e7228f9671859..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/accelerate_utils.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Accelerate utilities: Utilities related to accelerate -""" - -from packaging import version - -from .import_utils import is_accelerate_available - - -if is_accelerate_available(): - import accelerate - - -def apply_forward_hook(method): - """ - Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful - for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the - appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. - - This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. - - :param method: The method to decorate. This method should be a method of a PyTorch module. - """ - if not is_accelerate_available(): - return method - accelerate_version = version.parse(accelerate.__version__).base_version - if version.parse(accelerate_version) < version.parse("0.17.0"): - return method - - def wrapper(self, *args, **kwargs): - if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): - self._hf_hook.pre_forward(self) - return method(self, *args, **kwargs) - - return wrapper diff --git a/my_diffusers/utils/constants.py b/my_diffusers/utils/constants.py deleted file mode 100644 index b9e60a2a873b29a7d3adffbd7179be1670b3b417..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/constants.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home - - -default_cache_path = HUGGINGFACE_HUB_CACHE - - -CONFIG_NAME = "config.json" -WEIGHTS_NAME = "diffusion_pytorch_model.bin" -FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" -ONNX_WEIGHTS_NAME = "model.onnx" -SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" -ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" -HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" -DIFFUSERS_CACHE = default_cache_path -DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" -HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) -DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] diff --git a/my_diffusers/utils/deprecation_utils.py b/my_diffusers/utils/deprecation_utils.py deleted file mode 100644 index 6bdda664e102ea9913503b9e169fa97225d52c78..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/deprecation_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -import inspect -import warnings -from typing import Any, Dict, Optional, Union - -from packaging import version - - -def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): - from .. import __version__ - - deprecated_kwargs = take_from - values = () - if not isinstance(args[0], tuple): - args = (args,) - - for attribute, version_name, message in args: - if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): - raise ValueError( - f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" - f" version {__version__} is >= {version_name}" - ) - - warning = None - if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: - values += (deprecated_kwargs.pop(attribute),) - warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." - elif hasattr(deprecated_kwargs, attribute): - values += (getattr(deprecated_kwargs, attribute),) - warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." - elif deprecated_kwargs is None: - warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." - - if warning is not None: - warning = warning + " " if standard_warn else "" - warnings.warn(warning + message, FutureWarning, stacklevel=2) - - if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: - call_frame = inspect.getouterframes(inspect.currentframe())[1] - filename = call_frame.filename - line_number = call_frame.lineno - function = call_frame.function - key, value = next(iter(deprecated_kwargs.items())) - raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") - - if len(values) == 0: - return - elif len(values) == 1: - return values[0] - return values diff --git a/my_diffusers/utils/doc_utils.py b/my_diffusers/utils/doc_utils.py deleted file mode 100644 index f1f87743f99802931334bd51bf99985775116d59..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/doc_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Doc utilities: Utilities related to documentation -""" -import re - - -def replace_example_docstring(example_docstring): - def docstring_decorator(fn): - func_doc = fn.__doc__ - lines = func_doc.split("\n") - i = 0 - while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: - i += 1 - if i < len(lines): - lines[i] = example_docstring - func_doc = "\n".join(lines) - else: - raise ValueError( - f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " - f"current docstring is:\n{func_doc}" - ) - fn.__doc__ = func_doc - return fn - - return docstring_decorator diff --git a/my_diffusers/utils/dummy_flax_and_transformers_objects.py b/my_diffusers/utils/dummy_flax_and_transformers_objects.py deleted file mode 100644 index 5db4c7d58d1e9c17c8824c1d24edf88e44799eba..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_flax_and_transformers_objects.py +++ /dev/null @@ -1,47 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): - _backends = ["flax", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax", "transformers"]) - - -class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): - _backends = ["flax", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax", "transformers"]) - - -class FlaxStableDiffusionPipeline(metaclass=DummyObject): - _backends = ["flax", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax", "transformers"]) diff --git a/my_diffusers/utils/dummy_flax_objects.py b/my_diffusers/utils/dummy_flax_objects.py deleted file mode 100644 index 7772c1a06b49dc970a82243295106c6c01595d72..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_flax_objects.py +++ /dev/null @@ -1,182 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class FlaxModelMixin(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxUNet2DConditionModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxAutoencoderKL(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxDiffusionPipeline(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxDDIMScheduler(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxDDPMScheduler(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxKarrasVeScheduler(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxLMSDiscreteScheduler(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxPNDMScheduler(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxSchedulerMixin(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - -class FlaxScoreSdeVeScheduler(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["flax"]) diff --git a/my_diffusers/utils/dummy_onnx_objects.py b/my_diffusers/utils/dummy_onnx_objects.py deleted file mode 100644 index bde5f6ad0793e2d81bc638600b46ff81748d09ee..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_onnx_objects.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class OnnxRuntimeModel(metaclass=DummyObject): - _backends = ["onnx"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["onnx"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["onnx"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["onnx"]) diff --git a/my_diffusers/utils/dummy_pt_objects.py b/my_diffusers/utils/dummy_pt_objects.py deleted file mode 100644 index c731a1f1ddf3bbebef88b4776564758b36fd11c3..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_pt_objects.py +++ /dev/null @@ -1,675 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class AutoencoderKL(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class ControlNetModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class ModelMixin(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class PriorTransformer(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class Transformer2DModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class UNet1DModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class UNet2DConditionModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class UNet2DModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class VQModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -def get_constant_schedule(*args, **kwargs): - requires_backends(get_constant_schedule, ["torch"]) - - -def get_constant_schedule_with_warmup(*args, **kwargs): - requires_backends(get_constant_schedule_with_warmup, ["torch"]) - - -def get_cosine_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_schedule_with_warmup, ["torch"]) - - -def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) - - -def get_linear_schedule_with_warmup(*args, **kwargs): - requires_backends(get_linear_schedule_with_warmup, ["torch"]) - - -def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): - requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) - - -def get_scheduler(*args, **kwargs): - requires_backends(get_scheduler, ["torch"]) - - -class AudioPipelineOutput(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DanceDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DDIMPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DDPMPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DiffusionPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DiTPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class ImagePipelineOutput(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class KarrasVePipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class LDMPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class LDMSuperResolutionPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class PNDMPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class RePaintPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class ScoreSdeVePipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DDIMInverseScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DDIMScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DDPMScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DEISMultistepScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DPMSolverMultistepScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class DPMSolverSinglestepScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class EulerAncestralDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class EulerDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class HeunDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class IPNDMScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class KarrasVeScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class KDPM2DiscreteScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class PNDMScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class RePaintScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class SchedulerMixin(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class ScoreSdeVeScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class UnCLIPScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class UniPCMultistepScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class VQDiffusionScheduler(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - -class EMAModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) diff --git a/my_diffusers/utils/dummy_torch_and_librosa_objects.py b/my_diffusers/utils/dummy_torch_and_librosa_objects.py deleted file mode 100644 index 2088bc4a744198284f22fe54e6f1055cf3568566..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_torch_and_librosa_objects.py +++ /dev/null @@ -1,32 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class AudioDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "librosa"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "librosa"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "librosa"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "librosa"]) - - -class Mel(metaclass=DummyObject): - _backends = ["torch", "librosa"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "librosa"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "librosa"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "librosa"]) diff --git a/my_diffusers/utils/dummy_torch_and_scipy_objects.py b/my_diffusers/utils/dummy_torch_and_scipy_objects.py deleted file mode 100644 index a1ff25863822b04971d2c6dfdc17f5b28774cf05..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_torch_and_scipy_objects.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class LMSDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch", "scipy"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "scipy"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "scipy"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "scipy"]) diff --git a/my_diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py b/my_diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py deleted file mode 100644 index 56836f0b6d77b8daa25e956101694863e418339f..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "k_diffusion"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "k_diffusion"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "k_diffusion"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "k_diffusion"]) diff --git a/my_diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/my_diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py deleted file mode 100644 index 204500a1f195790aabf4a0136de0f0900faec5c9..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py +++ /dev/null @@ -1,77 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "onnx"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "onnx"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - -class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "onnx"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "onnx"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - -class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): - _backends = ["torch", "transformers", "onnx"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "onnx"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - -class OnnxStableDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "onnx"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "onnx"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - -class StableDiffusionOnnxPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "onnx"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "onnx"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "onnx"]) diff --git a/my_diffusers/utils/dummy_torch_and_transformers_objects.py b/my_diffusers/utils/dummy_torch_and_transformers_objects.py deleted file mode 100644 index 1b0f812ad16cea87758a38d71c48ced83dc94294..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dummy_torch_and_transformers_objects.py +++ /dev/null @@ -1,452 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class AltDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class CycleDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class LDMTextToImagePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class PaintByExamplePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class SemanticStableDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionControlNetPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionImageVariationPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionInpaintPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionPanoramaPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionPipelineSafe(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionSAGPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableDiffusionUpscalePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class StableUnCLIPPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class UnCLIPImageVariationPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class UnCLIPPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class VersatileDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - -class VQDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) diff --git a/my_diffusers/utils/dynamic_modules_utils.py b/my_diffusers/utils/dynamic_modules_utils.py deleted file mode 100644 index 1951c4fa2623b6b14b85c035395a738cdd733eea..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/dynamic_modules_utils.py +++ /dev/null @@ -1,456 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities to dynamically load objects from the Hub.""" - -import importlib -import inspect -import json -import os -import re -import shutil -import sys -from distutils.version import StrictVersion -from pathlib import Path -from typing import Dict, Optional, Union -from urllib import request - -from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info - -from .. import __version__ -from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging - - -COMMUNITY_PIPELINES_URL = ( - "https://raw.githubusercontent.com/huggingface/diffusers/{revision}/examples/community/{pipeline}.py" -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def get_diffusers_versions(): - url = "https://pypi.org/pypi/diffusers/json" - releases = json.loads(request.urlopen(url).read())["releases"].keys() - return sorted(releases, key=StrictVersion) - - -def init_hf_modules(): - """ - Creates the cache directory for modules with an init, and adds it to the Python path. - """ - # This function has already been executed if HF_MODULES_CACHE already is in the Python path. - if HF_MODULES_CACHE in sys.path: - return - - sys.path.append(HF_MODULES_CACHE) - os.makedirs(HF_MODULES_CACHE, exist_ok=True) - init_path = Path(HF_MODULES_CACHE) / "__init__.py" - if not init_path.exists(): - init_path.touch() - - -def create_dynamic_module(name: Union[str, os.PathLike]): - """ - Creates a dynamic module in the cache directory for modules. - """ - init_hf_modules() - dynamic_module_path = Path(HF_MODULES_CACHE) / name - # If the parent module does not exist yet, recursively create it. - if not dynamic_module_path.parent.exists(): - create_dynamic_module(dynamic_module_path.parent) - os.makedirs(dynamic_module_path, exist_ok=True) - init_path = dynamic_module_path / "__init__.py" - if not init_path.exists(): - init_path.touch() - - -def get_relative_imports(module_file): - """ - Get the list of modules that are relatively imported in a module file. - - Args: - module_file (`str` or `os.PathLike`): The module file to inspect. - """ - with open(module_file, "r", encoding="utf-8") as f: - content = f.read() - - # Imports of the form `import .xxx` - relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) - # Imports of the form `from .xxx import yyy` - relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) - # Unique-ify - return list(set(relative_imports)) - - -def get_relative_import_files(module_file): - """ - Get the list of all files that are needed for a given module. Note that this function recurses through the relative - imports (if a imports b and b imports c, it will return module files for b and c). - - Args: - module_file (`str` or `os.PathLike`): The module file to inspect. - """ - no_change = False - files_to_check = [module_file] - all_relative_imports = [] - - # Let's recurse through all relative imports - while not no_change: - new_imports = [] - for f in files_to_check: - new_imports.extend(get_relative_imports(f)) - - module_path = Path(module_file).parent - new_import_files = [str(module_path / m) for m in new_imports] - new_import_files = [f for f in new_import_files if f not in all_relative_imports] - files_to_check = [f"{f}.py" for f in new_import_files] - - no_change = len(new_import_files) == 0 - all_relative_imports.extend(files_to_check) - - return all_relative_imports - - -def check_imports(filename): - """ - Check if the current Python environment contains all the libraries that are imported in a file. - """ - with open(filename, "r", encoding="utf-8") as f: - content = f.read() - - # Imports of the form `import xxx` - imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) - # Imports of the form `from xxx import yyy` - imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) - # Only keep the top-level module - imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] - - # Unique-ify and test we got them all - imports = list(set(imports)) - missing_packages = [] - for imp in imports: - try: - importlib.import_module(imp) - except ImportError: - missing_packages.append(imp) - - if len(missing_packages) > 0: - raise ImportError( - "This modeling file requires the following packages that were not found in your environment: " - f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" - ) - - return get_relative_imports(filename) - - -def get_class_in_module(class_name, module_path): - """ - Import a module on the cache directory for modules and extract a class from it. - """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) - - if class_name is None: - return find_pipeline_class(module) - return getattr(module, class_name) - - -def find_pipeline_class(loaded_module): - """ - Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class - inheriting from `DiffusionPipeline`. - """ - from ..pipelines import DiffusionPipeline - - cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) - - pipeline_class = None - for cls_name, cls in cls_members.items(): - if ( - cls_name != DiffusionPipeline.__name__ - and issubclass(cls, DiffusionPipeline) - and cls.__module__.split(".")[0] != "diffusers" - ): - if pipeline_class is not None: - raise ValueError( - f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:" - f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in" - f" {loaded_module}." - ) - pipeline_class = cls - - return pipeline_class - - -def get_cached_module_file( - pretrained_model_name_or_path: Union[str, os.PathLike], - module_file: str, - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - resume_download: bool = False, - proxies: Optional[Dict[str, str]] = None, - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, -): - """ - Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached - Transformers module. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - This can be either: - - - a string, the *model id* of a pretrained model configuration hosted inside a model repo on - huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced - under a user or organization name, like `dbmdz/bert-base-german-cased`. - - a path to a *directory* containing a configuration file saved using the - [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. - - module_file (`str`): - The name of the module file containing the class to look for. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the standard - cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force to (re-)download the configuration files and override the cached versions if they - exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - local_files_only (`bool`, *optional*, defaults to `False`): - If `True`, will only try to load the tokenizer configuration from local files. - - - - You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private - or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - Returns: - `str`: The path to the module inside the cache. - """ - # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - - module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) - - if os.path.isfile(module_file_or_url): - resolved_module_file = module_file_or_url - submodule = "local" - elif pretrained_model_name_or_path.count("/") == 0: - available_versions = get_diffusers_versions() - # cut ".dev0" - latest_version = "v" + ".".join(__version__.split(".")[:3]) - - # retrieve github version that matches - if revision is None: - revision = latest_version if latest_version in available_versions else "main" - logger.info(f"Defaulting to latest_version: {revision}.") - elif revision in available_versions: - revision = f"v{revision}" - elif revision == "main": - revision = revision - else: - raise ValueError( - f"`custom_revision`: {revision} does not exist. Please make sure to choose one of" - f" {', '.join(available_versions + ['main'])}." - ) - - # community pipeline on GitHub - github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path) - try: - resolved_module_file = cached_download( - github_url, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=False, - ) - submodule = "git" - module_file = pretrained_model_name_or_path + ".py" - except EnvironmentError: - logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") - raise - else: - try: - # Load from URL or cache if already cached - resolved_module_file = hf_hub_download( - pretrained_model_name_or_path, - module_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - ) - submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) - except EnvironmentError: - logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") - raise - - # Check we have all the requirements in our environment - modules_needed = check_imports(resolved_module_file) - - # Now we move the module inside our cached dynamic modules. - full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule - create_dynamic_module(full_submodule) - submodule_path = Path(HF_MODULES_CACHE) / full_submodule - if submodule == "local" or submodule == "git": - # We always copy local files (we could hash the file to see if there was a change, and give them the name of - # that hash, to only copy when there is a modification but it seems overkill for now). - # The only reason we do the copy is to avoid putting too many folders in sys.path. - shutil.copy(resolved_module_file, submodule_path / module_file) - for module_needed in modules_needed: - module_needed = f"{module_needed}.py" - shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) - else: - # Get the commit hash - # TODO: we will get this info in the etag soon, so retrieve it from there and not here. - if isinstance(use_auth_token, str): - token = use_auth_token - elif use_auth_token is True: - token = HfFolder.get_token() - else: - token = None - - commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha - - # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the - # benefit of versioning. - submodule_path = submodule_path / commit_hash - full_submodule = full_submodule + os.path.sep + commit_hash - create_dynamic_module(full_submodule) - - if not (submodule_path / module_file).exists(): - shutil.copy(resolved_module_file, submodule_path / module_file) - # Make sure we also have every file with relative - for module_needed in modules_needed: - if not (submodule_path / module_needed).exists(): - get_cached_module_file( - pretrained_model_name_or_path, - f"{module_needed}.py", - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - use_auth_token=use_auth_token, - revision=revision, - local_files_only=local_files_only, - ) - return os.path.join(full_submodule, module_file) - - -def get_class_from_dynamic_module( - pretrained_model_name_or_path: Union[str, os.PathLike], - module_file: str, - class_name: Optional[str] = None, - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - resume_download: bool = False, - proxies: Optional[Dict[str, str]] = None, - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, - **kwargs, -): - """ - Extracts a class from a module file, present in the local folder or repository of a model. - - - - Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should - therefore only be called on trusted repos. - - - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - This can be either: - - - a string, the *model id* of a pretrained model configuration hosted inside a model repo on - huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced - under a user or organization name, like `dbmdz/bert-base-german-cased`. - - a path to a *directory* containing a configuration file saved using the - [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. - - module_file (`str`): - The name of the module file containing the class to look for. - class_name (`str`): - The name of the class to import in the module. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the standard - cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force to (re-)download the configuration files and override the cached versions if they - exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - use_auth_token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - local_files_only (`bool`, *optional*, defaults to `False`): - If `True`, will only try to load the tokenizer configuration from local files. - - - - You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private - or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). - - - - Returns: - `type`: The class, dynamically imported from the module. - - Examples: - - ```python - # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this - # module. - cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") - ```""" - # And lastly we get the class inside our newly created module - final_module = get_cached_module_file( - pretrained_model_name_or_path, - module_file, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - use_auth_token=use_auth_token, - revision=revision, - local_files_only=local_files_only, - ) - return get_class_in_module(class_name, final_module.replace(".py", "")) diff --git a/my_diffusers/utils/hub_utils.py b/my_diffusers/utils/hub_utils.py deleted file mode 100644 index 6ddbac3669abddb89911cb437bf54bb19b26360e..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/hub_utils.py +++ /dev/null @@ -1,201 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import sys -import traceback -from pathlib import Path -from typing import Dict, Optional, Union -from uuid import uuid4 - -from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami -from huggingface_hub.utils import is_jinja_available - -from .. import __version__ -from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT -from .import_utils import ( - ENV_VARS_TRUE_VALUES, - _flax_version, - _jax_version, - _onnxruntime_version, - _torch_version, - is_flax_available, - is_onnx_available, - is_torch_available, -) -from .logging import get_logger - - -logger = get_logger(__name__) - - -MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" -SESSION_ID = uuid4().hex -HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES -DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES -HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/" - - -def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: - """ - Formats a user-agent string with basic info about a request. - """ - ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" - if DISABLE_TELEMETRY or HF_HUB_OFFLINE: - return ua + "; telemetry/off" - if is_torch_available(): - ua += f"; torch/{_torch_version}" - if is_flax_available(): - ua += f"; jax/{_jax_version}" - ua += f"; flax/{_flax_version}" - if is_onnx_available(): - ua += f"; onnxruntime/{_onnxruntime_version}" - # CI will set this value to True - if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: - ua += "; is_ci/true" - if isinstance(user_agent, dict): - ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) - elif isinstance(user_agent, str): - ua += "; " + user_agent - return ua - - -def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): - if token is None: - token = HfFolder.get_token() - if organization is None: - username = whoami(token)["name"] - return f"{username}/{model_id}" - else: - return f"{organization}/{model_id}" - - -def create_model_card(args, model_name): - if not is_jinja_available(): - raise ValueError( - "Modelcard rendering is based on Jinja templates." - " Please make sure to have `jinja` installed before using `create_model_card`." - " To install it, please run `pip install Jinja2`." - ) - - if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: - return - - hub_token = args.hub_token if hasattr(args, "hub_token") else None - repo_name = get_full_repo_name(model_name, token=hub_token) - - model_card = ModelCard.from_template( - card_data=ModelCardData( # Card metadata object that will be converted to YAML block - language="en", - license="apache-2.0", - library_name="diffusers", - tags=[], - datasets=args.dataset_name, - metrics=[], - ), - template_path=MODEL_CARD_TEMPLATE_PATH, - model_name=model_name, - repo_name=repo_name, - dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, - learning_rate=args.learning_rate, - train_batch_size=args.train_batch_size, - eval_batch_size=args.eval_batch_size, - gradient_accumulation_steps=( - args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None - ), - adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, - adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, - adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, - adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, - lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, - lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, - ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, - ema_power=args.ema_power if hasattr(args, "ema_power") else None, - ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, - mixed_precision=args.mixed_precision, - ) - - card_path = os.path.join(args.output_dir, "README.md") - model_card.save(card_path) - - -# Old default cache path, potentially to be migrated. -# This logic was more or less taken from `transformers`, with the following differences: -# - Diffusers doesn't use custom environment variables to specify the cache path. -# - There is no need to migrate the cache format, just move the files to the new location. -hf_cache_home = os.path.expanduser( - os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) -) -old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") - - -def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: - if new_cache_dir is None: - new_cache_dir = DIFFUSERS_CACHE - if old_cache_dir is None: - old_cache_dir = old_diffusers_cache - - old_cache_dir = Path(old_cache_dir).expanduser() - new_cache_dir = Path(new_cache_dir).expanduser() - for old_blob_path in old_cache_dir.glob("**/blobs/*"): # move file blob by blob - if old_blob_path.is_file() and not old_blob_path.is_symlink(): - new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) - new_blob_path.parent.mkdir(parents=True, exist_ok=True) - os.replace(old_blob_path, new_blob_path) - try: - os.symlink(new_blob_path, old_blob_path) - except OSError: - logger.warning( - "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." - ) - # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). - - -cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt") -if not os.path.isfile(cache_version_file): - cache_version = 0 -else: - with open(cache_version_file) as f: - cache_version = int(f.read()) - -if cache_version < 1: - old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 - if old_cache_is_not_empty: - logger.warning( - "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " - "existing cached models. This is a one-time operation, you can interrupt it or run it " - "later by calling `diffusers.utils.hub_utils.move_cache()`." - ) - try: - move_cache() - except Exception as e: - trace = "\n".join(traceback.format_tb(e.__traceback__)) - logger.error( - f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " - "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " - "message and we will do our best to help." - ) - -if cache_version < 1: - try: - os.makedirs(DIFFUSERS_CACHE, exist_ok=True) - with open(cache_version_file, "w") as f: - f.write("1") - except Exception: - logger.warning( - f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " - "the directory exists and can be written to." - ) diff --git a/my_diffusers/utils/import_utils.py b/my_diffusers/utils/import_utils.py deleted file mode 100644 index 38cca035fb31092c578370bb76650072e0d6ca6f..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/import_utils.py +++ /dev/null @@ -1,508 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Import utilities: Utilities related to imports and our lazy inits. -""" -import importlib.util -import operator as op -import os -import sys -from collections import OrderedDict -from typing import Union - -from huggingface_hub.utils import is_jinja_available # noqa: F401 -from packaging import version -from packaging.version import Version, parse - -from . import logging - - -# The package importlib_metadata is in a different place, depending on the python version. -if sys.version_info < (3, 8): - import importlib_metadata -else: - import importlib.metadata as importlib_metadata - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} -ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) - -USE_TF = os.environ.get("USE_TF", "AUTO").upper() -USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() -USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() -USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() - -STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} - -_torch_version = "N/A" -if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - _torch_available = importlib.util.find_spec("torch") is not None - if _torch_available: - try: - _torch_version = importlib_metadata.version("torch") - logger.info(f"PyTorch version {_torch_version} available.") - except importlib_metadata.PackageNotFoundError: - _torch_available = False -else: - logger.info("Disabling PyTorch because USE_TORCH is set") - _torch_available = False - - -_tf_version = "N/A" -if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: - _tf_available = importlib.util.find_spec("tensorflow") is not None - if _tf_available: - candidates = ( - "tensorflow", - "tensorflow-cpu", - "tensorflow-gpu", - "tf-nightly", - "tf-nightly-cpu", - "tf-nightly-gpu", - "intel-tensorflow", - "intel-tensorflow-avx512", - "tensorflow-rocm", - "tensorflow-macos", - "tensorflow-aarch64", - ) - _tf_version = None - # For the metadata, we have to look for both tensorflow and tensorflow-cpu - for pkg in candidates: - try: - _tf_version = importlib_metadata.version(pkg) - break - except importlib_metadata.PackageNotFoundError: - pass - _tf_available = _tf_version is not None - if _tf_available: - if version.parse(_tf_version) < version.parse("2"): - logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") - _tf_available = False - else: - logger.info(f"TensorFlow version {_tf_version} available.") -else: - logger.info("Disabling Tensorflow because USE_TORCH is set") - _tf_available = False - -_jax_version = "N/A" -_flax_version = "N/A" -if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: - _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None - if _flax_available: - try: - _jax_version = importlib_metadata.version("jax") - _flax_version = importlib_metadata.version("flax") - logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") - except importlib_metadata.PackageNotFoundError: - _flax_available = False -else: - _flax_available = False - -if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: - _safetensors_available = importlib.util.find_spec("safetensors") is not None - if _safetensors_available: - try: - _safetensors_version = importlib_metadata.version("safetensors") - logger.info(f"Safetensors version {_safetensors_version} available.") - except importlib_metadata.PackageNotFoundError: - _safetensors_available = False -else: - logger.info("Disabling Safetensors because USE_TF is set") - _safetensors_available = False - -_transformers_available = importlib.util.find_spec("transformers") is not None -try: - _transformers_version = importlib_metadata.version("transformers") - logger.debug(f"Successfully imported transformers version {_transformers_version}") -except importlib_metadata.PackageNotFoundError: - _transformers_available = False - - -_inflect_available = importlib.util.find_spec("inflect") is not None -try: - _inflect_version = importlib_metadata.version("inflect") - logger.debug(f"Successfully imported inflect version {_inflect_version}") -except importlib_metadata.PackageNotFoundError: - _inflect_available = False - - -_unidecode_available = importlib.util.find_spec("unidecode") is not None -try: - _unidecode_version = importlib_metadata.version("unidecode") - logger.debug(f"Successfully imported unidecode version {_unidecode_version}") -except importlib_metadata.PackageNotFoundError: - _unidecode_available = False - - -_onnxruntime_version = "N/A" -_onnx_available = importlib.util.find_spec("onnxruntime") is not None -if _onnx_available: - candidates = ( - "onnxruntime", - "onnxruntime-gpu", - "onnxruntime-directml", - "onnxruntime-openvino", - "ort_nightly_directml", - ) - _onnxruntime_version = None - # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu - for pkg in candidates: - try: - _onnxruntime_version = importlib_metadata.version(pkg) - break - except importlib_metadata.PackageNotFoundError: - pass - _onnx_available = _onnxruntime_version is not None - if _onnx_available: - logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") - - -_scipy_available = importlib.util.find_spec("scipy") is not None -try: - _scipy_version = importlib_metadata.version("scipy") - logger.debug(f"Successfully imported scipy version {_scipy_version}") -except importlib_metadata.PackageNotFoundError: - _scipy_available = False - -_librosa_available = importlib.util.find_spec("librosa") is not None -try: - _librosa_version = importlib_metadata.version("librosa") - logger.debug(f"Successfully imported librosa version {_librosa_version}") -except importlib_metadata.PackageNotFoundError: - _librosa_available = False - -_accelerate_available = importlib.util.find_spec("accelerate") is not None -try: - _accelerate_version = importlib_metadata.version("accelerate") - logger.debug(f"Successfully imported accelerate version {_accelerate_version}") -except importlib_metadata.PackageNotFoundError: - _accelerate_available = False - -_xformers_available = importlib.util.find_spec("xformers") is not None -try: - _xformers_version = importlib_metadata.version("xformers") - if _torch_available: - import torch - - if version.Version(torch.__version__) < version.Version("1.12"): - raise ValueError("PyTorch should be >= 1.12") - logger.debug(f"Successfully imported xformers version {_xformers_version}") -except importlib_metadata.PackageNotFoundError: - _xformers_available = False - -_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None -try: - _k_diffusion_version = importlib_metadata.version("k_diffusion") - logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") -except importlib_metadata.PackageNotFoundError: - _k_diffusion_available = False - -_wandb_available = importlib.util.find_spec("wandb") is not None -try: - _wandb_version = importlib_metadata.version("wandb") - logger.debug(f"Successfully imported wandb version {_wandb_version }") -except importlib_metadata.PackageNotFoundError: - _wandb_available = False - -_omegaconf_available = importlib.util.find_spec("omegaconf") is not None -try: - _omegaconf_version = importlib_metadata.version("omegaconf") - logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}") -except importlib_metadata.PackageNotFoundError: - _omegaconf_available = False - -_tensorboard_available = importlib.util.find_spec("tensorboard") -try: - _tensorboard_version = importlib_metadata.version("tensorboard") - logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") -except importlib_metadata.PackageNotFoundError: - _tensorboard_available = False - - -def is_torch_available(): - return _torch_available - - -def is_safetensors_available(): - return _safetensors_available - - -def is_tf_available(): - return _tf_available - - -def is_flax_available(): - return _flax_available - - -def is_transformers_available(): - return _transformers_available - - -def is_inflect_available(): - return _inflect_available - - -def is_unidecode_available(): - return _unidecode_available - - -def is_onnx_available(): - return _onnx_available - - -def is_scipy_available(): - return _scipy_available - - -def is_librosa_available(): - return _librosa_available - - -def is_xformers_available(): - return _xformers_available - - -def is_accelerate_available(): - return _accelerate_available - - -def is_k_diffusion_available(): - return _k_diffusion_available - - -def is_wandb_available(): - return _wandb_available - - -def is_omegaconf_available(): - return _omegaconf_available - - -def is_tensorboard_available(): - return _tensorboard_available - - -# docstyle-ignore -FLAX_IMPORT_ERROR = """ -{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the -installation page: https://github.com/google/flax and follow the ones that match your environment. -""" - -# docstyle-ignore -INFLECT_IMPORT_ERROR = """ -{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install -inflect` -""" - -# docstyle-ignore -PYTORCH_IMPORT_ERROR = """ -{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the -installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. -""" - -# docstyle-ignore -ONNX_IMPORT_ERROR = """ -{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip -install onnxruntime` -""" - -# docstyle-ignore -SCIPY_IMPORT_ERROR = """ -{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install -scipy` -""" - -# docstyle-ignore -LIBROSA_IMPORT_ERROR = """ -{0} requires the librosa library but it was not found in your environment. Checkout the instructions on the -installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. -""" - -# docstyle-ignore -TRANSFORMERS_IMPORT_ERROR = """ -{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip -install transformers` -""" - -# docstyle-ignore -UNIDECODE_IMPORT_ERROR = """ -{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install -Unidecode` -""" - -# docstyle-ignore -K_DIFFUSION_IMPORT_ERROR = """ -{0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip -install k-diffusion` -""" - -# docstyle-ignore -WANDB_IMPORT_ERROR = """ -{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip -install wandb` -""" - -# docstyle-ignore -OMEGACONF_IMPORT_ERROR = """ -{0} requires the omegaconf library but it was not found in your environment. You can install it with pip: `pip -install omegaconf` -""" - -# docstyle-ignore -TENSORBOARD_IMPORT_ERROR = """ -{0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip -install tensorboard` -""" - -BACKENDS_MAPPING = OrderedDict( - [ - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), - ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), - ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), - ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), - ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), - ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), - ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), - ] -) - - -def requires_backends(obj, backends): - if not isinstance(backends, (list, tuple)): - backends = [backends] - - name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] - if failed: - raise ImportError("".join(failed)) - - if name in [ - "VersatileDiffusionTextToImagePipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionDualGuidedPipeline", - "StableDiffusionImageVariationPipeline", - "UnCLIPPipeline", - ] and is_transformers_version("<", "4.25.0"): - raise ImportError( - f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" - " --upgrade transformers \n```" - ) - - if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( - "<", "4.26.0" - ): - raise ImportError( - f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" - " --upgrade transformers \n```" - ) - - -class DummyObject(type): - """ - Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by - `requires_backend` each time a user tries to access any method of that class. - """ - - def __getattr__(cls, key): - if key.startswith("_"): - return super().__getattr__(cls, key) - requires_backends(cls, cls._backends) - - -# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 -def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): - """ - Args: - Compares a library version to some requirement using a given operation. - library_or_version (`str` or `packaging.version.Version`): - A library name or a version to check. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="`. - requirement_version (`str`): - The version to compare the library version against - """ - if operation not in STR_OPERATION_TO_FUNC.keys(): - raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") - operation = STR_OPERATION_TO_FUNC[operation] - if isinstance(library_or_version, str): - library_or_version = parse(importlib_metadata.version(library_or_version)) - return operation(library_or_version, parse(requirement_version)) - - -# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 -def is_torch_version(operation: str, version: str): - """ - Args: - Compares the current PyTorch version to a given reference with an operation. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A string version of PyTorch - """ - return compare_versions(parse(_torch_version), operation, version) - - -def is_transformers_version(operation: str, version: str): - """ - Args: - Compares the current Transformers version to a given reference with an operation. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A version string - """ - if not _transformers_available: - return False - return compare_versions(parse(_transformers_version), operation, version) - - -def is_accelerate_version(operation: str, version: str): - """ - Args: - Compares the current Accelerate version to a given reference with an operation. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A version string - """ - if not _accelerate_available: - return False - return compare_versions(parse(_accelerate_version), operation, version) - - -def is_k_diffusion_version(operation: str, version: str): - """ - Args: - Compares the current k-diffusion version to a given reference with an operation. - operation (`str`): - A string representation of an operator, such as `">"` or `"<="` - version (`str`): - A version string - """ - if not _k_diffusion_available: - return False - return compare_versions(parse(_k_diffusion_version), operation, version) - - -class OptionalDependencyNotAvailable(BaseException): - """An error indicating that an optional dependency of Diffusers was not found in the environment.""" diff --git a/my_diffusers/utils/logging.py b/my_diffusers/utils/logging.py deleted file mode 100644 index 3308d117e994d95d5dd7cb494d88512a61847fd6..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/logging.py +++ /dev/null @@ -1,342 +0,0 @@ -# coding=utf-8 -# Copyright 2023 Optuna, Hugging Face -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Logging utilities.""" - -import logging -import os -import sys -import threading -from logging import ( - CRITICAL, # NOQA - DEBUG, # NOQA - ERROR, # NOQA - FATAL, # NOQA - INFO, # NOQA - NOTSET, # NOQA - WARN, # NOQA - WARNING, # NOQA -) -from typing import Optional - -from tqdm import auto as tqdm_lib - - -_lock = threading.Lock() -_default_handler: Optional[logging.Handler] = None - -log_levels = { - "debug": logging.DEBUG, - "info": logging.INFO, - "warning": logging.WARNING, - "error": logging.ERROR, - "critical": logging.CRITICAL, -} - -_default_log_level = logging.WARNING - -_tqdm_active = True - - -def _get_default_logging_level(): - """ - If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is - not - fall back to `_default_log_level` - """ - env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) - if env_level_str: - if env_level_str in log_levels: - return log_levels[env_level_str] - else: - logging.getLogger().warning( - f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " - f"has to be one of: { ', '.join(log_levels.keys()) }" - ) - return _default_log_level - - -def _get_library_name() -> str: - return __name__.split(".")[0] - - -def _get_library_root_logger() -> logging.Logger: - return logging.getLogger(_get_library_name()) - - -def _configure_library_root_logger() -> None: - global _default_handler - - with _lock: - if _default_handler: - # This library has already configured the library root logger. - return - _default_handler = logging.StreamHandler() # Set sys.stderr as stream. - _default_handler.flush = sys.stderr.flush - - # Apply our default configuration to the library root logger. - library_root_logger = _get_library_root_logger() - library_root_logger.addHandler(_default_handler) - library_root_logger.setLevel(_get_default_logging_level()) - library_root_logger.propagate = False - - -def _reset_library_root_logger() -> None: - global _default_handler - - with _lock: - if not _default_handler: - return - - library_root_logger = _get_library_root_logger() - library_root_logger.removeHandler(_default_handler) - library_root_logger.setLevel(logging.NOTSET) - _default_handler = None - - -def get_log_levels_dict(): - return log_levels - - -def get_logger(name: Optional[str] = None) -> logging.Logger: - """ - Return a logger with the specified name. - - This function is not supposed to be directly accessed unless you are writing a custom diffusers module. - """ - - if name is None: - name = _get_library_name() - - _configure_library_root_logger() - return logging.getLogger(name) - - -def get_verbosity() -> int: - """ - Return the current level for the 🤗 Diffusers' root logger as an int. - - Returns: - `int`: The logging level. - - - - 🤗 Diffusers has following logging levels: - - - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - - 40: `diffusers.logging.ERROR` - - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN` - - 20: `diffusers.logging.INFO` - - 10: `diffusers.logging.DEBUG` - - """ - - _configure_library_root_logger() - return _get_library_root_logger().getEffectiveLevel() - - -def set_verbosity(verbosity: int) -> None: - """ - Set the verbosity level for the 🤗 Diffusers' root logger. - - Args: - verbosity (`int`): - Logging level, e.g., one of: - - - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - - `diffusers.logging.ERROR` - - `diffusers.logging.WARNING` or `diffusers.logging.WARN` - - `diffusers.logging.INFO` - - `diffusers.logging.DEBUG` - """ - - _configure_library_root_logger() - _get_library_root_logger().setLevel(verbosity) - - -def set_verbosity_info(): - """Set the verbosity to the `INFO` level.""" - return set_verbosity(INFO) - - -def set_verbosity_warning(): - """Set the verbosity to the `WARNING` level.""" - return set_verbosity(WARNING) - - -def set_verbosity_debug(): - """Set the verbosity to the `DEBUG` level.""" - return set_verbosity(DEBUG) - - -def set_verbosity_error(): - """Set the verbosity to the `ERROR` level.""" - return set_verbosity(ERROR) - - -def disable_default_handler() -> None: - """Disable the default handler of the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert _default_handler is not None - _get_library_root_logger().removeHandler(_default_handler) - - -def enable_default_handler() -> None: - """Enable the default handler of the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert _default_handler is not None - _get_library_root_logger().addHandler(_default_handler) - - -def add_handler(handler: logging.Handler) -> None: - """adds a handler to the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert handler is not None - _get_library_root_logger().addHandler(handler) - - -def remove_handler(handler: logging.Handler) -> None: - """removes given handler from the HuggingFace Diffusers' root logger.""" - - _configure_library_root_logger() - - assert handler is not None and handler not in _get_library_root_logger().handlers - _get_library_root_logger().removeHandler(handler) - - -def disable_propagation() -> None: - """ - Disable propagation of the library log outputs. Note that log propagation is disabled by default. - """ - - _configure_library_root_logger() - _get_library_root_logger().propagate = False - - -def enable_propagation() -> None: - """ - Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent - double logging if the root logger has been configured. - """ - - _configure_library_root_logger() - _get_library_root_logger().propagate = True - - -def enable_explicit_format() -> None: - """ - Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows: - ``` - [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE - ``` - All handlers currently bound to the root logger are affected by this method. - """ - handlers = _get_library_root_logger().handlers - - for handler in handlers: - formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") - handler.setFormatter(formatter) - - -def reset_format() -> None: - """ - Resets the formatting for HuggingFace Diffusers' loggers. - - All handlers currently bound to the root logger are affected by this method. - """ - handlers = _get_library_root_logger().handlers - - for handler in handlers: - handler.setFormatter(None) - - -def warning_advice(self, *args, **kwargs): - """ - This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this - warning will not be printed - """ - no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) - if no_advisory_warnings: - return - self.warning(*args, **kwargs) - - -logging.Logger.warning_advice = warning_advice - - -class EmptyTqdm: - """Dummy tqdm which doesn't do anything.""" - - def __init__(self, *args, **kwargs): # pylint: disable=unused-argument - self._iterator = args[0] if args else None - - def __iter__(self): - return iter(self._iterator) - - def __getattr__(self, _): - """Return empty function.""" - - def empty_fn(*args, **kwargs): # pylint: disable=unused-argument - return - - return empty_fn - - def __enter__(self): - return self - - def __exit__(self, type_, value, traceback): - return - - -class _tqdm_cls: - def __call__(self, *args, **kwargs): - if _tqdm_active: - return tqdm_lib.tqdm(*args, **kwargs) - else: - return EmptyTqdm(*args, **kwargs) - - def set_lock(self, *args, **kwargs): - self._lock = None - if _tqdm_active: - return tqdm_lib.tqdm.set_lock(*args, **kwargs) - - def get_lock(self): - if _tqdm_active: - return tqdm_lib.tqdm.get_lock() - - -tqdm = _tqdm_cls() - - -def is_progress_bar_enabled() -> bool: - """Return a boolean indicating whether tqdm progress bars are enabled.""" - global _tqdm_active - return bool(_tqdm_active) - - -def enable_progress_bar(): - """Enable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = True - - -def disable_progress_bar(): - """Disable tqdm progress bar.""" - global _tqdm_active - _tqdm_active = False diff --git a/my_diffusers/utils/model_card_template.md b/my_diffusers/utils/model_card_template.md deleted file mode 100644 index f19c85b0fcf2f7b07e9c3f950a9657b3f2053f21..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/model_card_template.md +++ /dev/null @@ -1,50 +0,0 @@ ---- -{{ card_data }} ---- - - - -# {{ model_name | default("Diffusion Model") }} - -## Model description - -This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library -on the `{{ dataset_name }}` dataset. - -## Intended uses & limitations - -#### How to use - -```python -# TODO: add an example code snippet for running this diffusion pipeline -``` - -#### Limitations and bias - -[TODO: provide examples of latent issues and potential remediations] - -## Training data - -[TODO: describe the data used to train the model] - -### Training hyperparameters - -The following hyperparameters were used during training: -- learning_rate: {{ learning_rate }} -- train_batch_size: {{ train_batch_size }} -- eval_batch_size: {{ eval_batch_size }} -- gradient_accumulation_steps: {{ gradient_accumulation_steps }} -- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} -- lr_scheduler: {{ lr_scheduler }} -- lr_warmup_steps: {{ lr_warmup_steps }} -- ema_inv_gamma: {{ ema_inv_gamma }} -- ema_inv_gamma: {{ ema_power }} -- ema_inv_gamma: {{ ema_max_decay }} -- mixed_precision: {{ mixed_precision }} - -### Training results - -📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) - - diff --git a/my_diffusers/utils/outputs.py b/my_diffusers/utils/outputs.py deleted file mode 100644 index f91a49b7a8a7c681ab51b8750f19cdaaf1600d79..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/outputs.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Generic utilities -""" - -from collections import OrderedDict -from dataclasses import fields -from typing import Any, Tuple - -import numpy as np - -from .import_utils import is_torch_available - - -def is_tensor(x): - """ - Tests if `x` is a `torch.Tensor` or `np.ndarray`. - """ - if is_torch_available(): - import torch - - if isinstance(x, torch.Tensor): - return True - - return isinstance(x, np.ndarray) - - -class BaseOutput(OrderedDict): - """ - Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a - tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular - python dictionary. - - - - You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple - before. - - - """ - - def __post_init__(self): - class_fields = fields(self) - - # Safety and consistency checks - if not len(class_fields): - raise ValueError(f"{self.__class__.__name__} has no fields.") - - first_field = getattr(self, class_fields[0].name) - other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) - - if other_fields_are_none and isinstance(first_field, dict): - for key, value in first_field.items(): - self[key] = value - else: - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v - - def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") - - def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") - - def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") - - def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - - def __getitem__(self, k): - if isinstance(k, str): - inner_dict = {k: v for (k, v) in self.items()} - return inner_dict[k] - else: - return self.to_tuple()[k] - - def __setattr__(self, name, value): - if name in self.keys() and value is not None: - # Don't call self.__setitem__ to avoid recursion errors - super().__setitem__(name, value) - super().__setattr__(name, value) - - def __setitem__(self, key, value): - # Will raise a KeyException if needed - super().__setitem__(key, value) - # Don't call self.__setattr__ to avoid recursion errors - super().__setattr__(key, value) - - def to_tuple(self) -> Tuple[Any]: - """ - Convert self to a tuple containing all the attributes/keys that are not `None`. - """ - return tuple(self[k] for k in self.keys()) diff --git a/my_diffusers/utils/pil_utils.py b/my_diffusers/utils/pil_utils.py deleted file mode 100644 index 39d0a15a4e2fe39fecb01951b36c43368492f983..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/pil_utils.py +++ /dev/null @@ -1,21 +0,0 @@ -import PIL.Image -import PIL.ImageOps -from packaging import version - - -if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): - PIL_INTERPOLATION = { - "linear": PIL.Image.Resampling.BILINEAR, - "bilinear": PIL.Image.Resampling.BILINEAR, - "bicubic": PIL.Image.Resampling.BICUBIC, - "lanczos": PIL.Image.Resampling.LANCZOS, - "nearest": PIL.Image.Resampling.NEAREST, - } -else: - PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, - } diff --git a/my_diffusers/utils/testing_utils.py b/my_diffusers/utils/testing_utils.py deleted file mode 100644 index 3727d2dbf97eb0fd2c091b02f5468609b7819264..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/testing_utils.py +++ /dev/null @@ -1,447 +0,0 @@ -import inspect -import logging -import os -import random -import re -import unittest -import urllib.parse -from distutils.util import strtobool -from io import BytesIO, StringIO -from pathlib import Path -from typing import Optional, Union - -import numpy as np -import PIL.Image -import PIL.ImageOps -import requests -from packaging import version - -from .import_utils import is_flax_available, is_onnx_available, is_torch_available -from .logging import get_logger - - -global_rng = random.Random() - -logger = get_logger(__name__) - -if is_torch_available(): - import torch - - if "DIFFUSERS_TEST_DEVICE" in os.environ: - torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] - - available_backends = ["cuda", "cpu", "mps"] - if torch_device not in available_backends: - raise ValueError( - f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:" - f" {available_backends}" - ) - logger.info(f"torch_device overrode to {torch_device}") - else: - torch_device = "cuda" if torch.cuda.is_available() else "cpu" - is_torch_higher_equal_than_1_12 = version.parse( - version.parse(torch.__version__).base_version - ) >= version.parse("1.12") - - if is_torch_higher_equal_than_1_12: - # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details - mps_backend_registered = hasattr(torch.backends, "mps") - torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device - - -def torch_all_close(a, b, *args, **kwargs): - if not is_torch_available(): - raise ValueError("PyTorch needs to be installed to use this function.") - if not torch.allclose(a, b, *args, **kwargs): - assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." - return True - - -def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"): - test_name = os.environ.get("PYTEST_CURRENT_TEST") - if not torch.is_tensor(tensor): - tensor = torch.from_numpy(tensor) - - tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") - # format is usually: - # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) - output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") - test_file, test_class, test_fn = test_name.split("::") - test_fn = test_fn.split()[0] - with open(filename, "a") as f: - print(";".join([test_file, test_class, test_fn, output_str]), file=f) - - -def get_tests_dir(append_path=None): - """ - Args: - append_path: optional path to append to the tests dir path - Return: - The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is - joined after the `tests` dir the former is provided. - """ - # this function caller's __file__ - caller__file__ = inspect.stack()[1][1] - tests_dir = os.path.abspath(os.path.dirname(caller__file__)) - - while not tests_dir.endswith("tests"): - tests_dir = os.path.dirname(tests_dir) - - if append_path: - return os.path.join(tests_dir, append_path) - else: - return tests_dir - - -def parse_flag_from_env(key, default=False): - try: - value = os.environ[key] - except KeyError: - # KEY isn't set, default to `default`. - _value = default - else: - # KEY is set, convert it to True or False. - try: - _value = strtobool(value) - except ValueError: - # More values are supported, but let's keep the message simple. - raise ValueError(f"If set, {key} must be yes or no.") - return _value - - -_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) -_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) - - -def floats_tensor(shape, scale=1.0, rng=None, name=None): - """Creates a random float32 tensor""" - if rng is None: - rng = global_rng - - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(rng.random() * scale) - - return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() - - -def slow(test_case): - """ - Decorator marking a test as slow. - - Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. - - """ - return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) - - -def nightly(test_case): - """ - Decorator marking a test that runs nightly in the diffusers CI. - - Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. - - """ - return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) - - -def require_torch(test_case): - """ - Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. - """ - return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) - - -def require_torch_gpu(test_case): - """Decorator marking a test that requires CUDA and PyTorch.""" - return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( - test_case - ) - - -def skip_mps(test_case): - """Decorator marking a test to skip if torch_device is 'mps'""" - return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) - - -def require_flax(test_case): - """ - Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed - """ - return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) - - -def require_onnxruntime(test_case): - """ - Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. - """ - return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) - - -def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: - if isinstance(arry, str): - # local_path = "/home/patrick_huggingface_co/" - if local_path is not None: - # local_path can be passed to correct images of tests - return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]])) - elif arry.startswith("http://") or arry.startswith("https://"): - response = requests.get(arry) - response.raise_for_status() - arry = np.load(BytesIO(response.content)) - elif os.path.isfile(arry): - arry = np.load(arry) - else: - raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" - ) - elif isinstance(arry, np.ndarray): - pass - else: - raise ValueError( - "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" - " ndarray." - ) - - return arry - - -def load_pt(url: str): - response = requests.get(url) - response.raise_for_status() - arry = torch.load(BytesIO(response.content)) - return arry - - -def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: - """ - Args: - Loads `image` to a PIL Image. - image (`str` or `PIL.Image.Image`): - The image to convert to the PIL Image format. - Returns: - `PIL.Image.Image`: A PIL Image. - """ - if isinstance(image, str): - if image.startswith("http://") or image.startswith("https://"): - image = PIL.Image.open(requests.get(image, stream=True).raw) - elif os.path.isfile(image): - image = PIL.Image.open(image) - else: - raise ValueError( - f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" - ) - elif isinstance(image, PIL.Image.Image): - image = image - else: - raise ValueError( - "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." - ) - image = PIL.ImageOps.exif_transpose(image) - image = image.convert("RGB") - return image - - -def load_hf_numpy(path) -> np.ndarray: - if not path.startswith("http://") or path.startswith("https://"): - path = os.path.join( - "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) - ) - - return load_numpy(path) - - -# --- pytest conf functions --- # - -# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once -pytest_opt_registered = {} - - -def pytest_addoption_shared(parser): - """ - This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. - - It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` - option. - - """ - option = "--make-reports" - if option not in pytest_opt_registered: - parser.addoption( - option, - action="store", - default=False, - help="generate report files. The value of this option is used as a prefix to report names", - ) - pytest_opt_registered[option] = 1 - - -def pytest_terminal_summary_main(tr, id): - """ - Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current - directory. The report files are prefixed with the test suite name. - - This function emulates --duration and -rA pytest arguments. - - This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined - there. - - Args: - - tr: `terminalreporter` passed from `conftest.py` - - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is - needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. - - NB: this functions taps into a private _pytest API and while unlikely, it could break should - pytest do internal changes - also it calls default internal methods of terminalreporter which - can be hijacked by various `pytest-` plugins and interfere. - - """ - from _pytest.config import create_terminal_writer - - if not len(id): - id = "tests" - - config = tr.config - orig_writer = config.get_terminal_writer() - orig_tbstyle = config.option.tbstyle - orig_reportchars = tr.reportchars - - dir = "reports" - Path(dir).mkdir(parents=True, exist_ok=True) - report_files = { - k: f"{dir}/{id}_{k}.txt" - for k in [ - "durations", - "errors", - "failures_long", - "failures_short", - "failures_line", - "passes", - "stats", - "summary_short", - "warnings", - ] - } - - # custom durations report - # note: there is no need to call pytest --durations=XX to get this separate report - # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 - dlist = [] - for replist in tr.stats.values(): - for rep in replist: - if hasattr(rep, "duration"): - dlist.append(rep) - if dlist: - dlist.sort(key=lambda x: x.duration, reverse=True) - with open(report_files["durations"], "w") as f: - durations_min = 0.05 # sec - f.write("slowest durations\n") - for i, rep in enumerate(dlist): - if rep.duration < durations_min: - f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") - break - f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") - - def summary_failures_short(tr): - # expecting that the reports were --tb=long (default) so we chop them off here to the last frame - reports = tr.getreports("failed") - if not reports: - return - tr.write_sep("=", "FAILURES SHORT STACK") - for rep in reports: - msg = tr._getfailureheadline(rep) - tr.write_sep("_", msg, red=True, bold=True) - # chop off the optional leading extra frames, leaving only the last one - longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) - tr._tw.line(longrepr) - # note: not printing out any rep.sections to keep the report short - - # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each - # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 - # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. - # pytest-instafail does that) - - # report failures with line/short/long styles - config.option.tbstyle = "auto" # full tb - with open(report_files["failures_long"], "w") as f: - tr._tw = create_terminal_writer(config, f) - tr.summary_failures() - - # config.option.tbstyle = "short" # short tb - with open(report_files["failures_short"], "w") as f: - tr._tw = create_terminal_writer(config, f) - summary_failures_short(tr) - - config.option.tbstyle = "line" # one line per error - with open(report_files["failures_line"], "w") as f: - tr._tw = create_terminal_writer(config, f) - tr.summary_failures() - - with open(report_files["errors"], "w") as f: - tr._tw = create_terminal_writer(config, f) - tr.summary_errors() - - with open(report_files["warnings"], "w") as f: - tr._tw = create_terminal_writer(config, f) - tr.summary_warnings() # normal warnings - tr.summary_warnings() # final warnings - - tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) - with open(report_files["passes"], "w") as f: - tr._tw = create_terminal_writer(config, f) - tr.summary_passes() - - with open(report_files["summary_short"], "w") as f: - tr._tw = create_terminal_writer(config, f) - tr.short_test_summary() - - with open(report_files["stats"], "w") as f: - tr._tw = create_terminal_writer(config, f) - tr.summary_stats() - - # restore: - tr._tw = orig_writer - tr.reportchars = orig_reportchars - config.option.tbstyle = orig_tbstyle - - -class CaptureLogger: - """ - Args: - Context manager to capture `logging` streams - logger: 'logging` logger object - Returns: - The captured output is available via `self.out` - Example: - ```python - >>> from diffusers import logging - >>> from diffusers.testing_utils import CaptureLogger - - >>> msg = "Testing 1, 2, 3" - >>> logging.set_verbosity_info() - >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") - >>> with CaptureLogger(logger) as cl: - ... logger.info(msg) - >>> assert cl.out, msg + "\n" - ``` - """ - - def __init__(self, logger): - self.logger = logger - self.io = StringIO() - self.sh = logging.StreamHandler(self.io) - self.out = "" - - def __enter__(self): - self.logger.addHandler(self.sh) - return self - - def __exit__(self, *exc): - self.logger.removeHandler(self.sh) - self.out = self.io.getvalue() - - def __repr__(self): - return f"captured: {self.out}\n" diff --git a/my_diffusers/utils/torch_utils.py b/my_diffusers/utils/torch_utils.py deleted file mode 100644 index 113e64c16bac62e86e72dd92017a5e96adea7b40..0000000000000000000000000000000000000000 --- a/my_diffusers/utils/torch_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -PyTorch utilities: Utilities related to PyTorch -""" -from typing import List, Optional, Tuple, Union - -from . import logging -from .import_utils import is_torch_available - - -if is_torch_available(): - import torch - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def randn_tensor( - shape: Union[Tuple, List], - generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, - device: Optional["torch.device"] = None, - dtype: Optional["torch.dtype"] = None, - layout: Optional["torch.layout"] = None, -): - """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When - passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor - will always be created on CPU. - """ - # device on which tensor is created defaults to device - rand_device = device - batch_size = shape[0] - - layout = layout or torch.strided - device = device or torch.device("cpu") - - if generator is not None: - gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type - if gen_device_type != device.type and gen_device_type == "cpu": - rand_device = "cpu" - if device != "mps": - logger.info( - f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." - f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" - f" slighly speed up this function by passing a generator that was created on the {device} device." - ) - elif gen_device_type != device.type and gen_device_type == "cuda": - raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") - - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - - return latents