# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Based on [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111). # Authors: Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly # Code: https://github.com/apple/ml-mdm with MIT license # # Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). import gc import inspect import math from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from packaging import version from PIL import Image from torch import nn from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.configuration_utils import ConfigMixin, FrozenDict, LegacyConfigMixin, register_to_config from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromSingleFileMixin, IPAdapterMixin, PeftAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin, UNet2DConditionLoadersMixin, ) from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.activations import GELU, get_activation from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, FusedAttnProcessor2_0, ) from diffusers.models.downsampling import Downsample2D from diffusers.models.embeddings import ( GaussianFourierProjection, GLIGENTextBoundingboxProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.modeling_utils import LegacyModelMixin, ModelMixin from diffusers.models.resnet import ResnetBlock2D from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from diffusers.models.upsampling import Upsample2D from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import ( USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import apply_freeu, randn_tensor if is_torch_xla_available(): import torch_xla.core.xla_model as xm # type: ignore XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import DiffusionPipeline >>> from diffusers.utils import make_image_grid >>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", >>> custom_pipeline="matryoshka").to("cuda") >>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" >>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" >>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy" >>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images >>> make_image_grid(image, rows=1, cols=len(image)) >>> pipe.change_nesting_level() # 0, 1, or 2 >>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively. ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps # Copied from diffusers.models.attention._chunked_feed_forward def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = hidden_states.shape[chunk_dim] // chunk_size ff_output = torch.cat( [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], dim=chunk_dim, ) return ff_output @dataclass class MatryoshkaDDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: Union[torch.Tensor, List[torch.Tensor]] pred_original_sample: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin): """ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. beta_start (`float`, defaults to 0.0001): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. beta_schedule (`str`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. clip_sample (`bool`, defaults to `True`): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, defaults to `True`): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. dynamic_thresholding_ratio (`float`, defaults to 0.995): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 elif beta_schedule == "squaredcos_cap_v2": if self.config.timestep_spacing == "matryoshka_style": self.betas = torch.cat((torch.tensor([0]), betas_for_alpha_bar(num_train_timesteps))) else: # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) self.scales = None self.schedule_shifted_power = 1.0 def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.Tensor`: A scaled input sample. """ return sample def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 elif self.config.timestep_spacing == "matryoshka_style": step_ratio = (self.config.num_train_timesteps + 1) / (num_inference_steps + 1) timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1].copy().astype(np.int64) else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def get_schedule_shifted(self, alpha_prod, scale_factor=None): if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule scale_factor = scale_factor ** self.schedule_shifted_power snr = alpha_prod / (1 - alpha_prod) scaled_snr = snr / scale_factor alpha_prod = 1 / (1 + 1 / scaled_snr) return alpha_prod def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): The weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`, defaults to `False`): If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` has no effect. generator (`torch.Generator`, *optional*): A random number generator. variance_noise (`torch.Tensor`): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) if self.config.timestep_spacing != "matryoshka_style": prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps else: prev_timestep = self.timesteps[torch.nonzero(self.timesteps == timestep).item() + 1] # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod if self.config.timestep_spacing == "matryoshka_style" and len(model_output) > 1: alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in self.scales]) alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in self.scales]) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": if len(model_output) > 1: pred_original_sample = [] pred_epsilon = [] for m_o, s, a_p_t, b_p_t in zip(model_output, sample, alpha_prod_t, beta_prod_t): pred_original_sample.append((a_p_t**0.5) * s - (b_p_t**0.5) * m_o) pred_epsilon.append((a_p_t**0.5) * m_o + (b_p_t**0.5) * s) else: pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: if len(model_output) > 1: pred_original_sample = [ self._threshold_sample(p_o_s) for p_o_s in pred_original_sample ] else: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: if len(model_output) > 1: pred_original_sample = [ p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) for p_o_s in pred_original_sample ] else: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide if len(model_output) > 1: pred_epsilon = [] for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t): pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5)) else: pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if len(model_output) > 1: pred_sample_direction = [] for p_e, a_p_t_p in zip(pred_epsilon, alpha_prod_t_prev): pred_sample_direction.append((1 - a_p_t_p - std_dev_t**2) ** (0.5) * p_e) else: pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if len(model_output) > 1: prev_sample = [] for p_o_s, p_s_d, a_p_t_p in zip(pred_original_sample, pred_sample_direction, alpha_prod_t_prev): prev_sample.append(a_p_t_p ** (0.5) * p_o_s + p_s_d) else: prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: if len(model_output) > 1: variance_noise = [] for m_o in model_output: variance_noise.append( randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype) ) else: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) if len(model_output) > 1: prev_sample = [p_s + std_dev_t * v_n for v_n, p_s in zip(variance_noise, prev_sample)] else: variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement # for the subsequent add_noise calls self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps class CrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, norm_type: str = "layer_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", attention_pre_only: bool = False, attention_bias: bool = False, use_attention_ffn: bool = True, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( MatryoshkaTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_attention_ffn=use_attention_ffn, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, additional_residuals: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") output_states = () blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class UNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, norm_type: str = "layer_norm", num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, cross_attention_norm: Optional[str] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", attention_pre_only: bool = False, attention_bias: bool = False, use_attention_ffn: bool = True, ): super().__init__() out_channels = out_channels or in_channels self.in_channels = in_channels self.out_channels = out_channels self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers resnet_groups_out = resnet_groups_out or resnet_groups # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, groups_out=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for i in range(num_layers): attentions.append( MatryoshkaTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_attention_ffn=use_attention_ffn, ) ) resnets.append( ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups_out, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) else: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, norm_type: str = "layer_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", attention_pre_only: bool = False, attention_bias: bool = False, use_attention_ffn: bool = True, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( MatryoshkaTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_attention_ffn=use_attention_ffn, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.Tensor, res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @dataclass class MatryoshkaTransformer2DModelOutput(BaseOutput): """ The output of [`MatryoshkaTransformer2DModel`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`MatryoshkaTransformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ sample: "torch.Tensor" # noqa: F821 class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["MatryoshkaTransformerBlock"] @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, cross_attention_dim: Optional[int] = None, upcast_attention: bool = False, use_attention_ffn: bool = True, ): super().__init__() self.in_channels = self.config.num_attention_heads * self.config.attention_head_dim self.gradient_checkpointing = False self.transformer_blocks = nn.ModuleList( [ MatryoshkaTransformerBlock( self.in_channels, self.config.num_attention_heads, self.config.attention_head_dim, cross_attention_dim=self.config.cross_attention_dim, upcast_attention=self.config.upcast_attention, use_attention_ffn=self.config.use_attention_ffn, ) for _ in range(self.config.num_layers) ] ) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ The [`MatryoshkaTransformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. cross_attention_kwargs ( `Dict[str, Any]`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). attention_mask ( `torch.Tensor`, *optional*): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~MatryoshkaTransformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, **ckpt_kwargs, ) else: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # Output output = hidden_states if not return_dict: return (output,) return MatryoshkaTransformer2DModelOutput(sample=output) class MatryoshkaTransformerBlock(nn.Module): r""" Matryoshka Transformer block. Parameters: """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: Optional[int] = None, upcast_attention: bool = False, use_attention_ffn: bool = True, ): super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.cross_attention_dim = cross_attention_dim # Define 3 blocks. # 1. Self-Attn self.attn1 = Attention( query_dim=dim, cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, norm_num_groups=32, bias=True, upcast_attention=upcast_attention, pre_only=True, processor=MatryoshkaFusedAttnProcessor2_0(), ) self.attn1.fuse_projections() del self.attn1.to_q del self.attn1.to_k del self.attn1.to_v # 2. Cross-Attn if cross_attention_dim is not None and cross_attention_dim > 0: self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, cross_attention_norm="layer_norm", heads=num_attention_heads, dim_head=attention_head_dim, bias=True, upcast_attention=upcast_attention, pre_only=True, processor=MatryoshkaFusedAttnProcessor2_0(), ) self.attn2.fuse_projections() del self.attn2.to_q del self.attn2.to_k del self.attn2.to_v self.proj_out = nn.Linear(dim, dim) if use_attention_ffn: # 3. Feed-forward self.ff = MatryoshkaFeedForward(dim) else: self.ff = None # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.Tensor: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") # 1. Self-Attention batch_size, channels, *spatial_dims = hidden_states.shape attn_output, query = self.attn1( hidden_states, # **cross_attention_kwargs, ) # 2. Cross-Attention if self.cross_attention_dim is not None and self.cross_attention_dim > 0: attn_output_cond = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, self_attention_output=attn_output, self_attention_query=query, # **cross_attention_kwargs, ) attn_output_cond = self.proj_out(attn_output_cond) attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims) hidden_states = hidden_states + attn_output_cond if self.ff is not None: # 3. Feed-forward if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(hidden_states) hidden_states = ff_output + hidden_states return hidden_states class MatryoshkaFusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is currently 🧪 experimental in nature and can change in future. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, self_attention_query: Optional[torch.Tensor] = None, self_attention_output: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states) if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous() if encoder_hidden_states is None: qkv = attn.to_qkv(hidden_states) split_size = qkv.shape[-1] // 3 query, key, value = torch.split(qkv, split_size, dim=-1) else: if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self_attention_query is not None: query = self_attention_query else: query = attn.to_q(hidden_states) kv = attn.to_kv(encoder_hidden_states) split_size = kv.shape[-1] // 2 key, value = torch.split(kv, split_size, dim=-1) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads if self_attention_output is None: query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.to(query.dtype) if self_attention_output is not None: hidden_states = hidden_states + self_attention_output hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states if self_attention_output is not None else (hidden_states, query) class MatryoshkaFeedForward(nn.Module): r""" A feed-forward layer for the Matryoshka models. Parameters:""" def __init__( self, dim: int, ): super().__init__() self.group_norm = nn.GroupNorm(32, dim) self.linear_gelu = GELU(dim, dim * 4) self.linear_out = nn.Linear(dim * 4, dim) def forward(self, x): batch_size, channels, *spatial_dims = x.shape x = self.group_norm(x) x = x.view(batch_size, channels, -1).permute(0, 2, 1) x = self.linear_out(self.linear_gelu(x)) x = x.permute(0, 2, 1).view(batch_size, channels, *spatial_dims) return x def get_down_block( down_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_downsample: bool, resnet_eps: float, resnet_act_fn: str, norm_type: str = "layer_norm", transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, downsample_padding: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", attention_pre_only: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = None, use_attention_ffn: bool = True, downsample_type: Optional[str] = None, dropout: float = 0.0, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warning( f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, dropout=dropout, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, norm_type=norm_type, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, attention_pre_only=attention_pre_only, use_attention_ffn=use_attention_ffn, ) def get_mid_block( mid_block_type: str, temb_channels: int, in_channels: int, resnet_eps: float, resnet_act_fn: str, resnet_groups: int, norm_type: str = "layer_norm", output_scale_factor: float = 1.0, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, cross_attention_dim: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, mid_block_only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", attention_pre_only: bool = False, resnet_skip_time_act: bool = False, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = 1, dropout: float = 0.0, ): if mid_block_type == "UNetMidBlock2DCrossAttn": return UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, temb_channels=temb_channels, dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, norm_type=norm_type, output_scale_factor=output_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, resnet_groups=resnet_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, attention_pre_only=attention_pre_only, ) def get_up_block( up_block_type: str, num_layers: int, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, add_upsample: bool, resnet_eps: float, resnet_act_fn: str, norm_type: str = "layer_norm", resolution_idx: Optional[int] = None, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", attention_pre_only: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = None, use_attention_ffn: bool = True, upsample_type: Optional[str] = None, dropout: float = 0.0, ) -> nn.Module: # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warning( f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, dropout=dropout, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, norm_type=norm_type, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, attention_pre_only=attention_pre_only, use_attention_ffn=use_attention_ffn, ) class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, type): super().__init__() if type == "unet": self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) elif type == "nested_unet": self.cond_emb = None self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) def forward(self, emb, encoder_hidden_states, added_cond_kwargs): conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( dim=1, keepdim=True ) cond_emb = self.cond_emb(y) else: cond_emb = None if not masked_cross_attention: conditioning_mask = None micro = added_cond_kwargs.get("micro_conditioning_scale", None) if micro is not None: temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): return temb_micro_conditioning, conditioning_mask, cond_emb return None, conditioning_mask, cond_emb @dataclass class MatryoshkaUNet2DConditionOutput(BaseOutput): """ The output of [`MatryoshkaUNet2DConditionOutput`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.Tensor = None sample_inner: torch.Tensor = None class MatryoshkaUNet2DConditionModel( ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin ): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True _no_split_modules = ["MatryoshkaTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", norm_type: str = "layer_norm", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_attention_ffn: bool = True, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", attention_pre_only: bool = False, masked_cross_attention: bool = False, micro_conditioning_scale: int = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, temporal_mode: bool = False, temporal_spatial_ds: bool = False, skip_cond_emb: bool = False, nesting: Optional[int] = False, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs self._check_config( down_block_types=down_block_types, up_block_types=up_block_types, only_cross_attention=only_cross_attention, block_out_channels=block_out_channels, layers_per_block=layers_per_block, cross_attention_dim=cross_attention_dim, transformer_layers_per_block=transformer_layers_per_block, reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, attention_head_dim=attention_head_dim, num_attention_heads=num_attention_heads, ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim, timestep_input_dim = self._set_time_proj( time_embedding_type, block_out_channels=block_out_channels, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, time_embedding_dim=time_embedding_dim, ) self.time_embedding = TimestepEmbedding( time_embedding_dim // 4 if time_embedding_dim is not None else timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) self._set_encoder_hid_proj( encoder_hid_dim_type, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, ) # class embedding self._set_class_embedding( class_embed_type, act_fn=act_fn, num_class_embeds=num_class_embeds, projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, time_embed_dim=time_embed_dim, timestep_input_dim=timestep_input_dim, ) self._set_add_embedding( addition_embed_type, addition_embed_type_num_heads=addition_embed_type_num_heads, addition_time_embed_dim=timestep_input_dim, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, time_embed_dim=time_embed_dim, ) if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, norm_type=norm_type, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, attention_pre_only=attention_pre_only, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, use_attention_ffn=use_attention_ffn, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( mid_block_type, temb_channels=blocks_time_embed_dim, in_channels=block_out_channels[-1], resnet_eps=norm_eps, resnet_act_fn=act_fn, norm_type=norm_type, resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor, transformer_layers_per_block=1, num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, mid_block_only_cross_attention=mid_block_only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, attention_pre_only=attention_pre_only, resnet_skip_time_act=resnet_skip_time_act, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[-1], dropout=dropout, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = ( list(reversed(transformer_layers_per_block)) if reverse_transformer_layers_per_block is None else reverse_transformer_layers_per_block ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, norm_type=norm_type, resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, attention_pre_only=attention_pre_only, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, use_attention_ffn=use_attention_ffn, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.up_blocks.append(up_block) # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) self.is_temporal = [] def _check_config( self, down_block_types: Tuple[str], up_block_types: Tuple[str], only_cross_attention: Union[bool, Tuple[bool]], block_out_channels: Tuple[int], layers_per_block: Union[int, Tuple[int]], cross_attention_dim: Union[int, Tuple[int]], transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], reverse_transformer_layers_per_block: bool, attention_head_dim: int, num_attention_heads: Optional[Union[int, Tuple[int]]], ): if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") def _set_time_proj( self, time_embedding_type: str, block_out_channels: int, flip_sin_to_cos: bool, freq_shift: float, time_embedding_dim: int, ) -> Tuple[int, int]: if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 if self.model_type == "unet": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 256: self.time_proj = Timesteps(block_out_channels[0] * 4, flip_sin_to_cos, freq_shift) elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 1024: self.time_proj = Timesteps(block_out_channels[0] * 4 * 2, flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) return time_embed_dim, timestep_input_dim def _set_encoder_hid_proj( self, encoder_hid_dim_type: Optional[str], cross_attention_dim: Union[int, Tuple[int]], encoder_hid_dim: Optional[int], ): if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." ) else: self.encoder_hid_proj = None def _set_class_embedding( self, class_embed_type: Optional[str], act_fn: str, num_class_embeds: Optional[int], projection_class_embeddings_input_dim: Optional[int], time_embed_dim: int, timestep_input_dim: int, ): if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None def _set_add_embedding( self, addition_embed_type: str, addition_embed_type_num_heads: int, addition_time_embed_dim: Optional[int], flip_sin_to_cos: bool, freq_shift: float, cross_attention_dim: Optional[int], encoder_hid_dim: Optional[int], projection_class_embeddings_input_dim: Optional[int], time_embed_dim: int, ): if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "matryoshka": self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( self.config.time_embedding_dim // 4 if self.config.time_embedding_dim is not None else addition_time_embed_dim, cross_attention_dim, time_embed_dim, self.model_type, # if not self.config.nesting else "inner_" + self.model_type, ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError( f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." ) def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): if attention_type in ["gated", "gated-text-image"]: positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim elif isinstance(cross_attention_dim, (list, tuple)): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" self.position_net = GLIGENTextBoundingboxProjection( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) def fuse_qkv_projections(self): """ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. For cross-attention modules, key and value projection matrices are fused. This API is 🧪 experimental. """ self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedAttnProcessor2_0()) def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def get_time_embed( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] ) -> Optional[torch.Tensor]: timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) return t_emb def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: class_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) return class_emb def get_aug_embed( self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> Optional[torch.Tensor]: aug_emb = None if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "matryoshka": aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb = self.add_embedding(image_embs, hint) return aug_emb def process_encoder_hidden_states( self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> torch.Tensor: if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states @property def model_type(self) -> str: return "unet" def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, cond_emb: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, from_nested: bool = False, ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: r""" The [`NestedUNet2DConditionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain tuple. Returns: [`~NestedUNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if self.config.nesting: sample, sample_feat = sample if isinstance(sample, list) and len(sample) == 1: sample = sample[0] for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) if class_emb is not None: if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale added_cond_kwargs["from_nested"] = from_nested added_cond_kwargs["conditioning_mask"] = encoder_attention_mask if not from_nested: encoder_hidden_states = self.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) else: aug_emb, encoder_attention_mask, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb + cond_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) # 2. pre-process sample = self.conv_in(sample) if self.config.nesting: sample = sample + sample_feat # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. if cross_attention_kwargs is not None: cross_attention_kwargs = cross_attention_kwargs.copy() lora_scale = cross_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, ) sample_inner = sample # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample_inner) sample = self.conv_act(sample) sample = self.conv_out(sample) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (sample,) if self.config.nesting: return MatryoshkaUNet2DConditionOutput(sample=sample, sample_inner=sample_inner) return MatryoshkaUNet2DConditionOutput(sample=sample) class NestedUNet2DConditionOutput(BaseOutput): """ Output type for the [`NestedUNet2DConditionModel`] model. """ sample: list = None sample_inner: torch.Tensor = None class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): """ Nested UNet model with condition for image denoising. """ @register_to_config def __init__( self, in_channels=3, out_channels=3, block_out_channels=(64, 128, 256), cross_attention_dim=2048, resnet_time_scale_shift="scale_shift", down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"), up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"), mid_block_type=None, nesting=False, flip_sin_to_cos=False, transformer_layers_per_block=[0, 0, 0], layers_per_block=[2, 2, 1], masked_cross_attention=True, micro_conditioning_scale=256, addition_embed_type="matryoshka", skip_normalization=True, time_embedding_dim=1024, skip_inner_unet_input=False, temporal_mode=False, temporal_spatial_ds=False, initialize_inner_with_pretrained=None, use_attention_ffn=False, act_fn="silu", addition_embed_type_num_heads=64, addition_time_embed_dim=None, attention_head_dim=8, attention_pre_only=False, attention_type="default", center_input_sample=False, class_embed_type=None, class_embeddings_concat=False, conv_in_kernel=3, conv_out_kernel=3, cross_attention_norm=None, downsample_padding=1, dropout=0.0, dual_cross_attention=False, encoder_hid_dim=None, encoder_hid_dim_type=None, freq_shift=0, mid_block_only_cross_attention=None, mid_block_scale_factor=1, norm_eps=1e-05, norm_num_groups=32, norm_type="layer_norm", num_attention_heads=None, num_class_embeds=None, only_cross_attention=False, projection_class_embeddings_input_dim=None, resnet_out_scale_factor=1.0, resnet_skip_time_act=False, reverse_transformer_layers_per_block=None, sample_size=None, skip_cond_emb=False, time_cond_proj_dim=None, time_embedding_act_fn=None, time_embedding_type="positional", timestep_post_act=None, upcast_attention=False, use_linear_projection=False, is_temporal=None, inner_config={}, ): super().__init__( in_channels=in_channels, out_channels=out_channels, block_out_channels=block_out_channels, cross_attention_dim=cross_attention_dim, resnet_time_scale_shift=resnet_time_scale_shift, down_block_types=down_block_types, up_block_types=up_block_types, mid_block_type=mid_block_type, nesting=nesting, flip_sin_to_cos=flip_sin_to_cos, transformer_layers_per_block=transformer_layers_per_block, layers_per_block=layers_per_block, masked_cross_attention=masked_cross_attention, micro_conditioning_scale=micro_conditioning_scale, addition_embed_type=addition_embed_type, time_embedding_dim=time_embedding_dim, temporal_mode=temporal_mode, temporal_spatial_ds=temporal_spatial_ds, use_attention_ffn=use_attention_ffn, sample_size=sample_size, ) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim if "inner_config" not in self.config.inner_config: self.inner_unet = MatryoshkaUNet2DConditionModel(**self.config.inner_config) else: self.inner_unet = NestedUNet2DConditionModel(**self.config.inner_config) if not self.config.skip_inner_unet_input: self.in_adapter = nn.Conv2d( self.config.block_out_channels[-1], self.config.inner_config["block_out_channels"][0], kernel_size=3, padding=1, ) else: self.in_adapter = None self.out_adapter = nn.Conv2d( self.config.inner_config["block_out_channels"][0], self.config.block_out_channels[-1], kernel_size=3, padding=1, ) self.is_temporal = [self.config.temporal_mode and (not self.config.temporal_spatial_ds)] if hasattr(self.inner_unet, "is_temporal"): self.is_temporal = self.is_temporal + self.inner_unet.is_temporal nest_ratio = int(2 ** (len(self.config.block_out_channels) - 1)) if self.is_temporal[0]: nest_ratio = int(np.sqrt(nest_ratio)) if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio else: self.nest_ratio = [nest_ratio] # self.register_modules(inner_unet=self.inner_unet) @property def model_type(self): return "nested_unet" def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, cond_emb: Optional[torch.Tensor] = None, from_nested: bool = False, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: r""" The [`NestedUNet2DConditionModel`] forward method. Args: sample (`torch.Tensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain tuple. Returns: [`~NestedUNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if self.config.nesting: sample, sample_feat = sample if isinstance(sample, list) and len(sample) == 1: sample = sample[0] # 2. input layer (normalize the input) bsz = [x.size(0) for x in sample] bh, bl = bsz[0], bsz[1] x_t_low, sample = sample[1:], sample[0] for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) if class_emb is not None: if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.inner_unet.model_type == "unet": added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale added_cond_kwargs["conditioning_mask"] = encoder_attention_mask if not self.config.nesting: encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention aug_emb, __, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) else: aug_emb, cond_mask, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) elif self.inner_unet.model_type == "nested_unet": added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale added_cond_kwargs["conditioning_mask"] = encoder_attention_mask encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) aug_emb, __, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb + cond_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if not self.config.skip_normalization: sample = sample / sample.std((1, 2, 3), keepdims=True) if isinstance(sample, list) and len(sample) == 1: sample = sample[0] sample = self.conv_in(sample) if self.config.nesting: sample = sample + sample_feat # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. if cross_attention_kwargs is not None: cross_attention_kwargs = cross_attention_kwargs.copy() lora_scale = cross_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True # 3. downsample blocks in the outer layers down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb[:bh], encoder_hidden_states=encoder_hidden_states[:bh], attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples # 4. run inner unet x_inner = self.in_adapter(sample) if self.in_adapter is not None else None x_inner = ( torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner ) # pad zeros for low-resolutions inner_unet_output = self.inner_unet( (x_t_low, x_inner), timestep, cond_emb=cond_emb, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=cond_mask, from_nested=True, ) x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner x_inner = self.out_adapter(x_inner) sample = sample + x_inner[:bh] if bh < bl else sample + x_inner # 5. upsample blocks in the outer layers for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb[:bh], res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states[:bh], cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, ) # 6. post-process if self.conv_norm_out: sample_out = self.conv_norm_out(sample) sample_out = self.conv_act(sample_out) sample_out = self.conv_out(sample_out) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) # 7. output both low and high-res output if isinstance(x_low, list): out = [sample_out] + x_low else: out = [sample_out, x_low] if self.config.nesting: return NestedUNet2DConditionOutput(sample=out, sample_inner=sample) if not return_dict: return (out,) else: return NestedUNet2DConditionOutput(sample=out) @dataclass class MatryoshkaPipelineOutput(BaseOutput): """ Output class for Matryoshka pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray, List[np.ndarray]] class MatryoshkaPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Matryoshka Diffusion Models. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: text_encoder ([`~transformers.T5EncoderModel`]): Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)). tokenizer ([`~transformers.T5Tokenizer`]): A `T5Tokenizer` to tokenize text. unet ([`MatryoshkaUNet2DConditionModel`]): A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md. feature_extractor ([`~transformers.`]): A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet" _optional_components = ["unet", "feature_extractor", "image_encoder"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, scheduler: MatryoshkaDDIMScheduler, unet: MatryoshkaUNet2DConditionModel = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, trust_remote_code: bool = False, nesting_level: int = 0, ): super().__init__() if nesting_level == 0: unet = MatryoshkaUNet2DConditionModel.from_pretrained( "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" ) elif nesting_level == 1: unet = NestedUNet2DConditionModel.from_pretrained( "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" ) elif nesting_level == 2: unet = NestedUNet2DConditionModel.from_pretrained( "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" ) else: raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # deprecation_message = ( # f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." # " `clip_sample` should be set to False in the configuration file. Please make sure to update the" # " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" # " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" # " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" # ) # deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) # new_config = dict(scheduler.config) # new_config["clip_sample"] = False # scheduler._internal_dict = FrozenDict(new_config) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) if hasattr(unet, "nest_ratio"): scheduler.scales = unet.nest_ratio + [1] if nesting_level == 2: scheduler.schedule_shifted_power = 2.0 self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.register_to_config(nesting_level=nesting_level) self.image_processor = VaeImageProcessor(do_resize=False) def change_nesting_level(self, nesting_level: int): if nesting_level == 0: if hasattr(self.unet, "nest_ratio"): self.scheduler.scales = None self.unet = MatryoshkaUNet2DConditionModel.from_pretrained( "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" ).to(self.device) self.config.nesting_level = 0 elif nesting_level == 1: self.unet = NestedUNet2DConditionModel.from_pretrained( "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" ).to(self.device) self.config.nesting_level = 1 self.scheduler.scales = self.unet.nest_ratio + [1] self.scheduler.schedule_shifted_power = 1.0 elif nesting_level == 2: self.unet = NestedUNet2DConditionModel.from_pretrained( "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" ).to(self.device) self.config.nesting_level = 2 self.scheduler.scales = self.unet.nest_ratio + [1] self.scheduler.schedule_shifted_power = 2.0 else: raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") gc.collect() torch.cuda.empty_cache() def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: prompt_attention_mask = text_inputs.attention_mask.to(device) else: prompt_attention_mask = None if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) uncond_input = self.tokenizer( uncond_tokens, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: negative_prompt_attention_mask = uncond_input.attention_mask.to(device) else: negative_prompt_attention_mask = None if not do_classifier_free_guidance: if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) else: max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0])) if len(text_input_ids[0]) < max_len: text_input_ids = torch.cat( [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)], dim=1, ) prompt_attention_mask = torch.cat( [ prompt_attention_mask, torch.zeros( batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device ), ], dim=1, ) elif len(uncond_input_ids[0]) < max_len: uncond_input_ids = torch.cat( [uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)], dim=1, ) negative_prompt_attention_mask = torch.cat( [ negative_prompt_attention_mask, torch.zeros( batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long, device=device, ), ], dim=1, ) cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0) cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) prompt_embeds = self.text_encoder( cfg_input_ids.to(device), attention_mask=cfg_attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if not do_classifier_free_guidance: return prompt_embeds, None, prompt_attention_mask, None return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, scales, latents=None ): shape = ( batch_size, num_channels_latents, int(height), int(width), ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) if scales is not None: out = [latents] for s in scales[1:]: ratio = scales[0] // s sample_low = F.avg_pool2d(latents, ratio) * ratio sample_low = sample_low.normal_(generator=generator) out += [sample_low] latents = out else: if scales is not None: latents = [latent.to(device=device) for latent in latents] else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler if scales is not None: latents = [latent * self.scheduler.init_noise_sigma for latent in latents] else: latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def guidance_rescale(self): return self._guidance_rescale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~MatryoshkaPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet height = height or self.unet.config.sample_size width = width or self.unet.config.sample_size # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)]) attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) else: attention_masks = prompt_attention_mask prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps = timesteps[:-1] # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, self.scheduler.scales, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs |= {"use_clipped_model_output": True} # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) else None ) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance if self.do_classifier_free_guidance and isinstance(latents, list): latent_model_input = [latent.repeat(2, 1, 1, 1) for latent in latents] elif self.do_classifier_free_guidance: latent_model_input = latents.repeat(2, 1, 1, 1) else: latent_model_input = latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t - 1, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, encoder_attention_mask=attention_masks, return_dict=False, )[0] # perform guidance if isinstance(noise_pred, list) and self.do_classifier_free_guidance: for i, (noise_pred_uncond, noise_pred_text) in enumerate(noise_pred): noise_pred[i] = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if XLA_AVAILABLE: xm.mark_step() image = latents if self.scheduler.scales is not None: for i, (img, scale) in enumerate(zip(image, self.scheduler.scales)): image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0] else: image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return MatryoshkaPipelineOutput(images=image)