Spaces:
Running
Running
import math | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from diffusers.models.embeddings import (PixArtAlphaTextProjection, get_timestep_embedding, | |
TimestepEmbedding, Timesteps) | |
from einops import rearrange | |
from torch import nn | |
class HunyuanDiTAttentionPool(nn.Module): | |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | |
super().__init__() | |
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) | |
self.k_proj = nn.Linear(embed_dim, embed_dim) | |
self.q_proj = nn.Linear(embed_dim, embed_dim) | |
self.v_proj = nn.Linear(embed_dim, embed_dim) | |
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
self.num_heads = num_heads | |
def forward(self, x): | |
x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1) | |
x = x + self.positional_embedding[None, :, :].to(x.dtype) | |
query = self.q_proj(x[:, :1]) | |
key = self.k_proj(x) | |
value = self.v_proj(x) | |
batch_size, _, _ = query.size() | |
query = query.reshape(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2) # (1, H, N, E/H) | |
key = key.reshape(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H) | |
value = value.reshape(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H) | |
x = F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=None, dropout_p=0.0, is_causal=False) | |
x = x.transpose(1, 2).reshape(batch_size, 1, -1) | |
x = x.to(query.dtype) | |
x = self.c_proj(x) | |
return x.squeeze(1) | |
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): | |
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): | |
super().__init__() | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.pooler = HunyuanDiTAttentionPool( | |
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim | |
) | |
# Here we use a default learned embedder layer for future extension. | |
self.style_embedder = nn.Embedding(1, embedding_dim) | |
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim | |
self.extra_embedder = PixArtAlphaTextProjection( | |
in_features=extra_in_dim, | |
hidden_size=embedding_dim * 4, | |
out_features=embedding_dim, | |
act_fn="silu_fp32", | |
) | |
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) | |
# extra condition1: text | |
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) | |
# extra condition2: image meta size embdding | |
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) | |
image_meta_size = image_meta_size.to(dtype=hidden_dtype) | |
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) | |
# extra condition3: style embedding | |
style_embedding = self.style_embedder(style) # (N, embedding_dim) | |
# Concatenate all extra vectors | |
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) | |
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] | |
return conditioning | |
class TimePositionalEncoding(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
dropout = 0., | |
max_len = 24 | |
): | |
super().__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
position = torch.arange(max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) | |
pe = torch.zeros(1, max_len, d_model) | |
pe[0, :, 0::2] = torch.sin(position * div_term) | |
pe[0, :, 1::2] = torch.cos(position * div_term) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
b, c, f, h, w = x.size() | |
x = rearrange(x, "b c f h w -> (b h w) f c") | |
x = x + self.pe[:, :x.size(1)] | |
x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w) | |
return self.dropout(x) |