LN3Diff_I23D / dit /dit_i23d.py
NIRVANALAN
update dep
592a426
raw
history blame
26 kB
import torch.nn as nn
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from pdb import set_trace as st
from ldm.modules.attention import MemoryEfficientCrossAttention
from .dit_models_xformers import DiT, get_2d_sincos_pos_embed, ImageCondDiTBlock, FinalLayer, CaptionEmbedder, approx_gelu, ImageCondDiTBlockPixelArt, t2i_modulate, ImageCondDiTBlockPixelArtRMSNorm, T2IFinalLayer, ImageCondDiTBlockPixelArtRMSNormNoClip
from timm.models.vision_transformer import Mlp
try:
from apex.normalization import FusedLayerNorm as LayerNorm
from apex.normalization import FusedRMSNorm as RMSNorm
except:
from torch.nn import LayerNorm
from dit.norm import RMSNorm
# from vit.vit_triplane import XYZPosEmbed
class DiT_I23D(DiT):
# DiT with 3D_aware operations
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlock,
final_layer_blk=T2IFinalLayer,
):
# st()
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
mixed_prediction, context_dim, roll_out, vit_blk,
T2IFinalLayer)
assert self.roll_out
# if context_dim is not None:
# self.dino_proj = CaptionEmbedder(context_dim,
self.clip_ctx_dim = 1024 # vit-l
# self.dino_proj = CaptionEmbedder(self.clip_ctx_dim, # ! dino-vitl/14 here, for img-cond
self.dino_proj = CaptionEmbedder(context_dim, # ! dino-vitb/14 here, for MV-cond. hard coded for now...
# self.dino_proj = CaptionEmbedder(1024, # ! dino-vitb/14 here, for MV-cond. hard coded for now...
hidden_size,
act_layer=approx_gelu)
self.clip_spatial_proj = CaptionEmbedder(1024, # clip_I-L
hidden_size,
act_layer=approx_gelu)
def init_PE_3D_aware(self):
self.pos_embed = nn.Parameter(torch.zeros(
1, self.plane_n * self.x_embedder.num_patches, self.embed_dim),
requires_grad=False)
# Initialize (and freeze) pos_embed by sin-cos embedding:
p = int(self.x_embedder.num_patches**0.5)
D = self.pos_embed.shape[-1]
grid_size = (self.plane_n, p * p) # B n HW C
pos_embed = get_2d_sincos_pos_embed(D, grid_size).reshape(
self.plane_n * p * p, D) # H*W, D
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0))
def initialize_weights(self):
super().initialize_weights()
# ! add 3d-aware PE
self.init_PE_3D_aware()
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# context = self.clip_text_proj(context)
clip_cls_token = self.clip_text_proj(context['vector'])
clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
# ! todo, return spatial clip features.
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens.
# x = rearrange(x, '(b n) l c ->b (n l) c', n=3)
# assert context.ndim == 2
# if isinstance(context, dict):
# context = context['crossattn'] # sgm conditioner compat
# c = t + context
# else:
# c = t # BS 1024
for blk_idx, block in enumerate(self.blocks):
x = block(x, t, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
# x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)
# cast to float32 for better accuracy
x = x.to(torch.float32).contiguous()
return x
# ! compat issue
def forward_with_cfg(self, x, t, context, cfg_scale):
"""
Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# half = x[: len(x) // 2]
# combined = torch.cat([half, half], dim=0)
eps = self.forward(x, t, context)
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
# eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return eps
class DiT_I23D_PixelArt(DiT_I23D):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk=FinalLayer,
):
# st()
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk,
final_layer_blk)
# ! a shared one
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
# ! single
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
del self.clip_text_proj
self.cap_embedder = nn.Sequential( # TODO, init with zero here.
LayerNorm(pooling_ctx_dim),
nn.Linear(
pooling_ctx_dim,
hidden_size,
),
)
nn.init.constant_(self.cap_embedder[-1].weight, 0)
nn.init.constant_(self.cap_embedder[-1].bias, 0)
print(self) # check model arch
self.attention_y_norm = RMSNorm(
1024, eps=1e-5
) # https://github.com/Alpha-VLLM/Lumina-T2X/blob/0c8dd6a07a3b7c18da3d91f37b1e00e7ae661293/lumina_t2i/models/model.py#L570C9-L570C61
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# context = self.clip_text_proj(context)
clip_cls_token = self.cap_embedder(context['vector'])
clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
clip_spatial_token = self.attention_y_norm(clip_spatial_token) # avoid re-normalization in each blk
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens.
# x = rearrange(x, '(b n) l c ->b (n l) c', n=3)
# assert context.ndim == 2
# if isinstance(context, dict):
# context = context['crossattn'] # sgm conditioner compat
# c = t + context
# else:
# c = t # BS 1024
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
# x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)
# cast to float32 for better accuracy
x = x.to(torch.float32).contiguous()
return x
class DiT_I23D_PixelArt_MVCond(DiT_I23D_PixelArt):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim,
pooling_ctx_dim, roll_out, ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk)
# support multi-view img condition
# DINO handles global pooling here; clip takes care of camera-cond with ModLN
# Input DINO concat also + global pool. InstantMesh adopts DINO (but CA).
# expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size.
del self.dino_proj
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# st()
# (Pdb) p context.keys()
# dict_keys(['crossattn', 'vector', 'concat'])
# (Pdb) p context['vector'].shape
# torch.Size([2, 768])
# (Pdb) p context['crossattn'].shape
# torch.Size([2, 256, 1024])
# (Pdb) p context['concat'].shape
# torch.Size([2, 4, 256, 768]) # mv dino spatial features
# ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
# DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
for blk_idx, block in enumerate(self.blocks):
# x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# ! DINO tokens for CA, CLIP tokens for append here.
x = block(x, t0, dino_spatial_token=clip_spatial_token, clip_spatial_token=dino_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
x = x.to(torch.float32).contiguous()
return x
class DiT_I23D_PixelArt_MVCond_noClip(DiT_I23D_PixelArt):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim,
pooling_ctx_dim, roll_out,
ImageCondDiTBlockPixelArtRMSNormNoClip,
final_layer_blk)
# support multi-view img condition
# DINO handles global pooling here; clip takes care of camera-cond with ModLN
# Input DINO concat also + global pool. InstantMesh adopts DINO (but CA).
# expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size.
del self.dino_proj
del self.clip_spatial_proj, self.cap_embedder # no clip required
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# st()
# (Pdb) p context.keys()
# dict_keys(['crossattn', 'vector', 'concat'])
# (Pdb) p context['vector'].shape
# torch.Size([2, 768])
# (Pdb) p context['crossattn'].shape
# torch.Size([2, 256, 1024])
# (Pdb) p context['concat'].shape
# torch.Size([2, 4, 256, 768]) # mv dino spatial features
# ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
# DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
# clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t = self.t_embedder(timesteps)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
for blk_idx, block in enumerate(self.blocks):
# x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# ! DINO tokens for CA, CLIP tokens for append here.
x = block(x, t0, dino_spatial_token=dino_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
x = x.to(torch.float32).contiguous()
return x
# pcd-structured latent ddpm
class DiT_pcd_I23D_PixelArt_MVCond(DiT_I23D_PixelArt_MVCond):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim,
pooling_ctx_dim,
roll_out, ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk)
# ! first, normalize xyz from [-0.45,0.45] to [-1,1]
# Then, encode xyz with point fourier feat + MLP projection, serves as PE here.
# a separate MLP for the KL feature
# add them together in the feature space
# use a single MLP (final_layer) to map them back to 16 + 3 dims.
self.x_embedder = Mlp(in_features=in_channels,
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=approx_gelu,
drop=0)
del self.pos_embed
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# st()
# (Pdb) p context.keys()
# dict_keys(['crossattn', 'vector', 'concat'])
# (Pdb) p context['vector'].shape
# torch.Size([2, 768])
# (Pdb) p context['crossattn'].shape
# torch.Size([2, 256, 1024])
# (Pdb) p context['concat'].shape
# torch.Size([2, 4, 256, 768]) # mv dino spatial features
# ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
# DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
x = self.x_embedder(x)
for blk_idx, block in enumerate(self.blocks):
# x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# ! DINO tokens for CA, CLIP tokens for append here.
x = block(x, t0, dino_spatial_token=clip_spatial_token, clip_spatial_token=dino_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
#################################################################################
# DiT_I23D Configs #
#################################################################################
def DiT_XL_2(**kwargs):
return DiT_I23D(depth=28,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_L_2(**kwargs):
return DiT_I23D(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_B_2(**kwargs):
return DiT_I23D(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
def DiT_B_1(**kwargs):
return DiT_I23D(depth=12,
hidden_size=768,
patch_size=1,
num_heads=12,
**kwargs)
def DiT_L_Pixelart_2(**kwargs):
return DiT_I23D_PixelArt(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_B_Pixelart_2(**kwargs):
return DiT_I23D_PixelArt(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
def DiT_L_Pixelart_MV_2(**kwargs):
return DiT_I23D_PixelArt_MVCond(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_L_Pixelart_MV_2_noclip(**kwargs):
return DiT_I23D_PixelArt_MVCond_noClip(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_XL_Pixelart_MV_2(**kwargs):
return DiT_I23D_PixelArt_MVCond(depth=28,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_B_Pixelart_MV_2(**kwargs):
return DiT_I23D_PixelArt_MVCond(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
# pcd latent
def DiT_L_Pixelart_MV_pcd(**kwargs):
return DiT_pcd_I23D_PixelArt_MVCond(depth=24,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
**kwargs)
DiT_models = {
'DiT-XL/2': DiT_XL_2,
'DiT-L/2': DiT_L_2,
'DiT-B/2': DiT_B_2,
'DiT-B/1': DiT_B_1,
'DiT-PixArt-L/2': DiT_L_Pixelart_2,
'DiT-PixArt-MV-XL/2': DiT_XL_Pixelart_MV_2,
# 'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2,
'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2_noclip,
'DiT-PixArt-MV-PCD-L': DiT_L_Pixelart_MV_pcd,
'DiT-PixArt-MV-B/2': DiT_B_Pixelart_MV_2,
'DiT-PixArt-B/2': DiT_B_Pixelart_2,
}