# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional import numpy as np import torch from torch import nn from .activations import get_activation def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) return latent + self.pos_embed class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, ) return t_emb class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.log = log self.flip_sin_to_cos = flip_sin_to_cos if set_W_to_weight: # to delete later self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W def forward(self, x): if self.log: x = torch.log(x) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi if self.flip_sin_to_cos: out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the height and width of the latent space. For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For VQ-diffusion: Output vector embeddings are used as input for the transformer. Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. Args: num_embed (`int`): Number of embeddings for the latent pixels embeddings. height (`int`): Height of the latent image i.e. the number of height embeddings. width (`int`): Width of the latent image i.e. the number of width embeddings. embed_dim (`int`): Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. """ def __init__( self, num_embed: int, height: int, width: int, embed_dim: int, ): super().__init__() self.height = height self.width = width self.num_embed = num_embed self.embed_dim = embed_dim self.emb = nn.Embedding(self.num_embed, embed_dim) self.height_emb = nn.Embedding(self.height, embed_dim) self.width_emb = nn.Embedding(self.width, embed_dim) def forward(self, index): emb = self.emb(index) height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) # 1 x H x D -> 1 x H x 1 x D height_emb = height_emb.unsqueeze(2) width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) # 1 x W x D -> 1 x 1 x W x D width_emb = width_emb.unsqueeze(1) pos_emb = height_emb + width_emb # 1 x H x W x D -> 1 x L xD pos_emb = pos_emb.view(1, self.height * self.width, -1) emb = emb + pos_emb[:, : emb.shape[1], :] return emb class LabelEmbedding(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. Args: num_classes (`int`): The number of classes. hidden_size (`int`): The size of the vector embeddings. dropout_prob (`float`): The probability of dropping a label. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = torch.tensor(force_drop_ids == 1) labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class TextImageProjection(nn.Module): def __init__( self, text_embed_dim: int = 1024, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 10, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): batch_size = text_embeds.shape[0] # image image_text_embeds = self.image_embeds(image_embeds) image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) # text text_embeds = self.text_proj(text_embeds) return torch.cat([image_text_embeds, text_embeds], dim=1) class ImageProjection(nn.Module): def __init__( self, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 32, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.FloatTensor): batch_size = image_embeds.shape[0] # image image_embeds = self.image_embeds(image_embeds) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = self.norm(image_embeds) return image_embeds class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) def forward(self, timestep, class_labels, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) class_labels = self.class_embedder(class_labels) # (N, D) conditioning = timesteps_emb + class_labels # (N, D) return conditioning class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() self.norm1 = nn.LayerNorm(encoder_dim) self.pool = AttentionPooling(num_heads, encoder_dim) self.proj = nn.Linear(encoder_dim, time_embed_dim) self.norm2 = nn.LayerNorm(time_embed_dim) def forward(self, hidden_states): hidden_states = self.norm1(hidden_states) hidden_states = self.pool(hidden_states) hidden_states = self.proj(hidden_states) hidden_states = self.norm2(hidden_states) return hidden_states class TextImageTimeEmbedding(nn.Module): def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) self.text_norm = nn.LayerNorm(time_embed_dim) self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): # text time_text_embeds = self.text_proj(text_embeds) time_text_embeds = self.text_norm(time_text_embeds) # image time_image_embeds = self.image_proj(image_embeds) return time_image_embeds + time_text_embeds class ImageTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) def forward(self, image_embeds: torch.FloatTensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) return time_image_embeds class ImageHintTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) self.input_hint_block = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(32, 32, 3, padding=1), nn.SiLU(), nn.Conv2d(32, 96, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(96, 96, 3, padding=1), nn.SiLU(), nn.Conv2d(96, 256, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(256, 4, 3, padding=1), ) def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) hint = self.input_hint_block(hint) return time_image_embeds, hint class AttentionPooling(nn.Module): # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.num_heads = num_heads self.dim_per_head = embed_dim // self.num_heads def forward(self, x): bs, length, width = x.size() def shape(x): # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, -1, self.num_heads, self.dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) x = x.transpose(1, 2) return x class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) # (bs*n_heads, class_token_length, dim_per_head) q = shape(self.q_proj(class_token)) # (bs*n_heads, length+class_token_length, dim_per_head) k = shape(self.k_proj(x)) v = shape(self.v_proj(x)) # (bs*n_heads, class_token_length, length+class_token_length): scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) # (bs*n_heads, dim_per_head, class_token_length) a = torch.einsum("bts,bcs->bct", weight, v) # (bs, length+1, width) a = a.reshape(bs, -1, 1).transpose(1, 2) return a[:, 0, :] # cls_token