import math import collections.abc import torch import torch.nn as nn import torch.nn.functional as F import functools from einops import rearrange from itertools import repeat from functools import partial from .utils import approx_gelu, get_layernorm, t2i_modulate from typing import Optional try: import xformers HAS_XFORMERS = True except: HAS_XFORMERS = False # ================= # STDiT2Block # ================= class STDiT2Block(nn.Module): def __init__( self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0.0, enable_flash_attn=False, enable_layernorm_kernel=False, enable_sequence_parallelism=False, rope=None, qk_norm=False, ): super().__init__() self.hidden_size = hidden_size self.enable_flash_attn = enable_flash_attn self._enable_sequence_parallelism = enable_sequence_parallelism assert not self._enable_sequence_parallelism, "Sequence parallelism is not supported." if enable_sequence_parallelism: self.attn_cls = SeqParallelAttention self.mha_cls = SeqParallelMultiHeadCrossAttention else: self.attn_cls = Attention self.mha_cls = MultiHeadCrossAttention # spatial branch self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) self.attn = self.attn_cls( hidden_size, num_heads=num_heads, qkv_bias=True, enable_flash_attn=enable_flash_attn, qk_norm=qk_norm, ) self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) # cross attn self.cross_attn = self.mha_cls(hidden_size, num_heads) # mlp branch self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) self.mlp = Mlp( in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # temporal branch self.norm_temp = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) # new self.attn_temp = self.attn_cls( hidden_size, num_heads=num_heads, qkv_bias=True, enable_flash_attn=self.enable_flash_attn, rope=rope, qk_norm=qk_norm, ) self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new def t_mask_select(self, x_mask, x, masked_x, T, S): # x: [B, (T, S), C] # mased_x: [B, (T, S), C] # x_mask: [B, T] x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S) x = torch.where(x_mask[:, :, None, None], x, masked_x) x = rearrange(x, "B T S C -> B (T S) C") return x def forward(self, x, y, t, t_tmp, mask=None, x_mask=None, t0=None, t0_tmp=None, T=None, S=None): B, N, C = x.shape shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + t.reshape(B, 6, -1) ).chunk(6, dim=1) shift_tmp, scale_tmp, gate_tmp = (self.scale_shift_table_temporal[None] + t_tmp.reshape(B, 3, -1)).chunk( 3, dim=1 ) if x_mask is not None: shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = ( self.scale_shift_table[None] + t0.reshape(B, 6, -1) ).chunk(6, dim=1) shift_tmp_zero, scale_tmp_zero, gate_tmp_zero = ( self.scale_shift_table_temporal[None] + t0_tmp.reshape(B, 3, -1) ).chunk(3, dim=1) # modulate x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) if x_mask is not None: x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero) x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) # spatial branch x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S) x_s = self.attn(x_s) x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=T, S=S) if x_mask is not None: x_s_zero = gate_msa_zero * x_s x_s = gate_msa * x_s x_s = self.t_mask_select(x_mask, x_s, x_s_zero, T, S) else: x_s = gate_msa * x_s x = x + self.drop_path(x_s) # modulate x_m = t2i_modulate(self.norm_temp(x), shift_tmp, scale_tmp) if x_mask is not None: x_m_zero = t2i_modulate(self.norm_temp(x), shift_tmp_zero, scale_tmp_zero) x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) # temporal branch x_t = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S) x_t = self.attn_temp(x_t) x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=T, S=S) if x_mask is not None: x_t_zero = gate_tmp_zero * x_t x_t = gate_tmp * x_t x_t = self.t_mask_select(x_mask, x_t, x_t_zero, T, S) else: x_t = gate_tmp * x_t x = x + self.drop_path(x_t) # cross attn x = x + self.cross_attn(x, y, mask) # modulate x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp) if x_mask is not None: x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero) x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S) # mlp x_mlp = self.mlp(x_m) if x_mask is not None: x_mlp_zero = gate_mlp_zero * x_mlp x_mlp = gate_mlp * x_mlp x_mlp = self.t_mask_select(x_mask, x_mlp, x_mlp_zero, T, S) else: x_mlp = gate_mlp * x_mlp x = x + self.drop_path(x_mlp) return x # ================= # Attention # ================= class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = LlamaRMSNorm, enable_flash_attn: bool = False, rope=None, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.enable_flash_attn = enable_flash_attn self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.rope = False if rope is not None: self.rope = True self.rotary_emb = rope def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape # flash attn is not memory efficient for small sequences, this is empirical enable_flash_attn = self.enable_flash_attn and (N > B) qkv = self.qkv(x) qkv_shape = (B, N, 3, self.num_heads, self.head_dim) qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) if self.rope: q = self.rotary_emb(q) k = self.rotary_emb(k) q, k = self.q_norm(q), self.k_norm(k) if enable_flash_attn: from flash_attn import flash_attn_func # (B, #heads, N, #dim) -> (B, N, #heads, #dim) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) x = flash_attn_func( q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, softmax_scale=self.scale, ) else: dtype = q.dtype q = q * self.scale attn = q @ k.transpose(-2, -1) # translate attn to float32 attn = attn.to(torch.float32) attn = attn.softmax(dim=-1) attn = attn.to(dtype) # cast back attn to original dtype attn = self.attn_drop(attn) x = attn @ v x_output_shape = (B, N, C) if not enable_flash_attn: x = x.transpose(1, 2) x = x.reshape(x_output_shape) x = self.proj(x) x = self.proj_drop(x) return x # ======================== # MultiHeadCrossAttention # ======================== class MultiHeadCrossAttention(nn.Module): def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0): super(MultiHeadCrossAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads self.q_linear = nn.Linear(d_model, d_model) self.kv_linear = nn.Linear(d_model, d_model * 2) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(d_model, d_model) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, cond, mask=None): # query/value: img tokens; key: condition; mask: if padding tokens B, N, C = x.shape q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) k, v = kv.unbind(2) attn_bias = None if mask is not None: attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) x = x.view(B, -1, C) x = self.proj(x) x = self.proj_drop(x) return x # ================= # Timm Components # ================= def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) def extra_repr(self): return f'drop_prob={round(self.drop_prob,3):0.3f}' def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_2tuple = _ntuple(2) class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0., use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x # ================= # Embedding # ================= class CaptionEmbedder(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ def __init__( self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120, ): super().__init__() self.y_proj = Mlp( in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0, ) self.register_buffer( "y_embedding", torch.randn(token_num, in_channels) / in_channels**0.5, ) self.uncond_prob = uncond_prob def token_drop(self, caption, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob else: drop_ids = force_drop_ids == 1 caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) return caption def forward(self, caption, train, force_drop_ids=None): if train: assert caption.shape[2:] == self.y_embedding.shape use_dropout = self.uncond_prob > 0 if (train and use_dropout) or (force_drop_ids is not None): caption = self.token_drop(caption, force_drop_ids) caption = self.y_proj(caption) return caption class PatchEmbed3D(nn.Module): """Video to Patch Embedding. Args: patch_size (int): Patch token size. Default: (2,4,4). in_chans (int): Number of input video channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__( self, patch_size=(2, 4, 4), in_chans=3, embed_dim=96, norm_layer=None, flatten=True, ): super().__init__() self.patch_size = patch_size self.flatten = flatten self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): """Forward function.""" # padding _, _, D, H, W = x.size() if W % self.patch_size[2] != 0: x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) if H % self.patch_size[1] != 0: x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) if D % self.patch_size[0] != 0: x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) x = self.proj(x) # (B C T H W) if self.norm is not None: D, Wh, Ww = x.size(2), x.size(3), x.size(4) x = x.flatten(2).transpose(1, 2) x = self.norm(x) x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC return x class T2IFinalLayer(nn.Module): """ The final layer of PixArt. """ def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) self.out_channels = out_channels self.d_t = d_t self.d_s = d_s def t_mask_select(self, x_mask, x, masked_x, T, S): # x: [B, (T, S), C] # mased_x: [B, (T, S), C] # x_mask: [B, T] x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S) x = torch.where(x_mask[:, :, None, None], x, masked_x) x = rearrange(x, "B T S C -> B (T S) C") return x def forward(self, x, t, x_mask=None, t0=None, T=None, S=None): if T is None: T = self.d_t if S is None: S = self.d_s shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) x = t2i_modulate(self.norm_final(x), shift, scale) if x_mask is not None: shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1) x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero) x = self.t_mask_select(x_mask, x, x_zero, T, S) x = self.linear(x) return x class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) freqs = freqs.to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t, dtype): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) if t_freq.dtype != dtype: t_freq = t_freq.to(dtype) t_emb = self.mlp(t_freq) return t_emb class SizeEmbedder(TimestepEmbedder): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size self.outdim = hidden_size def forward(self, s, bs): if s.ndim == 1: s = s[:, None] assert s.ndim == 2 if s.shape[0] != bs: s = s.repeat(bs // s.shape[0], 1) assert s.shape[0] == bs b, dims = s.shape[0], s.shape[1] s = rearrange(s, "b d -> (b d)") s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) s_emb = self.mlp(s_freq) s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) return s_emb @property def dtype(self): return next(self.parameters()).dtype class PositionEmbedding2D(nn.Module): def __init__(self, dim: int) -> None: super().__init__() self.dim = dim assert dim % 4 == 0, "dim must be divisible by 4" half_dim = dim // 2 inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def _get_sin_cos_emb(self, t: torch.Tensor): out = torch.einsum("i,d->id", t, self.inv_freq) emb_cos = torch.cos(out) emb_sin = torch.sin(out) return torch.cat((emb_sin, emb_cos), dim=-1) @functools.lru_cache(maxsize=512) def _get_cached_emb( self, device: torch.device, dtype: torch.dtype, h: int, w: int, scale: float = 1.0, base_size: Optional[int] = None, ): grid_h = torch.arange(h, device=device) / scale grid_w = torch.arange(w, device=device) / scale if base_size is not None: grid_h *= base_size / h grid_w *= base_size / w grid_h, grid_w = torch.meshgrid( grid_w, grid_h, indexing="ij", ) # here w goes first grid_h = grid_h.t().reshape(-1) grid_w = grid_w.t().reshape(-1) emb_h = self._get_sin_cos_emb(grid_h) emb_w = self._get_sin_cos_emb(grid_w) return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype) def forward( self, x: torch.Tensor, h: int, w: int, scale: Optional[float] = 1.0, base_size: Optional[int] = None, ) -> torch.Tensor: return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)