Spaces:
Build error
Build error
import math | |
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from diffusers.models.activations import get_activation, FP32SiLU | |
def get_timestep_embedding( | |
timesteps: torch.Tensor, | |
embedding_dim: int, | |
flip_sin_to_cos: bool = False, | |
downscale_freq_shift: float = 1, | |
scale: float = 1, | |
max_period: int = 10000, | |
): | |
""" | |
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. | |
Args | |
timesteps (torch.Tensor): | |
a 1-D Tensor of N indices, one per batch element. These may be fractional. | |
embedding_dim (int): | |
the dimension of the output. | |
flip_sin_to_cos (bool): | |
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) | |
downscale_freq_shift (float): | |
Controls the delta between frequencies between dimensions | |
scale (float): | |
Scaling factor applied to the embeddings. | |
max_period (int): | |
Controls the maximum frequency of the embeddings | |
Returns | |
torch.Tensor: an [N x dim] Tensor of positional embeddings. | |
""" | |
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
half_dim = embedding_dim // 2 | |
exponent = -math.log(max_period) * torch.arange( | |
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
) | |
exponent = exponent / (half_dim - downscale_freq_shift) | |
emb = torch.exp(exponent) | |
emb = timesteps[:, None].float() * emb[None, :] | |
# scale embeddings | |
emb = scale * emb | |
# concat sine and cosine embeddings | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
# flip sine and cosine embeddings | |
if flip_sin_to_cos: | |
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
# zero pad | |
if embedding_dim % 2 == 1: | |
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
return emb | |
class Timesteps(nn.Module): | |
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): | |
super().__init__() | |
self.num_channels = num_channels | |
self.flip_sin_to_cos = flip_sin_to_cos | |
self.downscale_freq_shift = downscale_freq_shift | |
self.scale = scale | |
def forward(self, timesteps): | |
t_emb = get_timestep_embedding( | |
timesteps, | |
self.num_channels, | |
flip_sin_to_cos=self.flip_sin_to_cos, | |
downscale_freq_shift=self.downscale_freq_shift, | |
scale=self.scale, | |
) | |
return t_emb | |
class TimestepEmbedding(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
time_embed_dim: int, | |
act_fn: str = "silu", | |
out_dim: int = None, | |
post_act_fn: Optional[str] = None, | |
cond_proj_dim=None, | |
sample_proj_bias=True, | |
): | |
super().__init__() | |
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) | |
if cond_proj_dim is not None: | |
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) | |
else: | |
self.cond_proj = None | |
self.act = get_activation(act_fn) | |
if out_dim is not None: | |
time_embed_dim_out = out_dim | |
else: | |
time_embed_dim_out = time_embed_dim | |
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) | |
if post_act_fn is None: | |
self.post_act = None | |
else: | |
self.post_act = get_activation(post_act_fn) | |
def forward(self, sample, condition=None): | |
if condition is not None: | |
sample = sample + self.cond_proj(condition) | |
sample = self.linear_1(sample) | |
if self.act is not None: | |
sample = self.act(sample) | |
sample = self.linear_2(sample) | |
if self.post_act is not None: | |
sample = self.post_act(sample) | |
return sample | |
class PixArtAlphaTextProjection(nn.Module): | |
""" | |
Projects caption embeddings. Also handles dropout for classifier-free guidance. | |
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
""" | |
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): | |
super().__init__() | |
if out_features is None: | |
out_features = hidden_size | |
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) | |
if act_fn == "gelu_tanh": | |
self.act_1 = nn.GELU(approximate="tanh") | |
elif act_fn == "silu": | |
self.act_1 = nn.SiLU() | |
elif act_fn == "silu_fp32": | |
self.act_1 = FP32SiLU() | |
else: | |
raise ValueError(f"Unknown activation function: {act_fn}") | |
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) | |
def forward(self, caption): | |
hidden_states = self.linear_1(caption) | |
hidden_states = self.act_1(hidden_states) | |
hidden_states = self.linear_2(hidden_states) | |
return hidden_states | |
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): | |
def __init__(self, embedding_dim, pooled_projection_dim): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
def forward(self, timestep, guidance, pooled_projection): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
guidance_proj = self.time_proj(guidance) | |
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
time_guidance_emb = timesteps_emb + guidance_emb | |
pooled_projections = self.text_embedder(pooled_projection) | |
conditioning = time_guidance_emb + pooled_projections | |
return conditioning | |
class CombinedTimestepTextProjEmbeddings(nn.Module): | |
def __init__(self, embedding_dim, pooled_projection_dim): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") | |
def forward(self, timestep, pooled_projection): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) | |
pooled_projections = self.text_embedder(pooled_projection) | |
conditioning = timesteps_emb + pooled_projections | |
return conditioning |