Spaces:
Build error
Build error
| import math | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from diffusers.models.activations import get_activation, FP32SiLU | |
| def get_timestep_embedding( | |
| timesteps: torch.Tensor, | |
| embedding_dim: int, | |
| flip_sin_to_cos: bool = False, | |
| downscale_freq_shift: float = 1, | |
| scale: float = 1, | |
| max_period: int = 10000, | |
| ): | |
| """ | |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. | |
| Args | |
| timesteps (torch.Tensor): | |
| a 1-D Tensor of N indices, one per batch element. These may be fractional. | |
| embedding_dim (int): | |
| the dimension of the output. | |
| flip_sin_to_cos (bool): | |
| Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) | |
| downscale_freq_shift (float): | |
| Controls the delta between frequencies between dimensions | |
| scale (float): | |
| Scaling factor applied to the embeddings. | |
| max_period (int): | |
| Controls the maximum frequency of the embeddings | |
| Returns | |
| torch.Tensor: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
| half_dim = embedding_dim // 2 | |
| exponent = -math.log(max_period) * torch.arange( | |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
| ) | |
| exponent = exponent / (half_dim - downscale_freq_shift) | |
| emb = torch.exp(exponent) | |
| emb = timesteps[:, None].float() * emb[None, :] | |
| # scale embeddings | |
| emb = scale * emb | |
| # concat sine and cosine embeddings | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
| # flip sine and cosine embeddings | |
| if flip_sin_to_cos: | |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
| # zero pad | |
| if embedding_dim % 2 == 1: | |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
| return emb | |
| class Timesteps(nn.Module): | |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.flip_sin_to_cos = flip_sin_to_cos | |
| self.downscale_freq_shift = downscale_freq_shift | |
| self.scale = scale | |
| def forward(self, timesteps): | |
| t_emb = get_timestep_embedding( | |
| timesteps, | |
| self.num_channels, | |
| flip_sin_to_cos=self.flip_sin_to_cos, | |
| downscale_freq_shift=self.downscale_freq_shift, | |
| scale=self.scale, | |
| ) | |
| return t_emb | |
| class TimestepEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| time_embed_dim: int, | |
| act_fn: str = "silu", | |
| out_dim: int = None, | |
| post_act_fn: Optional[str] = None, | |
| cond_proj_dim=None, | |
| sample_proj_bias=True, | |
| ): | |
| super().__init__() | |
| self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) | |
| if cond_proj_dim is not None: | |
| self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) | |
| else: | |
| self.cond_proj = None | |
| self.act = get_activation(act_fn) | |
| if out_dim is not None: | |
| time_embed_dim_out = out_dim | |
| else: | |
| time_embed_dim_out = time_embed_dim | |
| self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) | |
| if post_act_fn is None: | |
| self.post_act = None | |
| else: | |
| self.post_act = get_activation(post_act_fn) | |
| def forward(self, sample, condition=None): | |
| if condition is not None: | |
| sample = sample + self.cond_proj(condition) | |
| sample = self.linear_1(sample) | |
| if self.act is not None: | |
| sample = self.act(sample) | |
| sample = self.linear_2(sample) | |
| if self.post_act is not None: | |
| sample = self.post_act(sample) | |
| return sample | |
| class PixArtAlphaTextProjection(nn.Module): | |
| """ | |
| Projects caption embeddings. Also handles dropout for classifier-free guidance. | |
| Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
| """ | |
| def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): | |
| super().__init__() | |
| if out_features is None: | |
| out_features = hidden_size | |
| self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) | |
| if act_fn == "gelu_tanh": | |
| self.act_1 = nn.GELU(approximate="tanh") | |
| elif act_fn == "silu": | |
| self.act_1 = nn.SiLU() | |
| elif act_fn == "silu_fp32": | |
| self.act_1 = FP32SiLU() | |
| else: | |
| raise ValueError(f"Unknown activation function: {act_fn}") | |
| self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) | |
| def forward(self, caption): | |
| hidden_states = self.linear_1(caption) | |
| hidden_states = self.act_1(hidden_states) | |
| hidden_states = self.linear_2(hidden_states) | |
| return hidden_states | |
| class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): | |
| def __init__(self, embedding_dim, pooled_projection_dim): | |
| super().__init__() | |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
| self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
| self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
| def forward(self, timestep, guidance, pooled_projection): | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
| guidance_proj = self.time_proj(guidance) | |
| guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
| time_guidance_emb = timesteps_emb + guidance_emb | |
| pooled_projections = self.text_embedder(pooled_projection) | |
| conditioning = time_guidance_emb + pooled_projections | |
| return conditioning | |
| class CombinedTimestepTextProjEmbeddings(nn.Module): | |
| def __init__(self, embedding_dim, pooled_projection_dim): | |
| super().__init__() | |
| self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
| self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
| def forward(self, timestep, pooled_projection): | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
| pooled_projections = self.text_embedder(pooled_projection) | |
| conditioning = timesteps_emb + pooled_projections | |
| return conditioning |