OpenSora-STDiT-v1-HQ-16x256x256 / modeling_stdit.py
frankleeeee's picture
Upload STDiT
8ef4ac6 verified
raw
history blame
No virus
10.2 kB
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from einops import rearrange
from .configuration_stdit import STDiTConfig
from .layers import (
STDiTBlock,
CaptionEmbedder,
PatchEmbed3D,
T2IFinalLayer,
TimestepEmbedder,
)
from .utils import (
approx_gelu,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
)
from transformers import PreTrainedModel
class STDiT(PreTrainedModel):
config_class = STDiTConfig
def __init__(
self,
config
):
super().__init__(config)
self.pred_sigma = config.pred_sigma
self.in_channels = config.in_channels
self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
self.input_size = config.input_size
num_patches = np.prod([config.input_size[i] // config.patch_size[i] for i in range(3)])
self.num_patches = num_patches
self.num_temporal = config.input_size[0] // config.patch_size[0]
self.num_spatial = num_patches // self.num_temporal
self.num_heads = config.num_heads
self.no_temporal_pos_emb = config.no_temporal_pos_emb
self.depth = config.depth
self.mlp_ratio = config.mlp_ratio
self.enable_flash_attn = config.enable_flash_attn
self.enable_layernorm_kernel = config.enable_layernorm_kernel
self.space_scale = config.space_scale
self.time_scale = config.time_scale
self.register_buffer("pos_embed", self.get_spatial_pos_embed())
self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
self.t_embedder = TimestepEmbedder(config.hidden_size)
self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True))
self.y_embedder = CaptionEmbedder(
in_channels=config.caption_channels,
hidden_size=config.hidden_size,
uncond_prob=config.class_dropout_prob,
act_layer=approx_gelu,
token_num=config.model_max_length,
)
drop_path = [x.item() for x in torch.linspace(0, config.drop_path, config.depth)]
self.blocks = nn.ModuleList(
[
STDiTBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=self.mlp_ratio,
drop_path=drop_path[i],
enable_flash_attn=self.enable_flash_attn,
enable_layernorm_kernel=self.enable_layernorm_kernel,
enable_sequence_parallelism=config.enable_sequence_parallelism,
d_t=self.num_temporal,
d_s=self.num_spatial,
)
for i in range(self.depth)
]
)
self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
# init model
self.initialize_weights()
self.initialize_temporal()
if config.freeze is not None:
assert config.freeze in ["not_temporal", "text"]
if config.freeze == "not_temporal":
self.freeze_not_temporal()
elif config.freeze == "text":
self.freeze_text()
# sequence parallel related configs
self.enable_sequence_parallelism = config.enable_sequence_parallelism
if config.enable_sequence_parallelism:
self.sp_rank = dist.get_rank(get_sequence_parallel_group())
else:
self.sp_rank = None
def forward(self, x, timestep, y, mask=None):
"""
Forward pass of STDiT.
Args:
x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
timestep (torch.Tensor): diffusion time steps; of shape [B]
y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
Returns:
x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
"""
x = x.to(self.final_layer.linear.weight.dtype)
timestep = timestep.to(self.final_layer.linear.weight.dtype)
y = y.to(self.final_layer.linear.weight.dtype)
# embedding
x = self.x_embedder(x) # [B, N, C]
x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial)
x = x + self.pos_embed
x = rearrange(x, "B T S C -> B (T S) C")
# shard over the sequence dim if sp is enabled
if self.enable_sequence_parallelism:
x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
t0 = self.t_block(t) # [B, C]
y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
# blocks
for i, block in enumerate(self.blocks):
if i == 0:
if self.enable_sequence_parallelism:
tpe = torch.chunk(
self.pos_embed_temporal, dist.get_world_size(get_sequence_parallel_group()), dim=1
)[self.sp_rank].contiguous()
else:
tpe = self.pos_embed_temporal
else:
tpe = None
x = block(x, y, t0, y_lens, tpe)
# x = auto_grad_checkpoint(block, x, y, t0, y_lens, tpe)
if self.enable_sequence_parallelism:
x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
# x.shape: [B, N, C]
# final process
x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
x = self.unpatchify(x) # [B, C_out, T, H, W]
# cast to float32 for better accuracy
x = x.to(torch.float32)
return x
def unpatchify(self, x):
"""
Args:
x (torch.Tensor): of shape [B, N, C]
Return:
x (torch.Tensor): of shape [B, C_out, T, H, W]
"""
N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
T_p, H_p, W_p = self.patch_size
x = rearrange(
x,
"B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
N_t=N_t,
N_h=N_h,
N_w=N_w,
T_p=T_p,
H_p=H_p,
W_p=W_p,
C_out=self.out_channels,
)
return x
def unpatchify_old(self, x):
c = self.out_channels
t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
pt, ph, pw = self.patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_spatial_pos_embed(self, grid_size=None):
if grid_size is None:
grid_size = self.input_size[1:]
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size,
(grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
scale=self.space_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def get_temporal_pos_embed(self):
pos_embed = get_1d_sincos_pos_embed(
self.hidden_size,
self.input_size[0] // self.patch_size[0],
scale=self.time_scale,
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
return pos_embed
def freeze_not_temporal(self):
for n, p in self.named_parameters():
if "attn_temp" not in n:
p.requires_grad = False
def freeze_text(self):
for n, p in self.named_parameters():
if "cross_attn" in n:
p.requires_grad = False
def initialize_temporal(self):
for block in self.blocks:
nn.init.constant_(block.attn_temp.proj.weight, 0)
nn.init.constant_(block.attn_temp.proj.bias, 0)
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)
# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)