# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional import paddle import paddle.nn.functional as F from paddle import nn from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput from .cross_attention import CrossAttention @dataclass class Transformer2DModelOutput(BaseOutput): """ Args: sample (`paddle.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ sample: paddle.Tensor class Transformer2DModel(ModelMixin, ConfigMixin): """ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual embeddings) inputs. When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image. When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict classes of unnoised image. Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim self.inner_dim = inner_dim = num_attention_heads * attention_head_dim # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = in_channels is not None self.is_input_vectorized = num_vector_embeds is not None if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) elif not self.is_input_continuous and not self.is_input_vectorized: raise ValueError( f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is not None." ) # 2. Define input layers if self.is_input_continuous: self.in_channels = in_channels self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, epsilon=1e-6) if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2D(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" self.height = sample_size self.width = sample_size self.num_vector_embeds = num_vector_embeds self.num_latent_pixels = self.height * self.width self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) # 3. Define transformers blocks self.transformer_blocks = nn.LayerList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) for d in range(num_layers) ] ) # 4. Define output layers if self.is_input_continuous: if use_linear_projection: self.proj_out = nn.Linear(in_channels, inner_dim) else: self.proj_out = nn.Conv2D(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, cross_attention_kwargs=None, return_dict: bool = True, ): """ Args: hidden_states ( When discrete, `paddle.Tensor` of shape `(batch size, num latent pixels)`. When continous, `paddle.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `paddle.Tensor` of shape `(batch size, encoder_hidden_states)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `paddle.Tensor`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. Input if self.is_input_continuous: _, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.transpose([0, 2, 3, 1]).flatten(1, 2) if self.use_linear_projection: hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states.cast("int64")) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, ) # 3. Output if self.is_input_continuous: if self.use_linear_projection: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape([-1, height, width, self.inner_dim]).transpose([0, 3, 1, 2]) if not self.use_linear_projection: hidden_states = self.proj_out(hidden_states) output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.transpose([0, 2, 1]) # log(p(x_0)) output = F.log_softmax(logits.cast("float64"), axis=1).cast("float32") if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class AttentionBlock(nn.Layer): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. Uses three q, k, v linear layers to compute attention. Parameters: channels (`int`): The number of channels in the input and output. num_head_channels (`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ def __init__( self, channels: int, num_head_channels: Optional[int] = None, norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, ): super().__init__() self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.head_dim = self.channels // self.num_heads self.scale = 1 / math.sqrt(self.channels / self.num_heads) self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, epsilon=eps) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) self.key = nn.Linear(channels, channels) self.value = nn.Linear(channels, channels) self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels) def reshape_heads_to_batch_dim(self, tensor): tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim]) tensor = tensor.transpose([0, 2, 1, 3]) return tensor def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.transpose([0, 2, 1, 3]) tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]]) return tensor def forward(self, hidden_states): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.reshape([batch, channel, height * width]).transpose([0, 2, 1]) # proj to q, k, v query_proj = self.query(hidden_states) key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) query_proj = self.reshape_heads_to_batch_dim(query_proj) key_proj = self.reshape_heads_to_batch_dim(key_proj) value_proj = self.reshape_heads_to_batch_dim(value_proj) # get scores attention_scores = paddle.matmul(query_proj, key_proj, transpose_y=True) * self.scale attention_probs = F.softmax(attention_scores.cast("float32"), axis=-1).cast(attention_scores.dtype) # compute attention output hidden_states = paddle.matmul(attention_probs, value_proj) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]).reshape([batch, channel, height, width]) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states class BasicTransformerBlock(nn.Layer): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = num_embeds_ada_norm is not None # 1. Self-Attn self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) # 2. Cross-Attn if cross_attention_dim is not None: self.attn2 = CrossAttention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) if cross_attention_dim is not None: self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, cross_attention_kwargs=None, ): # 1. Self-Attention norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states if self.attn2 is not None: # 2. Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states return hidden_states class FeedForward(nn.Layer): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.LayerList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out)) def forward(self, hidden_states): for module in self.net: hidden_states = module(hidden_states) return hidden_states class GELU(nn.Layer): r""" GELU activation function """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = F.gelu(hidden_states) return hidden_states # feedforward class GEGLU(nn.Layer): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, axis=-1) return hidden_states * F.gelu(gate) class ApproximateGELU(nn.Layer): """ The approximate form of Gaussian Error Linear Unit (GELU) For more details, see section 2: https://arxiv.org/abs/1606.08415 """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def forward(self, x): x = self.proj(x) return x * F.sigmoid(1.702 * x) class AdaLayerNorm(nn.Layer): """ Norm layer modified to incorporate timestep embeddings. """ def __init__(self, embedding_dim, num_embeddings): super().__init__() self.emb = nn.Embedding(num_embeddings, embedding_dim) self.silu = nn.Silu() self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.norm = nn.LayerNorm(embedding_dim) # elementwise_affine=False def forward(self, x, timestep): emb = self.linear(self.silu(self.emb(timestep))) scale, shift = paddle.chunk(emb, 2, axis=-1) x = self.norm(x) * (1 + scale) + shift return x class DualTransformer2DModel(nn.Layer): """ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, ): super().__init__() self.transformers = nn.LayerList( [ Transformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, ) for _ in range(2) ] ) # Variables that can be set by a pipeline: # The ratio of transformer1 to transformer2's output states to be combined during inference self.mix_ratio = 0.5 # The shape of `encoder_hidden_states` is expected to be # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` self.condition_lengths = [77, 257] # Which transformer to use to encode which condition. # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` self.transformer_index_for_condition = [1, 0] def forward( self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, cross_attention_kwargs=None, return_dict: bool = True, ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. attention_mask (`torch.FloatTensor`, *optional*): Optional attention mask to be applied in CrossAttention return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ input_states = hidden_states encoded_states = [] tokens_start = 0 # attention_mask is not used yet for i in range(2): # for each of the two transformers, pass the corresponding condition tokens condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] transformer_index = self.transformer_index_for_condition[i] encoded_state = self.transformers[transformer_index]( input_states, encoder_hidden_states=condition_state, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] encoded_states.append(encoded_state - input_states) tokens_start += self.condition_lengths[i] output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) output_states = output_states + input_states if not return_dict: return (output_states,) return Transformer2DModelOutput(sample=output_states)