zxl
first commit
07c6a04
raw
history blame
20.6 kB
# Adapted from OpenSora
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# OpenSora: https://github.com/hpcaitech/Open-Sora
# --------------------------------------------------------
import os
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp
from transformers import PretrainedConfig, PreTrainedModel
from videosys.core.comm import (
all_to_all_with_pad,
gather_sequence,
get_spatial_pad,
get_temporal_pad,
set_spatial_pad,
set_temporal_pad,
split_sequence,
)
from videosys.core.pab_mgr import (
enable_pab,
get_mlp_output,
if_broadcast_cross,
if_broadcast_mlp,
if_broadcast_spatial,
if_broadcast_temporal,
save_mlp_output,
)
from videosys.core.parallel_mgr import (
enable_sequence_parallel,
get_cfg_parallel_size,
get_data_parallel_group,
get_sequence_parallel_group,
)
from videosys.utils.utils import batch_func
from .modules import (
Attention,
CaptionEmbedder,
MultiHeadCrossAttention,
PatchEmbed3D,
PositionEmbedding2D,
SizeEmbedder,
T2IFinalLayer,
TimestepEmbedder,
approx_gelu,
get_layernorm,
t2i_modulate,
)
from .utils import auto_grad_checkpoint, load_checkpoint
class STDiT3Block(nn.Module):
def __init__(
self,
hidden_size,
num_heads,
mlp_ratio=4.0,
drop_path=0.0,
rope=None,
qk_norm=False,
temporal=False,
enable_flash_attn=False,
block_idx=None,
):
super().__init__()
self.temporal = temporal
self.hidden_size = hidden_size
self.enable_flash_attn = enable_flash_attn
attn_cls = Attention
mha_cls = MultiHeadCrossAttention
self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False)
self.attn = attn_cls(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
qk_norm=qk_norm,
rope=rope,
enable_flash_attn=enable_flash_attn,
)
self.cross_attn = mha_cls(hidden_size, num_heads)
self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False)
self.mlp = Mlp(
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
# pab
self.block_idx = block_idx
self.attn_count = 0
self.last_attn = None
self.cross_count = 0
self.last_cross = None
self.mlp_count = 0
def t_mask_select(self, x_mask, x, masked_x, T, S):
# x: [B, (T, S), C]
# mased_x: [B, (T, S), C]
# x_mask: [B, T]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
x = torch.where(x_mask[:, :, None, None], x, masked_x)
x = rearrange(x, "B T S C -> B (T S) C")
return x
def forward(
self,
x,
y,
t,
mask=None, # text mask
x_mask=None, # temporal mask
t0=None, # t with timestamp=0
T=None, # number of frames
S=None, # number of pixel patches
timestep=None,
all_timesteps=None,
):
# prepare modulate parameters
B, N, C = x.shape
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + t.reshape(B, 6, -1)
).chunk(6, dim=1)
if x_mask is not None:
shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = (
self.scale_shift_table[None] + t0.reshape(B, 6, -1)
).chunk(6, dim=1)
if enable_pab():
if self.temporal:
broadcast_attn, self.attn_count = if_broadcast_temporal(int(timestep[0]), self.attn_count)
else:
broadcast_attn, self.attn_count = if_broadcast_spatial(
int(timestep[0]), self.attn_count, self.block_idx
)
if enable_pab() and broadcast_attn:
x_m_s = self.last_attn
else:
# modulate (attention)
x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
if x_mask is not None:
x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
# attention
if self.temporal:
if enable_sequence_parallel():
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=True)
x_m = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
x_m = self.attn(x_m)
x_m = rearrange(x_m, "(B S) T C -> B (T S) C", T=T, S=S)
if enable_sequence_parallel():
x_m, S, T = self.dynamic_switch(x_m, S, T, to_spatial_shard=False)
else:
x_m = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
x_m = self.attn(x_m)
x_m = rearrange(x_m, "(B T) S C -> B (T S) C", T=T, S=S)
# modulate (attention)
x_m_s = gate_msa * x_m
if x_mask is not None:
x_m_s_zero = gate_msa_zero * x_m
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
if enable_pab():
self.last_attn = x_m_s
# residual
x = x + self.drop_path(x_m_s)
# cross attention
if enable_pab():
broadcast_cross, self.cross_count = if_broadcast_cross(int(timestep[0]), self.cross_count)
if enable_pab() and broadcast_cross:
x = x + self.last_cross
else:
x_cross = self.cross_attn(x, y, mask)
if enable_pab():
self.last_cross = x_cross
x = x + x_cross
if enable_pab():
broadcast_mlp, self.mlp_count, broadcast_next, skip_range = if_broadcast_mlp(
int(timestep[0]),
self.mlp_count,
self.block_idx,
all_timesteps,
is_temporal=self.temporal,
)
if enable_pab() and broadcast_mlp:
x_m_s = get_mlp_output(
skip_range,
timestep=int(timestep[0]),
block_idx=self.block_idx,
is_temporal=self.temporal,
)
else:
# modulate (MLP)
x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
if x_mask is not None:
x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
# MLP
x_m = self.mlp(x_m)
# modulate (MLP)
x_m_s = gate_mlp * x_m
if x_mask is not None:
x_m_s_zero = gate_mlp_zero * x_m
x_m_s = self.t_mask_select(x_mask, x_m_s, x_m_s_zero, T, S)
if enable_pab() and broadcast_next:
save_mlp_output(
timestep=int(timestep[0]),
block_idx=self.block_idx,
ff_output=x_m_s,
is_temporal=self.temporal,
)
# residual
x = x + self.drop_path(x_m_s)
return x
def dynamic_switch(self, x, s, t, to_spatial_shard: bool):
if to_spatial_shard:
scatter_dim, gather_dim = 2, 1
scatter_pad = get_spatial_pad()
gather_pad = get_temporal_pad()
else:
scatter_dim, gather_dim = 1, 2
scatter_pad = get_temporal_pad()
gather_pad = get_spatial_pad()
x = rearrange(x, "b (t s) d -> b t s d", t=t, s=s)
x = all_to_all_with_pad(
x,
get_sequence_parallel_group(),
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_pad=scatter_pad,
gather_pad=gather_pad,
)
new_s, new_t = x.shape[2], x.shape[1]
x = rearrange(x, "b t s d -> b (t s) d")
return x, new_s, new_t
class STDiT3Config(PretrainedConfig):
model_type = "STDiT3"
def __init__(
self,
input_size=(None, None, None),
input_sq_size=512,
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
caption_channels=4096,
model_max_length=300,
qk_norm=True,
enable_flash_attn=False,
only_train_temporal=False,
freeze_y_embedder=False,
skip_y_embedder=False,
**kwargs,
):
self.input_size = input_size
self.input_sq_size = input_sq_size
self.in_channels = in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.class_dropout_prob = class_dropout_prob
self.pred_sigma = pred_sigma
self.drop_path = drop_path
self.caption_channels = caption_channels
self.model_max_length = model_max_length
self.qk_norm = qk_norm
self.enable_flash_attn = enable_flash_attn
self.only_train_temporal = only_train_temporal
self.freeze_y_embedder = freeze_y_embedder
self.skip_y_embedder = skip_y_embedder
super().__init__(**kwargs)
class STDiT3(PreTrainedModel):
config_class = STDiT3Config
def __init__(self, config):
super().__init__(config)
self.pred_sigma = config.pred_sigma
self.in_channels = config.in_channels
self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
# model size related
self.depth = config.depth
self.mlp_ratio = config.mlp_ratio
self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
# computation related
self.drop_path = config.drop_path
self.enable_flash_attn = config.enable_flash_attn
# input size related
self.patch_size = config.patch_size
self.input_sq_size = config.input_sq_size
self.pos_embed = PositionEmbedding2D(config.hidden_size)
from rotary_embedding_torch import RotaryEmbedding
self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads)
# embedding
self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.fps_embedder = SizeEmbedder(self.hidden_size)
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True),
)
self.y_embedder = CaptionEmbedder(
in_channels=config.caption_channels,
hidden_size=config.hidden_size,
uncond_prob=config.class_dropout_prob,
act_layer=approx_gelu,
token_num=config.model_max_length,
)
# spatial blocks
drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
self.spatial_blocks = nn.ModuleList(
[
STDiT3Block(
hidden_size=config.hidden_size,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
drop_path=drop_path[i],
qk_norm=config.qk_norm,
enable_flash_attn=config.enable_flash_attn,
block_idx=i,
)
for i in range(config.depth)
]
)
# temporal blocks
drop_path = [x.item() for x in torch.linspace(0, self.drop_path, config.depth)]
self.temporal_blocks = nn.ModuleList(
[
STDiT3Block(
hidden_size=config.hidden_size,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
drop_path=drop_path[i],
qk_norm=config.qk_norm,
enable_flash_attn=config.enable_flash_attn,
# temporal
temporal=True,
rope=self.rope.rotate_queries_or_keys,
block_idx=i,
)
for i in range(config.depth)
]
)
# final layer
self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
self.initialize_weights()
if config.only_train_temporal:
for param in self.parameters():
param.requires_grad = False
for block in self.temporal_blocks:
for param in block.parameters():
param.requires_grad = True
if config.freeze_y_embedder:
for param in self.y_embedder.parameters():
param.requires_grad = False
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize fps_embedder
nn.init.normal_(self.fps_embedder.mlp[0].weight, std=0.02)
nn.init.constant_(self.fps_embedder.mlp[0].bias, 0)
nn.init.constant_(self.fps_embedder.mlp[2].weight, 0)
nn.init.constant_(self.fps_embedder.mlp[2].bias, 0)
# Initialize timporal blocks
for block in self.temporal_blocks:
nn.init.constant_(block.attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.mlp.fc2.weight, 0)
def get_dynamic_size(self, x):
_, _, T, H, W = x.size()
if T % self.patch_size[0] != 0:
T += self.patch_size[0] - T % self.patch_size[0]
if H % self.patch_size[1] != 0:
H += self.patch_size[1] - H % self.patch_size[1]
if W % self.patch_size[2] != 0:
W += self.patch_size[2] - W % self.patch_size[2]
T = T // self.patch_size[0]
H = H // self.patch_size[1]
W = W // self.patch_size[2]
return (T, H, W)
def encode_text(self, y, mask=None):
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, self.hidden_size)
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, self.hidden_size)
return y, y_lens
def forward(
self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
):
# === Split batch ===
if get_cfg_parallel_size() > 1:
x, timestep, y, x_mask, mask = batch_func(
partial(split_sequence, process_group=get_data_parallel_group(), dim=0), x, timestep, y, x_mask, mask
)
dtype = self.x_embedder.proj.weight.dtype
B = x.size(0)
x = x.to(dtype)
timestep = timestep.to(dtype)
y = y.to(dtype)
# === get pos embed ===
_, _, Tx, Hx, Wx = x.size()
T, H, W = self.get_dynamic_size(x)
S = H * W
base_size = round(S**0.5)
resolution_sq = (height[0].item() * width[0].item()) ** 0.5
scale = resolution_sq / self.input_sq_size
pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
# === get timestep embed ===
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
fps = self.fps_embedder(fps.unsqueeze(1), B)
t = t + fps
t_mlp = self.t_block(t)
t0 = t0_mlp = None
if x_mask is not None:
t0_timestep = torch.zeros_like(timestep)
t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
t0 = t0 + fps
t0_mlp = self.t_block(t0)
# === get y embed ===
if self.config.skip_y_embedder:
y_lens = mask
if isinstance(y_lens, torch.Tensor):
y_lens = y_lens.long().tolist()
else:
y, y_lens = self.encode_text(y, mask)
# === get x embed ===
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = x + pos_emb
# shard over the sequence dim if sp is enabled
if enable_sequence_parallel():
set_temporal_pad(T)
set_spatial_pad(S)
x = split_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad())
T = x.shape[1]
x_mask_org = x_mask
x_mask = split_sequence(
x_mask, get_sequence_parallel_group(), dim=1, grad_scale="down", pad=get_temporal_pad()
)
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
# === blocks ===
for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
x = auto_grad_checkpoint(
spatial_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
x = auto_grad_checkpoint(
temporal_block,
x,
y,
t_mlp,
y_lens,
x_mask,
t0_mlp,
T,
S,
timestep,
all_timesteps=all_timesteps,
)
if enable_sequence_parallel():
x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
x = gather_sequence(x, get_sequence_parallel_group(), dim=1, grad_scale="up", pad=get_temporal_pad())
T, S = x.shape[1], x.shape[2]
x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
x_mask = x_mask_org
# === final layer ===
x = self.final_layer(x, t, x_mask, t0, T, S)
x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
# cast to float32 for better accuracy
x = x.to(torch.float32)
# === Gather Output ===
if get_cfg_parallel_size() > 1:
x = gather_sequence(x, get_data_parallel_group(), dim=0)
return x
def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
# N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
# unpad
x = x[:, :, :R_t, :R_h, :R_w]
return x
def STDiT3_XL_2(from_pretrained=None, **kwargs):
if from_pretrained is not None and not os.path.isdir(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
model = STDiT3(config)
if from_pretrained is not None:
load_checkpoint(model, from_pretrained)
return model