rahul7star's picture
Migrated from GitHub
0084610 verified
import torch
from torch import nn
from .nn import (
TimeEmbeddings,
TextEmbeddings,
VisualEmbeddings,
RoPE1D,
RoPE3D,
Modulation,
MultiheadSelfAttentionEnc,
MultiheadSelfAttentionDec,
MultiheadCrossAttention,
FeedForward,
OutLayer,
apply_scale_shift_norm,
apply_gate_sum,
)
from .utils import fractal_flatten, fractal_unflatten
class TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
self.text_modulation = Modulation(time_dim, model_dim, 6)
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim)
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.feed_forward = FeedForward(model_dim, ff_dim)
def forward(self, x, time_embed, rope):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, rope)
x = apply_gate_sum(x, out, gate)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
out = self.feed_forward(out)
x = apply_gate_sum(x, out, gate)
return x
class TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
self.visual_modulation = Modulation(time_dim, model_dim, 9)
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim)
self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.cross_attention = MultiheadCrossAttention(model_dim, head_dim)
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.feed_forward = FeedForward(model_dim, ff_dim)
def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
self_attn_params, cross_attn_params, ff_params = torch.chunk(
self.visual_modulation(time_embed), 3, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, rope, sparse_params)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = self.cross_attention(visual_out, text_embed)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
visual_out = self.feed_forward(visual_out)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
return visual_embed
class DiffusionTransformer3D(nn.Module):
def __init__(
self,
in_visual_dim=4,
in_text_dim=3584,
in_text_dim2=768,
time_dim=512,
out_visual_dim=4,
patch_size=(1, 2, 2),
model_dim=2048,
ff_dim=5120,
num_text_blocks=2,
num_visual_blocks=32,
axes_dims=(16, 24, 24),
visual_cond=False,
):
super().__init__()
head_dim = sum(axes_dims)
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_cond = visual_cond
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
self.time_embeddings = TimeEmbeddings(model_dim, time_dim)
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim)
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim)
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
self.text_rope_embeddings = RoPE1D(head_dim)
self.text_transformer_blocks = nn.ModuleList(
[
TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_text_blocks)
]
)
self.visual_rope_embeddings = RoPE3D(axes_dims)
self.visual_transformer_blocks = nn.ModuleList(
[
TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_visual_blocks)
]
)
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
@torch.compile()
def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x,
text_rope_pos):
text_embed = self.text_embeddings(text_embed)
time_embed = self.time_embeddings(time)
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
visual_embed = self.visual_embeddings(x)
text_rope = self.text_rope_embeddings(text_rope_pos)
return text_embed, time_embed, text_rope, visual_embed
@torch.compile()
def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_factor,
sparse_params):
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape,
block_mask=to_fractal)
return visual_embed, visual_shape, to_fractal, visual_rope
@torch.compile()
def after_blocks(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed):
visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
x = self.out_layer(visual_embed, text_embed, time_embed)
return x
def forward(
self,
x,
text_embed,
pooled_text_embed,
time,
visual_rope_pos,
text_rope_pos,
scale_factor=(1.0, 1.0, 1.0),
sparse_params=None
):
text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks(
text_embed, time, pooled_text_embed, x, text_rope_pos)
for text_transformer_block in self.text_transformer_blocks:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks(
visual_embed, visual_rope_pos, scale_factor, sparse_params)
for visual_transformer_block in self.visual_transformer_blocks:
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
visual_rope, sparse_params)
x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed)
return x
def get_dit(conf):
dit = DiffusionTransformer3D(**conf)
return dit