bubbliiiing
Update V5
f62c8b9
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.embeddings import (PixArtAlphaTextProjection, get_timestep_embedding,
TimestepEmbedding, Timesteps)
from einops import rearrange
from torch import nn
class HunyuanDiTAttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1)
x = x + self.positional_embedding[None, :, :].to(x.dtype)
query = self.q_proj(x[:, :1])
key = self.k_proj(x)
value = self.v_proj(x)
batch_size, _, _ = query.size()
query = query.reshape(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2) # (1, H, N, E/H)
key = key.reshape(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
value = value.reshape(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
x = F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=None, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, 1, -1)
x = x.to(query.dtype)
x = self.c_proj(x)
return x.squeeze(1)
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
out_features=embedding_dim,
act_fn="silu_fp32",
)
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
# extra condition2: image meta size embdding
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
return conditioning
class TimePositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 24
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
b, c, f, h, w = x.size()
x = rearrange(x, "b c f h w -> (b h w) f c")
x = x + self.pe[:, :x.size(1)]
x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
return self.dropout(x)