import itertools from functools import partial from typing import Any, Dict, Tuple, Callable from typing import Union, Optional, List import numpy as np import torch from diffusers import DPMSolverMultistepScheduler from diffusers import StableDiffusionPipeline, AutoencoderKL from diffusers import Transformer2DModel, ModelMixin, ConfigMixin, SchedulerMixin from diffusers import UNet2DConditionModel from diffusers.configuration_utils import register_to_config from diffusers.models.attention import BasicTransformerBlock from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModelOutput from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import replace_example_docstring from torch import nn from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg def custom_sort_order(obj): """ Key function for sorting order of execution in forward methods """ return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__) def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing): """ :param timestep_spacing: the timestep_spacing array we want to squeeze :param n: the size of the squeezed array :param i: the index we start squeezing from :return: squeezed timestep_spacing Example: timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16) n = 10, i = 6 Expected: [967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133 """ assert i < n squeezed = np.flip(np.arange(n)) + 1 # [n, n-1, ..., 2, 1] squeezed[:i] = timestep_spacing[:i] k = squeezed[i - 1] // (n - i + 1) squeezed[i:] *= k return squeezed PREDEFINED_TIMESTEP_SQUEEZERS = { # Tested with DPM 16-steps (reduced 16 -> 10 or 11 steps) "10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6), "11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7), } FlexibleUnetConfigurations = { # General parameters for all blocks "sample_size": 64, "temb_dim": 320 * 4, "resnet_eps": 1e-5, "resnet_act_fn": "silu", "num_attention_heads": 8, "cross_attention_dim": 768, # Controls modules execute order in unet's forward "mix_block_in_forward": True, # Down blocks parameters "down_blocks_in_channels": [320, 320, 640], "down_blocks_out_channels": [320, 640, 1280], "down_blocks_num_attentions": [0, 1, 3], "down_blocks_num_resnets": [2, 2, 1], "add_downsample": [True, True, False], # Middle block parameters "add_upsample_mid_block": None, "mid_num_resnets": 0, "mid_num_attentions": 0, # Up block parameters "prev_output_channels": [1280, 1280, 640], "up_blocks_num_attentions": [5, 3, 0], "up_blocks_num_resnets": [2, 3, 3], "add_upsample": [True, True, False], } class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler, SchedulerMixin): """ This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences: * Defaults are modified to accommodate DeciDiffusion * It supports a squeezer to squeeze the number of inference steps to a smaller number //!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline! """ @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "squaredcos_cap_v2", # NOTE THIS DEFAULT VALUE trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "v_prediction", # NOTE THIS DEFAULT VALUE thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "heun", # NOTE THIS DEFAULT VALUE lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -3.0, # NOTE THIS DEFAULT VALUE variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 1, squeeze_mode: Optional[str] = None, # NOTE THIS ADDITION. Supports keys from `PREDEFINED_TIMESTEP_SQUEEZERS` defined above ): self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode) if use_karras_sigmas: raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`") super().__init__( num_train_timesteps=num_train_timesteps, beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, trained_betas=trained_betas, solver_order=solver_order, prediction_type=prediction_type, thresholding=thresholding, dynamic_thresholding_ratio=dynamic_thresholding_ratio, sample_max_value=sample_max_value, algorithm_type=algorithm_type, solver_type=solver_type, lower_order_final=lower_order_final, use_karras_sigmas=False, lambda_min_clipped=lambda_min_clipped, variance_type=variance_type, timestep_spacing=timestep_spacing, steps_offset=steps_offset, ) def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ super().set_timesteps(num_inference_steps=num_inference_steps, device=device) if self._squeezer is not None: timesteps = self._squeezer(self.timesteps.cpu()) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) self.num_inference_steps = len(timesteps) class FlexibleIdentityBlock(nn.Module): def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): return hidden_states class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin): configurations = FlexibleUnetConfigurations @register_to_config def __init__(self): super().__init__( sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]), cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]), ) num_attention_heads = self.configurations.get("num_attention_heads") cross_attention_dim = self.configurations.get("cross_attention_dim") mix_block_in_forward = self.configurations.get("mix_block_in_forward") resnet_act_fn = self.configurations.get("resnet_act_fn") resnet_eps = self.configurations.get("resnet_eps") temb_dim = self.configurations.get("temb_dim") ############### # Down blocks # ############### down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions") down_blocks_out_channels = self.configurations.get("down_blocks_out_channels") down_blocks_in_channels = self.configurations.get("down_blocks_in_channels") down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets") add_downsample = self.configurations.get("add_downsample") self.down_blocks = nn.ModuleList() for i, (in_c, out_c, n_res, n_att, add_down) in enumerate( zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample) ): last_block = i == len(down_blocks_in_channels) - 1 self.down_blocks.append( FlexibleCrossAttnDownBlock2D( in_channels=in_c, out_channels=out_c, temb_channels=temb_dim, num_resnets=n_res, num_attentions=n_att, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, num_attention_heads=num_attention_heads, cross_attention_dim=cross_attention_dim, add_downsample=add_down, last_block=last_block, mix_block_in_forward=mix_block_in_forward, ) ) ############### # Mid blocks # ############### mid_block_add_upsample = self.configurations.get("add_upsample_mid_block") mid_num_attentions = self.configurations.get("mid_num_attentions") mid_num_resnets = self.configurations.get("mid_num_resnets") if mid_num_resnets == mid_num_attentions == 0: self.mid_block = FlexibleIdentityBlock() else: self.mid_block = FlexibleUNetMidBlock2DCrossAttn( in_channels=down_blocks_out_channels[-1], temb_channels=temb_dim, resnet_act_fn=resnet_act_fn, resnet_eps=resnet_eps, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, num_resnets=mid_num_resnets, num_attentions=mid_num_attentions, mix_block_in_forward=mix_block_in_forward, add_upsample=mid_block_add_upsample, ) ############### # Up blocks # ############### up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions") up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets") prev_output_channels = self.configurations.get("prev_output_channels") up_upsample = self.configurations.get("add_upsample") self.up_blocks = nn.ModuleList() for in_c, out_c, prev_out, n_res, n_att, add_up in zip( reversed(down_blocks_in_channels), reversed(down_blocks_out_channels), prev_output_channels, up_blocks_num_resnets, up_blocks_num_attentions, up_upsample, ): self.up_blocks.append( FlexibleCrossAttnUpBlock2D( in_channels=in_c, out_channels=out_c, prev_output_channel=prev_out, temb_channels=temb_dim, num_resnets=n_res, num_attentions=n_att, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, num_attention_heads=num_attention_heads, cross_attention_dim=cross_attention_dim, add_upsample=add_up, mix_block_in_forward=mix_block_in_forward, ) ) class FlexibleCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_resnets: int = 1, num_attentions: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, last_block: bool = False, mix_block_in_forward: bool = True, ): super().__init__() self.last_block = last_block self.mix_block_in_forward = mix_block_in_forward self.has_cross_attention = True self.num_attention_heads = num_attention_heads modules = [] add_resnets = [True] * num_resnets add_cross_attentions = [True] * num_attentions for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): in_channels = in_channels if i == 0 else out_channels if add_resnet: modules.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_cross_attention: modules.append( FlexibleTransformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) if not mix_block_in_forward: modules = sorted(modules, key=custom_sort_order) self.modules_list = nn.ModuleList(modules) if add_downsample: self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")]) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () for module in self.modules_list: if isinstance(module, ResnetBlock2D): hidden_states = module(hidden_states, temb) elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): hidden_states = module( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: raise ValueError(f"Got an unexpected module in modules list! {type(module)}") if isinstance(module, ResnetBlock2D): output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) if not self.last_block: output_states = output_states + (hidden_states,) return hidden_states, output_states class FlexibleCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_resnets: int = 1, num_attentions: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, mix_block_in_forward: bool = True, ): super().__init__() modules = [] # WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline self.resnets = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads add_resnets = [True] * num_resnets add_cross_attentions = [True] * num_attentions for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels if add_resnet: self.resnets += [True] modules.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_cross_attention: modules.append( FlexibleTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) if not mix_block_in_forward: modules = sorted(modules, key=custom_sort_order) self.modules_list = nn.ModuleList(modules) self.upsamplers = None if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): for module in self.modules_list: if isinstance(module, ResnetBlock2D): res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = module(hidden_states, temb) if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): hidden_states = module( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class FlexibleUNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_resnets: int = 1, num_attentions: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, use_linear_projection: bool = False, upcast_attention: bool = False, mix_block_in_forward: bool = True, add_upsample: bool = True, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # There is always at least one resnet modules = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] add_resnets = [True] * num_resnets add_cross_attentions = [True] * num_attentions for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): if add_cross_attention: modules.append( FlexibleTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) if add_resnet: modules.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not mix_block_in_forward: modules = sorted(modules, key=custom_sort_order) self.modules_list = nn.ModuleList(modules) self.upsamplers = nn.ModuleList([nn.Identity()]) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: hidden_states = self.modules_list[0](hidden_states, temb) for module in self.modules_list: if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): hidden_states = module( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] elif isinstance(module, ResnetBlock2D): hidden_states = module(hidden_states, temb) for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class FlexibleTransformer2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, only_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.in_channels = in_channels inner_dim = num_attention_heads * attention_head_dim # Define input layers self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.use_linear_projection = use_linear_projection if self.use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) # Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, ) for _ in range(num_layers) ] ) # Define output layers self.out_channels = in_channels if out_channels is None else out_channels if self.use_linear_projection: self.proj_out = nn.Linear(inner_dim, in_channels) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = False, ): # 1. Input batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual if return_dict: return (output,) return Transformer2DModelOutput(sample=output) class DeciDiffusionPipeline(StableDiffusionPipeline): deci_default_squeeze_mode = "10,6" deci_default_number_of_iterations = 16 deci_default_guidance_rescale = 0.7 def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): # Replace UNet with Deci`s unet del unet unet = FlexibleUNet2DConditionModel() # Replace with custom scheduler del scheduler scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode) super().__init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker, ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 16, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.7, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)