Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2023 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from ..utils import USE_PEFT_BACKEND | |
| from .activations import get_activation | |
| from .attention_processor import Attention | |
| from .lora import LoRACompatibleLinear | |
| def get_timestep_embedding( | |
| timesteps: torch.Tensor, | |
| embedding_dim: int, | |
| flip_sin_to_cos: bool = False, | |
| downscale_freq_shift: float = 1, | |
| scale: float = 1, | |
| max_period: int = 10000, | |
| ): | |
| """ | |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. | |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the | |
| embeddings. :return: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
| half_dim = embedding_dim // 2 | |
| exponent = -math.log(max_period) * torch.arange( | |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
| ) | |
| exponent = exponent / (half_dim - downscale_freq_shift) | |
| emb = torch.exp(exponent) | |
| emb = timesteps[:, None].float() * emb[None, :] | |
| # scale embeddings | |
| emb = scale * emb | |
| # concat sine and cosine embeddings | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
| # flip sine and cosine embeddings | |
| if flip_sin_to_cos: | |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
| # zero pad | |
| if embedding_dim % 2 == 1: | |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
| return emb | |
| def get_2d_sincos_pos_embed( | |
| embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 | |
| ): | |
| """ | |
| grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or | |
| [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
| """ | |
| if isinstance(grid_size, int): | |
| grid_size = (grid_size, grid_size) | |
| grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale | |
| grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale | |
| grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| if cls_token and extra_tokens > 0: | |
| pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| if embed_dim % 2 != 0: | |
| raise ValueError("embed_dim must be divisible by 2") | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | |
| """ | |
| if embed_dim % 2 != 0: | |
| raise ValueError("embed_dim must be divisible by 2") | |
| omega = np.arange(embed_dim // 2, dtype=np.float64) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| class PatchEmbed(nn.Module): | |
| """2D Image to Patch Embedding""" | |
| def __init__( | |
| self, | |
| height=224, | |
| width=224, | |
| patch_size=16, | |
| in_channels=3, | |
| embed_dim=768, | |
| layer_norm=False, | |
| flatten=True, | |
| bias=True, | |
| interpolation_scale=1, | |
| ): | |
| super().__init__() | |
| num_patches = (height // patch_size) * (width // patch_size) | |
| self.flatten = flatten | |
| self.layer_norm = layer_norm | |
| self.proj = nn.Conv2d( | |
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
| ) | |
| if layer_norm: | |
| self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
| else: | |
| self.norm = None | |
| self.patch_size = patch_size | |
| # See: | |
| # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
| self.height, self.width = height // patch_size, width // patch_size | |
| self.base_size = height // patch_size | |
| self.interpolation_scale = interpolation_scale | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
| ) | |
| self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
| def forward(self, latent): | |
| height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
| latent = self.proj(latent) | |
| if self.flatten: | |
| latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC | |
| if self.layer_norm: | |
| latent = self.norm(latent) | |
| # Interpolate positional embeddings if needed. | |
| # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
| if self.height != height or self.width != width: | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim=self.pos_embed.shape[-1], | |
| grid_size=(height, width), | |
| base_size=self.base_size, | |
| interpolation_scale=self.interpolation_scale, | |
| ) | |
| pos_embed = torch.from_numpy(pos_embed) | |
| pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
| else: | |
| pos_embed = self.pos_embed | |
| return (latent + pos_embed).to(latent.dtype) | |
| class TimestepEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| time_embed_dim: int, | |
| act_fn: str = "silu", | |
| out_dim: int = None, | |
| post_act_fn: Optional[str] = None, | |
| cond_proj_dim=None, | |
| sample_proj_bias=True, | |
| ): | |
| super().__init__() | |
| linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear | |
| self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) | |
| if cond_proj_dim is not None: | |
| self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) | |
| else: | |
| self.cond_proj = None | |
| self.act = get_activation(act_fn) | |
| if out_dim is not None: | |
| time_embed_dim_out = out_dim | |
| else: | |
| time_embed_dim_out = time_embed_dim | |
| self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias) | |
| if post_act_fn is None: | |
| self.post_act = None | |
| else: | |
| self.post_act = get_activation(post_act_fn) | |
| def forward(self, sample, condition=None): | |
| if condition is not None: | |
| sample = sample + self.cond_proj(condition) | |
| sample = self.linear_1(sample) | |
| if self.act is not None: | |
| sample = self.act(sample) | |
| sample = self.linear_2(sample) | |
| if self.post_act is not None: | |
| sample = self.post_act(sample) | |
| return sample | |
| class Timesteps(nn.Module): | |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.flip_sin_to_cos = flip_sin_to_cos | |
| self.downscale_freq_shift = downscale_freq_shift | |
| def forward(self, timesteps): | |
| t_emb = get_timestep_embedding( | |
| timesteps, | |
| self.num_channels, | |
| flip_sin_to_cos=self.flip_sin_to_cos, | |
| downscale_freq_shift=self.downscale_freq_shift, | |
| ) | |
| return t_emb | |
| class GaussianFourierProjection(nn.Module): | |
| """Gaussian Fourier embeddings for noise levels.""" | |
| def __init__( | |
| self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False | |
| ): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) | |
| self.log = log | |
| self.flip_sin_to_cos = flip_sin_to_cos | |
| if set_W_to_weight: | |
| # to delete later | |
| self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) | |
| self.weight = self.W | |
| def forward(self, x): | |
| if self.log: | |
| x = torch.log(x) | |
| x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi | |
| if self.flip_sin_to_cos: | |
| out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) | |
| else: | |
| out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) | |
| return out | |
| class SinusoidalPositionalEmbedding(nn.Module): | |
| """Apply positional information to a sequence of embeddings. | |
| Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to | |
| them | |
| Args: | |
| embed_dim: (int): Dimension of the positional embedding. | |
| max_seq_length: Maximum sequence length to apply positional embeddings | |
| """ | |
| def __init__(self, embed_dim: int, max_seq_length: int = 32): | |
| super().__init__() | |
| position = torch.arange(max_seq_length).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) | |
| pe = torch.zeros(1, max_seq_length, embed_dim) | |
| pe[0, :, 0::2] = torch.sin(position * div_term) | |
| pe[0, :, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| _, seq_length, _ = x.shape | |
| x = x + self.pe[:, :seq_length] | |
| return x | |
| class ImagePositionalEmbeddings(nn.Module): | |
| """ | |
| Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the | |
| height and width of the latent space. | |
| For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 | |
| For VQ-diffusion: | |
| Output vector embeddings are used as input for the transformer. | |
| Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. | |
| Args: | |
| num_embed (`int`): | |
| Number of embeddings for the latent pixels embeddings. | |
| height (`int`): | |
| Height of the latent image i.e. the number of height embeddings. | |
| width (`int`): | |
| Width of the latent image i.e. the number of width embeddings. | |
| embed_dim (`int`): | |
| Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. | |
| """ | |
| def __init__( | |
| self, | |
| num_embed: int, | |
| height: int, | |
| width: int, | |
| embed_dim: int, | |
| ): | |
| super().__init__() | |
| self.height = height | |
| self.width = width | |
| self.num_embed = num_embed | |
| self.embed_dim = embed_dim | |
| self.emb = nn.Embedding(self.num_embed, embed_dim) | |
| self.height_emb = nn.Embedding(self.height, embed_dim) | |
| self.width_emb = nn.Embedding(self.width, embed_dim) | |
| def forward(self, index): | |
| emb = self.emb(index) | |
| height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) | |
| # 1 x H x D -> 1 x H x 1 x D | |
| height_emb = height_emb.unsqueeze(2) | |
| width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) | |
| # 1 x W x D -> 1 x 1 x W x D | |
| width_emb = width_emb.unsqueeze(1) | |
| pos_emb = height_emb + width_emb | |
| # 1 x H x W x D -> 1 x L xD | |
| pos_emb = pos_emb.view(1, self.height * self.width, -1) | |
| emb = emb + pos_emb[:, : emb.shape[1], :] | |
| return emb | |
| class LabelEmbedding(nn.Module): | |
| """ | |
| Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. | |
| Args: | |
| num_classes (`int`): The number of classes. | |
| hidden_size (`int`): The size of the vector embeddings. | |
| dropout_prob (`float`): The probability of dropping a label. | |
| """ | |
| def __init__(self, num_classes, hidden_size, dropout_prob): | |
| super().__init__() | |
| use_cfg_embedding = dropout_prob > 0 | |
| self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) | |
| self.num_classes = num_classes | |
| self.dropout_prob = dropout_prob | |
| def token_drop(self, labels, force_drop_ids=None): | |
| """ | |
| Drops labels to enable classifier-free guidance. | |
| """ | |
| if force_drop_ids is None: | |
| drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob | |
| else: | |
| drop_ids = torch.tensor(force_drop_ids == 1) | |
| labels = torch.where(drop_ids, self.num_classes, labels) | |
| return labels | |
| def forward(self, labels: torch.LongTensor, force_drop_ids=None): | |
| use_dropout = self.dropout_prob > 0 | |
| if (self.training and use_dropout) or (force_drop_ids is not None): | |
| labels = self.token_drop(labels, force_drop_ids) | |
| embeddings = self.embedding_table(labels) | |
| return embeddings | |
| class TextImageProjection(nn.Module): | |
| def __init__( | |
| self, | |
| text_embed_dim: int = 1024, | |
| image_embed_dim: int = 768, | |
| cross_attention_dim: int = 768, | |
| num_image_text_embeds: int = 10, | |
| ): | |
| super().__init__() | |
| self.num_image_text_embeds = num_image_text_embeds | |
| self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) | |
| self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) | |
| def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): | |
| batch_size = text_embeds.shape[0] | |
| # image | |
| image_text_embeds = self.image_embeds(image_embeds) | |
| image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) | |
| # text | |
| text_embeds = self.text_proj(text_embeds) | |
| return torch.cat([image_text_embeds, text_embeds], dim=1) | |
| class ImageProjection(nn.Module): | |
| def __init__( | |
| self, | |
| image_embed_dim: int = 768, | |
| cross_attention_dim: int = 768, | |
| num_image_text_embeds: int = 32, | |
| ): | |
| super().__init__() | |
| self.num_image_text_embeds = num_image_text_embeds | |
| self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) | |
| self.norm = nn.LayerNorm(cross_attention_dim) | |
| def forward(self, image_embeds: torch.FloatTensor): | |
| batch_size = image_embeds.shape[0] | |
| # image | |
| image_embeds = self.image_embeds(image_embeds) | |
| image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) | |
| image_embeds = self.norm(image_embeds) | |
| return image_embeds | |
| class IPAdapterFullImageProjection(nn.Module): | |
| def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): | |
| super().__init__() | |
| from .attention import FeedForward | |
| self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") | |
| self.norm = nn.LayerNorm(cross_attention_dim) | |
| def forward(self, image_embeds: torch.FloatTensor): | |
| return self.norm(self.ff(image_embeds)) | |
| class CombinedTimestepLabelEmbeddings(nn.Module): | |
| def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): | |
| super().__init__() | |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) | |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
| self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) | |
| def forward(self, timestep, class_labels, hidden_dtype=None): | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
| class_labels = self.class_embedder(class_labels) # (N, D) | |
| conditioning = timesteps_emb + class_labels # (N, D) | |
| return conditioning | |
| class TextTimeEmbedding(nn.Module): | |
| def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(encoder_dim) | |
| self.pool = AttentionPooling(num_heads, encoder_dim) | |
| self.proj = nn.Linear(encoder_dim, time_embed_dim) | |
| self.norm2 = nn.LayerNorm(time_embed_dim) | |
| def forward(self, hidden_states): | |
| hidden_states = self.norm1(hidden_states) | |
| hidden_states = self.pool(hidden_states) | |
| hidden_states = self.proj(hidden_states) | |
| hidden_states = self.norm2(hidden_states) | |
| return hidden_states | |
| class TextImageTimeEmbedding(nn.Module): | |
| def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
| super().__init__() | |
| self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) | |
| self.text_norm = nn.LayerNorm(time_embed_dim) | |
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
| def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): | |
| # text | |
| time_text_embeds = self.text_proj(text_embeds) | |
| time_text_embeds = self.text_norm(time_text_embeds) | |
| # image | |
| time_image_embeds = self.image_proj(image_embeds) | |
| return time_image_embeds + time_text_embeds | |
| class ImageTimeEmbedding(nn.Module): | |
| def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
| super().__init__() | |
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
| self.image_norm = nn.LayerNorm(time_embed_dim) | |
| def forward(self, image_embeds: torch.FloatTensor): | |
| # image | |
| time_image_embeds = self.image_proj(image_embeds) | |
| time_image_embeds = self.image_norm(time_image_embeds) | |
| return time_image_embeds | |
| class ImageHintTimeEmbedding(nn.Module): | |
| def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): | |
| super().__init__() | |
| self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) | |
| self.image_norm = nn.LayerNorm(time_embed_dim) | |
| self.input_hint_block = nn.Sequential( | |
| nn.Conv2d(3, 16, 3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(16, 16, 3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(16, 32, 3, padding=1, stride=2), | |
| nn.SiLU(), | |
| nn.Conv2d(32, 32, 3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(32, 96, 3, padding=1, stride=2), | |
| nn.SiLU(), | |
| nn.Conv2d(96, 96, 3, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(96, 256, 3, padding=1, stride=2), | |
| nn.SiLU(), | |
| nn.Conv2d(256, 4, 3, padding=1), | |
| ) | |
| def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): | |
| # image | |
| time_image_embeds = self.image_proj(image_embeds) | |
| time_image_embeds = self.image_norm(time_image_embeds) | |
| hint = self.input_hint_block(hint) | |
| return time_image_embeds, hint | |
| class AttentionPooling(nn.Module): | |
| # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 | |
| def __init__(self, num_heads, embed_dim, dtype=None): | |
| super().__init__() | |
| self.dtype = dtype | |
| self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) | |
| self.num_heads = num_heads | |
| self.dim_per_head = embed_dim // self.num_heads | |
| def forward(self, x): | |
| bs, length, width = x.size() | |
| def shape(x): | |
| # (bs, length, width) --> (bs, length, n_heads, dim_per_head) | |
| x = x.view(bs, -1, self.num_heads, self.dim_per_head) | |
| # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) | |
| x = x.transpose(1, 2) | |
| # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) | |
| x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) | |
| # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) | |
| x = x.transpose(1, 2) | |
| return x | |
| class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) | |
| x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) | |
| # (bs*n_heads, class_token_length, dim_per_head) | |
| q = shape(self.q_proj(class_token)) | |
| # (bs*n_heads, length+class_token_length, dim_per_head) | |
| k = shape(self.k_proj(x)) | |
| v = shape(self.v_proj(x)) | |
| # (bs*n_heads, class_token_length, length+class_token_length): | |
| scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) | |
| weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards | |
| weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
| # (bs*n_heads, dim_per_head, class_token_length) | |
| a = torch.einsum("bts,bcs->bct", weight, v) | |
| # (bs, length+1, width) | |
| a = a.reshape(bs, -1, 1).transpose(1, 2) | |
| return a[:, 0, :] # cls_token | |
| def get_fourier_embeds_from_boundingbox(embed_dim, box): | |
| """ | |
| Args: | |
| embed_dim: int | |
| box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline | |
| Returns: | |
| [B x N x embed_dim] tensor of positional embeddings | |
| """ | |
| batch_size, num_boxes = box.shape[:2] | |
| emb = 100 ** (torch.arange(embed_dim) / embed_dim) | |
| emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) | |
| emb = emb * box.unsqueeze(-1) | |
| emb = torch.stack((emb.sin(), emb.cos()), dim=-1) | |
| emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) | |
| return emb | |
| class GLIGENTextBoundingboxProjection(nn.Module): | |
| def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): | |
| super().__init__() | |
| self.positive_len = positive_len | |
| self.out_dim = out_dim | |
| self.fourier_embedder_dim = fourier_freqs | |
| self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy | |
| if isinstance(out_dim, tuple): | |
| out_dim = out_dim[0] | |
| if feature_type == "text-only": | |
| self.linears = nn.Sequential( | |
| nn.Linear(self.positive_len + self.position_dim, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, out_dim), | |
| ) | |
| self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
| elif feature_type == "text-image": | |
| self.linears_text = nn.Sequential( | |
| nn.Linear(self.positive_len + self.position_dim, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, out_dim), | |
| ) | |
| self.linears_image = nn.Sequential( | |
| nn.Linear(self.positive_len + self.position_dim, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, out_dim), | |
| ) | |
| self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
| self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
| self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) | |
| def forward( | |
| self, | |
| boxes, | |
| masks, | |
| positive_embeddings=None, | |
| phrases_masks=None, | |
| image_masks=None, | |
| phrases_embeddings=None, | |
| image_embeddings=None, | |
| ): | |
| masks = masks.unsqueeze(-1) | |
| # embedding position (it may includes padding as placeholder) | |
| xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C | |
| # learnable null embedding | |
| xyxy_null = self.null_position_feature.view(1, 1, -1) | |
| # replace padding with learnable null embedding | |
| xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null | |
| # positionet with text only information | |
| if positive_embeddings is not None: | |
| # learnable null embedding | |
| positive_null = self.null_positive_feature.view(1, 1, -1) | |
| # replace padding with learnable null embedding | |
| positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null | |
| objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) | |
| # positionet with text and image infomation | |
| else: | |
| phrases_masks = phrases_masks.unsqueeze(-1) | |
| image_masks = image_masks.unsqueeze(-1) | |
| # learnable null embedding | |
| text_null = self.null_text_feature.view(1, 1, -1) | |
| image_null = self.null_image_feature.view(1, 1, -1) | |
| # replace padding with learnable null embedding | |
| phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null | |
| image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null | |
| objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) | |
| objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) | |
| objs = torch.cat([objs_text, objs_image], dim=1) | |
| return objs | |
| class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): | |
| """ | |
| For PixArt-Alpha. | |
| Reference: | |
| https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | |
| """ | |
| def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): | |
| super().__init__() | |
| self.outdim = size_emb_dim | |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
| self.use_additional_conditions = use_additional_conditions | |
| if use_additional_conditions: | |
| self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
| self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
| def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
| if self.use_additional_conditions: | |
| resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) | |
| resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) | |
| aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) | |
| aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) | |
| conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) | |
| else: | |
| conditioning = timesteps_emb | |
| return conditioning | |
| class PixArtAlphaTextProjection(nn.Module): | |
| """ | |
| Projects caption embeddings. Also handles dropout for classifier-free guidance. | |
| Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
| """ | |
| def __init__(self, in_features, hidden_size, num_tokens=120): | |
| super().__init__() | |
| self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) | |
| self.act_1 = nn.GELU(approximate="tanh") | |
| self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) | |
| def forward(self, caption): | |
| hidden_states = self.linear_1(caption) | |
| hidden_states = self.act_1(hidden_states) | |
| hidden_states = self.linear_2(hidden_states) | |
| return hidden_states | |
| class IPAdapterPlusImageProjection(nn.Module): | |
| """Resampler of IP-Adapter Plus. | |
| Args: | |
| ---- | |
| embed_dims (int): The feature dimension. Defaults to 768. | |
| output_dims (int): The number of output channels, that is the same | |
| number of the channels in the | |
| `unet.config.cross_attention_dim`. Defaults to 1024. | |
| hidden_dims (int): The number of hidden channels. Defaults to 1280. | |
| depth (int): The number of blocks. Defaults to 8. | |
| dim_head (int): The number of head channels. Defaults to 64. | |
| heads (int): Parallel attention heads. Defaults to 16. | |
| num_queries (int): The number of queries. Defaults to 8. | |
| ffn_ratio (float): The expansion ratio of feedforward network hidden | |
| layer channels. Defaults to 4. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dims: int = 768, | |
| output_dims: int = 1024, | |
| hidden_dims: int = 1280, | |
| depth: int = 4, | |
| dim_head: int = 64, | |
| heads: int = 16, | |
| num_queries: int = 8, | |
| ffn_ratio: float = 4, | |
| ) -> None: | |
| super().__init__() | |
| from .attention import FeedForward # Lazy import to avoid circular import | |
| self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) | |
| self.proj_in = nn.Linear(embed_dims, hidden_dims) | |
| self.proj_out = nn.Linear(hidden_dims, output_dims) | |
| self.norm_out = nn.LayerNorm(output_dims) | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| nn.LayerNorm(hidden_dims), | |
| nn.LayerNorm(hidden_dims), | |
| Attention( | |
| query_dim=hidden_dims, | |
| dim_head=dim_head, | |
| heads=heads, | |
| out_bias=False, | |
| ), | |
| nn.Sequential( | |
| nn.LayerNorm(hidden_dims), | |
| FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), | |
| ), | |
| ] | |
| ) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass. | |
| Args: | |
| ---- | |
| x (torch.Tensor): Input Tensor. | |
| Returns: | |
| ------- | |
| torch.Tensor: Output Tensor. | |
| """ | |
| latents = self.latents.repeat(x.size(0), 1, 1) | |
| x = self.proj_in(x) | |
| for ln0, ln1, attn, ff in self.layers: | |
| residual = latents | |
| encoder_hidden_states = ln0(x) | |
| latents = ln1(latents) | |
| encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) | |
| latents = attn(latents, encoder_hidden_states) + residual | |
| latents = ff(latents) + latents | |
| latents = self.proj_out(latents) | |
| return self.norm_out(latents) | |