# Copyright 2024 LLM-grounded Video Diffusion Models (LVD) Team and The HuggingFace Team. All rights reserved. # Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. # Copyright 2024 The ModelScope Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.models.activations import get_activation from diffusers.models.attention import (Attention, FeedForward, GatedSelfAttentionDense, _chunked_feed_forward) from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor) from diffusers.models.embeddings import (ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps) from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero) from diffusers.models.resnet import (Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D) from diffusers.models.transformer_2d import Transformer2DModelOutput from diffusers.models.transformers.transformer_temporal import TransformerTemporalModelOutput from diffusers.models.unets.unet_3d_blocks import ( CrossAttnDownBlockMotion, CrossAttnDownBlockSpatioTemporal, CrossAttnUpBlockMotion, CrossAttnUpBlockSpatioTemporal, DownBlockMotion, DownBlockSpatioTemporal, UpBlock3D, UpBlockMotion, UpBlockSpatioTemporal) from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput from diffusers.utils import (USE_PEFT_BACKEND, deprecate, is_torch_version, logging) from diffusers.utils.torch_utils import apply_freeu, maybe_allow_in_graph from torch import nn logger = logging.get_logger(__name__) # pylint: disable=invalid-name class FourierEmbedder(nn.Module): def __init__(self, num_freqs=64, temperature=100): super().__init__() self.num_freqs = num_freqs self.temperature = temperature freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) freq_bands = freq_bands[None, None, None] self.register_buffer("freq_bands", freq_bands, persistent=False) def __call__(self, x): x = self.freq_bands * x.unsqueeze(-1) return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) class PositionNet(nn.Module): def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): super().__init__() self.positive_len = positive_len self.out_dim = out_dim self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy if isinstance(out_dim, tuple): out_dim = out_dim[0] if feature_type == "text-only": self.linears = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter( torch.zeros([self.positive_len])) elif feature_type == "text-image": self.linears_text = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.linears_image = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_text_feature = torch.nn.Parameter( torch.zeros([self.positive_len])) self.null_image_feature = torch.nn.Parameter( torch.zeros([self.positive_len])) self.null_position_feature = torch.nn.Parameter( torch.zeros([self.position_dim])) def forward( self, boxes, masks, positive_embeddings=None, phrases_masks=None, image_masks=None, phrases_embeddings=None, image_embeddings=None, ): masks = masks.unsqueeze(-1) # embedding position (it may includes padding as placeholder) xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C # learnable null embedding xyxy_null = self.null_position_feature.view(1, 1, -1) # replace padding with learnable null embedding xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null # positionet with text only information if positive_embeddings is not None: # learnable null embedding positive_null = self.null_positive_feature.view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * \ masks + (1 - masks) * positive_null objs = self.linears( torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) # positionet with text and image infomation else: phrases_masks = phrases_masks.unsqueeze(-1) image_masks = image_masks.unsqueeze(-1) # learnable null embedding text_null = self.null_text_feature.view(1, 1, -1) image_null = self.null_image_feature.view(1, 1, -1) # replace padding with learnable null embedding phrases_embeddings = phrases_embeddings * \ phrases_masks + (1 - phrases_masks) * text_null image_embeddings = image_embeddings * \ image_masks + (1 - image_masks) * image_null objs_text = self.linears_text( torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) objs_image = self.linears_image( torch.cat([image_embeddings, xyxy_embedding], dim=-1)) objs = torch.cat([objs_text, objs_image], dim=1) return objs class Transformer2DModel(ModelMixin, ConfigMixin): """ A 2D Transformer model for image-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. num_vector_embeds (`int`, *optional*): The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, attention_type: str = "default", caption_channels: int = None, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = ( in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None if norm_type == "layer_norm" and num_embeds_ada_norm is not None: deprecation_message = ( f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" " would be very nice if you could open a Pull request for the `transformer/config.json` file" ) deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) norm_type = "ada_norm" if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) elif self.is_input_vectorized and self.is_input_patches: raise ValueError( f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" " sure that either `num_vector_embeds` or `num_patches` is None." ) elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: raise ValueError( f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) # 2. Define input layers if self.is_input_continuous: self.in_channels = in_channels self.norm = torch.nn.GroupNorm( num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = linear_cls(in_channels, inner_dim) else: self.proj_in = conv_cls( in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" self.height = sample_size self.width = sample_size self.num_vector_embeds = num_vector_embeds self.num_latent_pixels = self.height * self.width self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) elif self.is_input_patches: assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" self.height = sample_size self.width = sample_size self.patch_size = patch_size # => 64 (= 512 pixart) has interpolation scale 1 interpolation_scale = self.config.sample_size // 64 interpolation_scale = max(interpolation_scale, 1) self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, interpolation_scale=interpolation_scale, ) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, double_self_attention=double_self_attention, upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, attention_type=attention_type, ) for d in range(num_layers) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: self.proj_out = linear_cls(inner_dim, in_channels) else: self.proj_out = conv_cls( inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) elif self.is_input_patches and norm_type != "ada_norm_single": self.norm_out = nn.LayerNorm( inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) self.proj_out_2 = nn.Linear( inner_dim, patch_size * patch_size * self.out_channels) elif self.is_input_patches and norm_type == "ada_norm_single": self.norm_out = nn.LayerNorm( inner_dim, elementwise_affine=False, eps=1e-6) self.scale_shift_table = nn.Parameter( torch.randn(2, inner_dim) / inner_dim**0.5) self.proj_out = nn.Linear( inner_dim, patch_size * patch_size * self.out_channels) # 5. PixArt-Alpha blocks. self.adaln_single = None self.use_additional_conditions = False if norm_type == "ada_norm_single": self.use_additional_conditions = self.config.sample_size == 128 # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use # additional conditions until we find better name self.adaln_single = AdaLayerNormSingle( inner_dim, use_additional_conditions=self.use_additional_conditions) self.caption_projection = None if caption_channels is not None: self.caption_projection = PixArtAlphaTextProjection( in_features=caption_channels, hidden_size=inner_dim) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. cross_attention_kwargs ( `Dict[str, Any]`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). attention_mask ( `torch.Tensor`, *optional*): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = ( 1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = ( 1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # Retrieve lora scale. lora_scale = cross_attention_kwargs.get( "scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = ( self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states) ) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute( 0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute( 0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = ( self.proj_in(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_in(hidden_states) ) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size hidden_states = self.pos_embed(hidden_states) if self.adaln_single is not None: if self.use_additional_conditions and added_cond_kwargs is None: raise ValueError( "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." ) batch_size = hidden_states.shape[0] timestep, embedded_timestep = self.adaln_single( timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype ) # 2. Blocks if self.caption_projection is not None: batch_size = hidden_states.shape[0] encoder_hidden_states = self.caption_projection( encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view( batch_size, -1, hidden_states.shape[-1]) for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = { "use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels, **ckpt_kwargs, ) else: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape( batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = ( self.proj_out(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_out(hidden_states) ) else: hidden_states = ( self.proj_out(hidden_states, scale=lora_scale) if not USE_PEFT_BACKEND else self.proj_out(hidden_states) ) hidden_states = hidden_states.reshape( batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() if self.is_input_patches: if self.config.norm_type != "ada_norm_single": conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1( F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out( hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) elif self.config.norm_type == "ada_norm_single": shift, scale = ( self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) # unpatchify if self.adaln_single is None: height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class TransformerTemporalModel(ModelMixin, ConfigMixin): """ A Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported activation functions. norm_elementwise_affine (`bool`, *optional*): Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. positional_embeddings: (`str`, *optional*): The type of positional embeddings to apply to the sequence input before passing use. num_positional_embeddings: (`int`, *optional*): The maximum length of the sequence over which to apply positional embeddings. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.in_channels = in_channels self.norm = torch.nn.GroupNorm( num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, positional_embeddings=positional_embeddings, num_positional_embeddings=num_positional_embeddings, ) for d in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.LongTensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: torch.LongTensor = None, num_frames: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> TransformerTemporalModelOutput: """ The [`TransformerTemporal`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input hidden_states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. num_frames (`int`, *optional*, defaults to 1): The number of frames to be processed per batch. This is used to reshape the hidden states. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape batch_size = batch_frames // num_frames residual = hidden_states hidden_states = hidden_states[None, :].reshape( batch_size, num_frames, channel, height, width) hidden_states = hidden_states.permute(0, 2, 1, 3, 4) hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, num_frames, channel) .permute(0, 3, 4, 1, 2) .contiguous() ) hidden_states = hidden_states.reshape( batch_frames, channel, height, width) output = hidden_states + residual if not return_dict: return (output,) return TransformerTemporalModelOutput(sample=output) @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the attention computation to float32. This is useful for mixed precision training. norm_elementwise_affine (`bool`, *optional*, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. final_dropout (`bool` *optional*, defaults to False): Whether to apply a final dropout after the last feed-forward layer. attention_type (`str`, *optional*, defaults to `"default"`): The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. positional_embeddings (`str`, *optional*, defaults to `None`): The type of positional embeddings to apply to. num_positional_embeddings (`int`, *optional*, defaults to `None`): The maximum number of positional embeddings to apply. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' norm_type: str = "layer_norm", norm_eps: float = 1e-5, final_dropout: bool = False, attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, ada_norm_bias: Optional[int] = None, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = ( num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = ( num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" self.use_layer_norm = norm_type == "layer_norm" self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) if positional_embeddings and (num_positional_embeddings is None): raise ValueError( "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." ) if positional_embeddings == "sinusoidal": self.pos_embed = SinusoidalPositionalEmbedding( dim, max_seq_length=num_positional_embeddings) else: self.pos_embed = None # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_continuous: self.norm1 = AdaLayerNormContinuous( dim, ada_norm_continous_conditioning_embedding_dim, norm_elementwise_affine, norm_eps, ada_norm_bias, "rms_norm", ) else: self.norm1 = nn.LayerNorm( dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, out_bias=attention_out_bias, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. if self.use_ada_layer_norm: self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_continuous: self.norm2 = AdaLayerNormContinuous( dim, ada_norm_continous_conditioning_embedding_dim, norm_elementwise_affine, norm_eps, ada_norm_bias, "rms_norm", ) else: self.norm2 = nn.LayerNorm( dim, norm_eps, norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, out_bias=attention_out_bias, ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward if self.use_ada_layer_norm_continuous: self.norm3 = AdaLayerNormContinuous( dim, ada_norm_continous_conditioning_embedding_dim, norm_elementwise_affine, norm_eps, ada_norm_bias, "layer_norm", ) elif not self.use_ada_layer_norm_single: self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) # 4. Fuser if attention_type == "gated" or attention_type == "gated-text-image": self.fuser = GatedSelfAttentionDense( dim, cross_attention_dim, num_attention_heads, attention_head_dim) else: self.fuser = None # 5. Scale-shift for PixArt-Alpha. if self.use_ada_layer_norm_single: self.scale_shift_table = nn.Parameter( torch.randn(6, dim) / dim**0.5) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.FloatTensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) elif self.use_layer_norm: norm_hidden_states = self.norm1(hidden_states) elif self.use_ada_layer_norm_continuous: norm_hidden_states = self.norm1( hidden_states, added_cond_kwargs["pooled_text_emb"]) elif self.use_ada_layer_norm_single: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * \ (1 + scale_msa) + shift_msa norm_hidden_states = norm_hidden_states.squeeze(1) else: raise ValueError("Incorrect norm used") if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) # 1. Retrieve lora scale. lora_scale = cross_attention_kwargs.get( "scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 2. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy( ) if cross_attention_kwargs is not None else {} lvd_gligen_kwargs = cross_attention_kwargs.pop("gligen", None) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output elif self.use_ada_layer_norm_single: attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 2.5 GLIGEN Control if lvd_gligen_kwargs is not None: if self.fuser is not None: hidden_states = self.fuser( hidden_states, lvd_gligen_kwargs["objs"]) # 3. Cross-Attention if self.attn2 is not None: if self.use_ada_layer_norm: norm_hidden_states = self.norm2(hidden_states, timestep) elif self.use_ada_layer_norm_zero or self.use_layer_norm: norm_hidden_states = self.norm2(hidden_states) elif self.use_ada_layer_norm_single: # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states elif self.use_ada_layer_norm_continuous: norm_hidden_states = self.norm2( hidden_states, added_cond_kwargs["pooled_text_emb"]) else: raise ValueError("Incorrect norm") if self.pos_embed is not None and self.use_ada_layer_norm_single is False: norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 4. Feed-forward if self.use_ada_layer_norm_continuous: norm_hidden_states = self.norm3( hidden_states, added_cond_kwargs["pooled_text_emb"]) elif not self.use_ada_layer_norm_single: norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * \ (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.use_ada_layer_norm_single: norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * \ (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward( self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale ) else: ff_output = self.ff(norm_hidden_states, scale=lora_scale) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output elif self.use_ada_layer_norm_single: ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states def get_down_block( down_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_downsample: bool, resnet_eps: float, resnet_act_fn: str, num_attention_heads: int, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, downsample_padding: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = True, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, transformer_layers_per_block: int = 1, attention_type: str = "default", ) -> Union[ "DownBlock3D", "CrossAttnDownBlock3D", "DownBlockMotion", "CrossAttnDownBlockMotion", "DownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", ]: if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: raise ValueError( "cross_attention_dim must be specified for CrossAttnDownBlock3D") return CrossAttnDownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, ) if down_block_type == "DownBlockMotion": return DownBlockMotion( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, ) elif down_block_type == "CrossAttnDownBlockMotion": if cross_attention_dim is None: raise ValueError( "cross_attention_dim must be specified for CrossAttnDownBlockMotion") return CrossAttnDownBlockMotion( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, ) elif down_block_type == "DownBlockSpatioTemporal": # added for SDV return DownBlockSpatioTemporal( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, ) elif down_block_type == "CrossAttnDownBlockSpatioTemporal": # added for SDV if cross_attention_dim is None: raise ValueError( "cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") return CrossAttnDownBlockSpatioTemporal( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, add_downsample=add_downsample, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type: str, num_layers: int, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, add_upsample: bool, resnet_eps: float, resnet_act_fn: str, num_attention_heads: int, resolution_idx: Optional[int] = None, resnet_groups: Optional[int] = None, cross_attention_dim: Optional[int] = None, dual_cross_attention: bool = False, use_linear_projection: bool = True, only_cross_attention: bool = False, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", temporal_num_attention_heads: int = 8, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, transformer_layers_per_block: int = 1, dropout: float = 0.0, ) -> Union[ "UpBlock3D", "CrossAttnUpBlock3D", "UpBlockMotion", "CrossAttnUpBlockMotion", "UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", ]: if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: raise ValueError( "cross_attention_dim must be specified for CrossAttnUpBlock3D") return CrossAttnUpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, attention_type=attention_type, ) if up_block_type == "UpBlockMotion": return UpBlockMotion( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, ) elif up_block_type == "CrossAttnUpBlockMotion": if cross_attention_dim is None: raise ValueError( "cross_attention_dim must be specified for CrossAttnUpBlockMotion") return CrossAttnUpBlockMotion( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resolution_idx=resolution_idx, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, ) elif up_block_type == "UpBlockSpatioTemporal": # added for SDV return UpBlockSpatioTemporal( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resolution_idx=resolution_idx, add_upsample=add_upsample, ) elif up_block_type == "CrossAttnUpBlockSpatioTemporal": # added for SDV if cross_attention_dim is None: raise ValueError( "cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") return CrossAttnUpBlockSpatioTemporal( in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, add_upsample=add_upsample, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, resolution_idx=resolution_idx, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, dual_cross_attention: bool = False, use_linear_projection: bool = True, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min( in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] temp_convs = [ TemporalConvLayer( in_channels, in_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ] attentions = [] temp_attentions = [] for _ in range(num_layers): attentions.append( Transformer2DModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) ) temp_attentions.append( TransformerTemporalModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( in_channels, in_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.temp_convs[0]( hidden_states, num_frames=num_frames) for attn, temp_attn, resnet, temp_conv in zip( self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] ): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) return hidden_states class CrossAttnDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", ): super().__init__() resnets = [] attentions = [] temp_attentions = [] temp_convs = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, cross_attention_kwargs: Dict[str, Any] = None, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class DownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, num_frames: int = 1, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, cross_attention_dim: int = 1280, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", resolution_idx: Optional[int] = None, ): super().__init__() resnets = [] temp_convs = [] attentions = [] temp_attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if ( i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_upsample: self.upsamplers = nn.ModuleList( [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False self.resolution_idx = resolution_idx def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, num_frames: int = 1, cross_attention_kwargs: Dict[str, Any] = None, ) -> torch.FloatTensor: is_freeu_enabled = ( getattr(self, "s1", None) and getattr(self, "s2", None) and getattr(self, "b1", None) and getattr(self, "b2", None) ) # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] # FreeU: Only operate on the first two stages if is_freeu_enabled: hidden_states, res_hidden_states = apply_freeu( self.resolution_idx, hidden_states, res_hidden_states, s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2, ) hidden_states = torch.cat( [hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class GroundedUNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, down_block_types: Tuple[str, ...] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), up_block_types: Tuple[str, ...] = ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", ), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, attention_type: str = "default", ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise NotImplementedError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input conv_in_kernel = 3 conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, 0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) self.transformer_in = TransformerTemporalModel( num_attention_heads=8, attention_head_dim=attention_head_dim, in_channels=block_out_channels[0], num_layers=1, norm_num_groups=norm_num_groups, ) # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * \ len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=False, attention_type=attention_type, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, attention_type=attention_type, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min( i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, resolution_idx=i, attention_type=attention_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation("silu") else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) if attention_type in ["gated", "gated-text-image"]: positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): positive_len = cross_attention_dim[0] feature_type = "text-only" if attention_type == "gated" else "text-image" self.position_net = PositionNet( positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor( return_deprecated_lora=True) for sub_name, child in module.named_children(): fn_recursive_add_processors( f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * \ [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError( f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False ): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor, _remove_lora=_remove_lora) else: module.set_processor(processor.pop( f"{name}.processor"), _remove_lora=_remove_lora) for sub_name, child in module.named_children(): fn_recursive_attn_processor( f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError( f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor, _remove_lora=True) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu def enable_freeu(self, s1, s2, b1, b2): r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. The suffixes after the scaling factors represent the stage blocks where they are being applied. Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. Args: s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process. b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. """ for i, upsample_block in enumerate(self.up_blocks): setattr(upsample_block, "s1", s1) setattr(upsample_block, "s2", s2) setattr(upsample_block, "b1", b1) setattr(upsample_block, "b2", b2) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu def disable_freeu(self): """Disables the FreeU mechanism.""" freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]: r""" The [`GroundedUNet3DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, num_frames, height, width`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed through the `self.time_embedding` layer to obtain the timestep embeddings. attention_mask (`torch.Tensor`, *optional*, defaults to `None`): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): A tuple of tensors that if specified are added to the residuals of down unet blocks. mid_block_additional_residual: (`torch.Tensor`, *optional*): A tensor that if specified is added to the residual of the middle unet block. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info( "Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor( [timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave( repeats=num_frames, dim=0) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape( (sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) sample = self.transformer_in( sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() lvd_gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = { "objs": self.position_net(**lvd_gligen_args)} # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets):] down_block_res_samples = down_block_res_samples[: -len( upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # reshape to (batch, channel, framerate, width, height) sample = sample[None, :].reshape( (-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample)