| from typing import List, Optional, Tuple |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.amp as amp |
|
|
| import numpy as np |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
| from .wan_video_dit import flash_attention |
| from ..core.gradient import gradient_checkpoint_forward |
|
|
|
|
| class RMSNorm_FP32(torch.nn.Module): |
| def __init__(self, dim: int, eps: float): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| output = self._norm(x.float()).type_as(x) |
| return output * self.weight |
|
|
|
|
| def broadcat(tensors, dim=-1): |
| num_tensors = len(tensors) |
| shape_lens = set(list(map(lambda t: len(t.shape), tensors))) |
| assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" |
| shape_len = list(shape_lens)[0] |
| dim = (dim + shape_len) if dim < 0 else dim |
| dims = list(zip(*map(lambda t: list(t.shape), tensors))) |
| expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] |
| assert all( |
| [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] |
| ), "invalid dimensions for broadcastable concatentation" |
| max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) |
| expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) |
| expanded_dims.insert(dim, (dim, dims[dim])) |
| expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) |
| tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) |
| return torch.cat(tensors, dim=dim) |
|
|
|
|
| def rotate_half(x): |
| x = rearrange(x, "... (d r) -> ... d r", r=2) |
| x1, x2 = x.unbind(dim=-1) |
| x = torch.stack((-x2, x1), dim=-1) |
| return rearrange(x, "... d r -> ... (d r)") |
|
|
|
|
| class RotaryPositionalEmbedding(nn.Module): |
|
|
| def __init__(self, |
| head_dim, |
| cp_split_hw=None |
| ): |
| """Rotary positional embedding for 3D |
| Reference : https://blog.eleuther.ai/rotary-embeddings/ |
| Paper: https://arxiv.org/pdf/2104.09864.pdf |
| Args: |
| dim: Dimension of embedding |
| base: Base value for exponential |
| """ |
| super().__init__() |
| self.head_dim = head_dim |
| assert self.head_dim % 8 == 0, 'Dim must be a multiply of 8 for 3D RoPE.' |
| self.cp_split_hw = cp_split_hw |
| |
| self.base = 10000 |
| self.freqs_dict = {} |
|
|
| def register_grid_size(self, grid_size): |
| if grid_size not in self.freqs_dict: |
| self.freqs_dict.update({ |
| grid_size: self.precompute_freqs_cis_3d(grid_size) |
| }) |
|
|
| def precompute_freqs_cis_3d(self, grid_size): |
| num_frames, height, width = grid_size |
| dim_t = self.head_dim - 4 * (self.head_dim // 6) |
| dim_h = 2 * (self.head_dim // 6) |
| dim_w = 2 * (self.head_dim // 6) |
| freqs_t = 1.0 / (self.base ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) |
| freqs_h = 1.0 / (self.base ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) |
| freqs_w = 1.0 / (self.base ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) |
| grid_t = np.linspace(0, num_frames, num_frames, endpoint=False, dtype=np.float32) |
| grid_h = np.linspace(0, height, height, endpoint=False, dtype=np.float32) |
| grid_w = np.linspace(0, width, width, endpoint=False, dtype=np.float32) |
| grid_t = torch.from_numpy(grid_t).float() |
| grid_h = torch.from_numpy(grid_h).float() |
| grid_w = torch.from_numpy(grid_w).float() |
| freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) |
| freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) |
| freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) |
| freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) |
| freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) |
| freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) |
| freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) |
| |
| freqs = rearrange(freqs, "T H W D -> (T H W) D") |
| |
| |
| |
| |
| |
|
|
| return freqs |
|
|
| def forward(self, q, k, grid_size): |
| """3D RoPE. |
| |
| Args: |
| query: [B, head, seq, head_dim] |
| key: [B, head, seq, head_dim] |
| Returns: |
| query and key with the same shape as input. |
| """ |
|
|
| if grid_size not in self.freqs_dict: |
| self.register_grid_size(grid_size) |
|
|
| freqs_cis = self.freqs_dict[grid_size].to(q.device) |
| q_, k_ = q.float(), k.float() |
| freqs_cis = freqs_cis.float().to(q.device) |
| cos, sin = freqs_cis.cos(), freqs_cis.sin() |
| cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') |
| q_ = (q_ * cos) + (rotate_half(q_) * sin) |
| k_ = (k_ * cos) + (rotate_half(k_) * sin) |
|
|
| return q_.type_as(q), k_.type_as(k) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| enable_flashattn3: bool = False, |
| enable_flashattn2: bool = False, |
| enable_xformers: bool = False, |
| enable_bsa: bool = False, |
| bsa_params: dict = None, |
| cp_split_hw: Optional[List[int]] = None |
| ) -> None: |
| super().__init__() |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim**-0.5 |
| self.enable_flashattn3 = enable_flashattn3 |
| self.enable_flashattn2 = enable_flashattn2 |
| self.enable_xformers = enable_xformers |
| self.enable_bsa = enable_bsa |
| self.bsa_params = bsa_params |
| self.cp_split_hw = cp_split_hw |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=True) |
| self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| self.proj = nn.Linear(dim, dim) |
|
|
| self.rope_3d = RotaryPositionalEmbedding( |
| self.head_dim, |
| cp_split_hw=cp_split_hw |
| ) |
|
|
| def _process_attn(self, q, k, v, shape): |
| q = rearrange(q, "B H S D -> B S (H D)") |
| k = rearrange(k, "B H S D -> B S (H D)") |
| v = rearrange(v, "B H S D -> B S (H D)") |
| x = flash_attention(q, k, v, num_heads=self.num_heads) |
| x = rearrange(x, "B S (H D) -> B H S D", H=self.num_heads) |
| return x |
|
|
| def forward(self, x: torch.Tensor, shape=None, num_cond_latents=None, return_kv=False) -> torch.Tensor: |
| """ |
| """ |
| B, N, C = x.shape |
| qkv = self.qkv(x) |
|
|
| qkv_shape = (B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) |
| q, k, v = qkv.unbind(0) |
| q, k = self.q_norm(q), self.k_norm(k) |
|
|
| if return_kv: |
| k_cache, v_cache = k.clone(), v.clone() |
|
|
| q, k = self.rope_3d(q, k, shape) |
|
|
| |
| if num_cond_latents is not None and num_cond_latents > 0: |
| num_cond_latents_thw = num_cond_latents * (N // shape[0]) |
| |
| q_cond = q[:, :, :num_cond_latents_thw].contiguous() |
| k_cond = k[:, :, :num_cond_latents_thw].contiguous() |
| v_cond = v[:, :, :num_cond_latents_thw].contiguous() |
| x_cond = self._process_attn(q_cond, k_cond, v_cond, shape) |
| |
| q_noise = q[:, :, num_cond_latents_thw:].contiguous() |
| x_noise = self._process_attn(q_noise, k, v, shape) |
| |
| x = torch.cat([x_cond, x_noise], dim=2).contiguous() |
| else: |
| x = self._process_attn(q, k, v, shape) |
|
|
| x_output_shape = (B, N, C) |
| x = x.transpose(1, 2) |
| x = x.reshape(x_output_shape) |
| x = self.proj(x) |
|
|
| if return_kv: |
| return x, (k_cache, v_cache) |
| else: |
| return x |
|
|
| def forward_with_kv_cache(self, x: torch.Tensor, shape=None, num_cond_latents=None, kv_cache=None) -> torch.Tensor: |
| """ |
| """ |
| B, N, C = x.shape |
| qkv = self.qkv(x) |
| |
| qkv_shape = (B, N, 3, self.num_heads, self.head_dim) |
| qkv = qkv.view(qkv_shape).permute((2, 0, 3, 1, 4)) |
| q, k, v = qkv.unbind(0) |
| q, k = self.q_norm(q), self.k_norm(k) |
|
|
| T, H, W = shape |
| k_cache, v_cache = kv_cache |
| assert k_cache.shape[0] == v_cache.shape[0] and k_cache.shape[0] in [1, B] |
| if k_cache.shape[0] == 1: |
| k_cache = k_cache.repeat(B, 1, 1, 1) |
| v_cache = v_cache.repeat(B, 1, 1, 1) |
| |
| if num_cond_latents is not None and num_cond_latents > 0: |
| k_full = torch.cat([k_cache, k], dim=2).contiguous() |
| v_full = torch.cat([v_cache, v], dim=2).contiguous() |
| q_padding = torch.cat([torch.empty_like(k_cache), q], dim=2).contiguous() |
| q_padding, k_full = self.rope_3d(q_padding, k_full, (T + num_cond_latents, H, W)) |
| q = q_padding[:, :, -N:].contiguous() |
| |
| x = self._process_attn(q, k_full, v_full, shape) |
| |
| x_output_shape = (B, N, C) |
| x = x.transpose(1, 2) |
| x = x.reshape(x_output_shape) |
| x = self.proj(x) |
|
|
| return x |
|
|
|
|
| class MultiHeadCrossAttention(nn.Module): |
| def __init__( |
| self, |
| dim, |
| num_heads, |
| enable_flashattn3=False, |
| enable_flashattn2=False, |
| enable_xformers=False, |
| ): |
| super(MultiHeadCrossAttention, self).__init__() |
| assert dim % num_heads == 0, "d_model must be divisible by num_heads" |
|
|
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q_linear = nn.Linear(dim, dim) |
| self.kv_linear = nn.Linear(dim, dim * 2) |
| self.proj = nn.Linear(dim, dim) |
|
|
| self.q_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
| self.k_norm = RMSNorm_FP32(self.head_dim, eps=1e-6) |
|
|
| self.enable_flashattn3 = enable_flashattn3 |
| self.enable_flashattn2 = enable_flashattn2 |
| self.enable_xformers = enable_xformers |
|
|
| def _process_cross_attn(self, x, cond, kv_seqlen): |
| B, N, C = x.shape |
| assert C == self.dim and cond.shape[2] == self.dim |
|
|
| q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) |
| kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) |
| k, v = kv.unbind(2) |
|
|
| q, k = self.q_norm(q), self.k_norm(k) |
|
|
| q = rearrange(q, "B S H D -> B S (H D)") |
| k = rearrange(k, "B S H D -> B S (H D)") |
| v = rearrange(v, "B S H D -> B S (H D)") |
| x = flash_attention(q, k, v, num_heads=self.num_heads) |
|
|
| x = x.view(B, -1, C) |
| x = self.proj(x) |
| return x |
|
|
| def forward(self, x, cond, kv_seqlen, num_cond_latents=None, shape=None): |
| """ |
| x: [B, N, C] |
| cond: [B, M, C] |
| """ |
| if num_cond_latents is None or num_cond_latents == 0: |
| return self._process_cross_attn(x, cond, kv_seqlen) |
| else: |
| B, N, C = x.shape |
| if num_cond_latents is not None and num_cond_latents > 0: |
| assert shape is not None, "SHOULD pass in the shape" |
| num_cond_latents_thw = num_cond_latents * (N // shape[0]) |
| x_noise = x[:, num_cond_latents_thw:] |
| output_noise = self._process_cross_attn(x_noise, cond, kv_seqlen) |
| output = torch.cat([ |
| torch.zeros((B, num_cond_latents_thw, C), dtype=output_noise.dtype, device=output_noise.device), |
| output_noise |
| ], dim=1).contiguous() |
| else: |
| raise NotImplementedError |
| |
| return output |
|
|
|
|
| class LayerNorm_FP32(nn.LayerNorm): |
| def __init__(self, dim, eps, elementwise_affine): |
| super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine) |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| origin_dtype = inputs.dtype |
| out = F.layer_norm( |
| inputs.float(), |
| self.normalized_shape, |
| None if self.weight is None else self.weight.float(), |
| None if self.bias is None else self.bias.float() , |
| self.eps |
| ).to(origin_dtype) |
| return out |
|
|
|
|
| def modulate_fp32(norm_func, x, shift, scale): |
| |
| |
| assert shift.dtype == torch.float32, scale.dtype == torch.float32 |
| dtype = x.dtype |
| x = norm_func(x.to(torch.float32)) |
| x = x * (scale + 1) + shift |
| x = x.to(dtype) |
| return x |
|
|
|
|
| class FinalLayer_FP32(nn.Module): |
| """ |
| The final layer of DiT. |
| """ |
|
|
| def __init__(self, hidden_size, num_patch, out_channels, adaln_tembed_dim): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_patch = num_patch |
| self.out_channels = out_channels |
| self.adaln_tembed_dim = adaln_tembed_dim |
|
|
| self.norm_final = LayerNorm_FP32(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(adaln_tembed_dim, 2 * hidden_size, bias=True)) |
|
|
| def forward(self, x, t, latent_shape): |
| |
| assert t.dtype == torch.float32 |
| B, N, C = x.shape |
| T, _, _ = latent_shape |
|
|
| with amp.autocast('cuda', dtype=torch.float32): |
| shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) |
| x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) |
| x = self.linear(x) |
| return x |
|
|
|
|
| class FeedForwardSwiGLU(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| multiple_of: int = 256, |
| ffn_dim_multiplier: Optional[float] = None, |
| ): |
| super().__init__() |
| hidden_dim = int(2 * hidden_dim / 3) |
| |
| if ffn_dim_multiplier is not None: |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
| self.dim = dim |
| self.hidden_dim = hidden_dim |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def forward(self, x): |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class TimestepEmbedder(nn.Module): |
| """ |
| Embeds scalar timesteps into vector representations. |
| """ |
|
|
| def __init__(self, t_embed_dim, frequency_embedding_size=256): |
| super().__init__() |
| self.t_embed_dim = t_embed_dim |
| self.frequency_embedding_size = frequency_embedding_size |
| self.mlp = nn.Sequential( |
| nn.Linear(frequency_embedding_size, t_embed_dim, bias=True), |
| nn.SiLU(), |
| nn.Linear(t_embed_dim, t_embed_dim, bias=True), |
| ) |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| """ |
| Create sinusoidal timestep embeddings. |
| :param t: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param dim: the dimension of the output. |
| :param max_period: controls the minimum frequency of the embeddings. |
| :return: an (N, D) Tensor of positional embeddings. |
| """ |
| half = dim // 2 |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) |
| freqs = freqs.to(device=t.device) |
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
| def forward(self, t, dtype): |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| if t_freq.dtype != dtype: |
| t_freq = t_freq.to(dtype) |
| t_emb = self.mlp(t_freq) |
| return t_emb |
|
|
|
|
| class CaptionEmbedder(nn.Module): |
| """ |
| Embeds class labels into vector representations. |
| """ |
|
|
| def __init__(self, in_channels, hidden_size): |
| super().__init__() |
| self.in_channels = in_channels |
| self.hidden_size = hidden_size |
| self.y_proj = nn.Sequential( |
| nn.Linear(in_channels, hidden_size, bias=True), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(hidden_size, hidden_size, bias=True), |
| ) |
|
|
| def forward(self, caption): |
| B, _, N, C = caption.shape |
| caption = self.y_proj(caption) |
| return caption |
|
|
|
|
| class PatchEmbed3D(nn.Module): |
| """Video to Patch Embedding. |
| |
| Args: |
| patch_size (int): Patch token size. Default: (2,4,4). |
| in_chans (int): Number of input video channels. Default: 3. |
| embed_dim (int): Number of linear projection output channels. Default: 96. |
| norm_layer (nn.Module, optional): Normalization layer. Default: None |
| """ |
|
|
| def __init__( |
| self, |
| patch_size=(2, 4, 4), |
| in_chans=3, |
| embed_dim=96, |
| norm_layer=None, |
| flatten=True, |
| ): |
| super().__init__() |
| self.patch_size = patch_size |
| self.flatten = flatten |
|
|
| self.in_chans = in_chans |
| self.embed_dim = embed_dim |
|
|
| self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
| if norm_layer is not None: |
| self.norm = norm_layer(embed_dim) |
| else: |
| self.norm = None |
|
|
| def forward(self, x): |
| """Forward function.""" |
| |
| _, _, D, H, W = x.size() |
| if W % self.patch_size[2] != 0: |
| x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) |
| if H % self.patch_size[1] != 0: |
| x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) |
| if D % self.patch_size[0] != 0: |
| x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) |
|
|
| B, C, T, H, W = x.shape |
| x = self.proj(x) |
| if self.norm is not None: |
| D, Wh, Ww = x.size(2), x.size(3), x.size(4) |
| x = x.flatten(2).transpose(1, 2) |
| x = self.norm(x) |
| x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) |
| if self.flatten: |
| x = x.flatten(2).transpose(1, 2) |
| return x |
|
|
|
|
| class LongCatSingleStreamBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| mlp_ratio: int, |
| adaln_tembed_dim: int, |
| enable_flashattn3: bool = False, |
| enable_flashattn2: bool = False, |
| enable_xformers: bool = False, |
| enable_bsa: bool = False, |
| bsa_params=None, |
| cp_split_hw=None |
| ): |
| super().__init__() |
|
|
| self.hidden_size = hidden_size |
|
|
| |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(adaln_tembed_dim, 6 * hidden_size, bias=True) |
| ) |
|
|
| self.mod_norm_attn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) |
| self.mod_norm_ffn = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=False) |
| self.pre_crs_attn_norm = LayerNorm_FP32(hidden_size, eps=1e-6, elementwise_affine=True) |
|
|
| self.attn = Attention( |
| dim=hidden_size, |
| num_heads=num_heads, |
| enable_flashattn3=enable_flashattn3, |
| enable_flashattn2=enable_flashattn2, |
| enable_xformers=enable_xformers, |
| enable_bsa=enable_bsa, |
| bsa_params=bsa_params, |
| cp_split_hw=cp_split_hw |
| ) |
| self.cross_attn = MultiHeadCrossAttention( |
| dim=hidden_size, |
| num_heads=num_heads, |
| enable_flashattn3=enable_flashattn3, |
| enable_flashattn2=enable_flashattn2, |
| enable_xformers=enable_xformers, |
| ) |
| self.ffn = FeedForwardSwiGLU(dim=hidden_size, hidden_dim=int(hidden_size * mlp_ratio)) |
|
|
| def forward(self, x, y, t, y_seqlen, latent_shape, num_cond_latents=None, return_kv=False, kv_cache=None, skip_crs_attn=False): |
| """ |
| x: [B, N, C] |
| y: [1, N_valid_tokens, C] |
| t: [B, T, C_t] |
| y_seqlen: [B]; type of a list |
| latent_shape: latent shape of a single item |
| """ |
| x_dtype = x.dtype |
|
|
| B, N, C = x.shape |
| T, _, _ = latent_shape |
|
|
| |
| with amp.autocast(device_type='cuda', dtype=torch.float32): |
| shift_msa, scale_msa, gate_msa, \ |
| shift_mlp, scale_mlp, gate_mlp = \ |
| self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) |
|
|
| |
| x_m = modulate_fp32(self.mod_norm_attn, x.view(B, T, -1, C), shift_msa, scale_msa).view(B, N, C) |
|
|
| if kv_cache is not None: |
| kv_cache = (kv_cache[0].to(x.device), kv_cache[1].to(x.device)) |
| attn_outputs = self.attn.forward_with_kv_cache(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, kv_cache=kv_cache) |
| else: |
| attn_outputs = self.attn(x_m, shape=latent_shape, num_cond_latents=num_cond_latents, return_kv=return_kv) |
| |
| if return_kv: |
| x_s, kv_cache = attn_outputs |
| else: |
| x_s = attn_outputs |
|
|
| with amp.autocast(device_type='cuda', dtype=torch.float32): |
| x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) |
| x = x.to(x_dtype) |
|
|
| |
| if not skip_crs_attn: |
| if kv_cache is not None: |
| num_cond_latents = None |
| x = x + self.cross_attn(self.pre_crs_attn_norm(x), y, y_seqlen, num_cond_latents=num_cond_latents, shape=latent_shape) |
|
|
| |
| x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) |
| x_s = self.ffn(x_m) |
| with amp.autocast(device_type='cuda', dtype=torch.float32): |
| x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) |
| x = x.to(x_dtype) |
|
|
| if return_kv: |
| return x, kv_cache |
| else: |
| return x |
|
|
|
|
| class LongCatVideoTransformer3DModel(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels: int = 16, |
| out_channels: int = 16, |
| hidden_size: int = 4096, |
| depth: int = 48, |
| num_heads: int = 32, |
| caption_channels: int = 4096, |
| mlp_ratio: int = 4, |
| adaln_tembed_dim: int = 512, |
| frequency_embedding_size: int = 256, |
| |
| patch_size: Tuple[int] = (1, 2, 2), |
| |
| enable_flashattn3: bool = False, |
| enable_flashattn2: bool = True, |
| enable_xformers: bool = False, |
| enable_bsa: bool = False, |
| bsa_params: dict = {'sparsity': 0.9375, 'chunk_3d_shape_q': [4, 4, 4], 'chunk_3d_shape_k': [4, 4, 4]}, |
| cp_split_hw: Optional[List[int]] = [1, 1], |
| text_tokens_zero_pad: bool = True, |
| ) -> None: |
| super().__init__() |
|
|
| self.patch_size = patch_size |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.cp_split_hw = cp_split_hw |
|
|
| self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) |
| self.t_embedder = TimestepEmbedder(t_embed_dim=adaln_tembed_dim, frequency_embedding_size=frequency_embedding_size) |
| self.y_embedder = CaptionEmbedder( |
| in_channels=caption_channels, |
| hidden_size=hidden_size, |
| ) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| LongCatSingleStreamBlock( |
| hidden_size=hidden_size, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| adaln_tembed_dim=adaln_tembed_dim, |
| enable_flashattn3=enable_flashattn3, |
| enable_flashattn2=enable_flashattn2, |
| enable_xformers=enable_xformers, |
| enable_bsa=enable_bsa, |
| bsa_params=bsa_params, |
| cp_split_hw=cp_split_hw |
| ) |
| for i in range(depth) |
| ] |
| ) |
|
|
| self.final_layer = FinalLayer_FP32( |
| hidden_size, |
| np.prod(self.patch_size), |
| out_channels, |
| adaln_tembed_dim, |
| ) |
|
|
| self.gradient_checkpointing = False |
| self.text_tokens_zero_pad = text_tokens_zero_pad |
|
|
| self.lora_dict = {} |
| self.active_loras = [] |
|
|
| def enable_loras(self, lora_key_list=[]): |
| self.disable_all_loras() |
| |
| module_loras = {} |
| model_device = next(self.parameters()).device |
| model_dtype = next(self.parameters()).dtype |
| |
| for lora_key in lora_key_list: |
| if lora_key in self.lora_dict: |
| for lora in self.lora_dict[lora_key].loras: |
| lora.to(model_device, dtype=model_dtype, non_blocking=True) |
| module_name = lora.lora_name.replace("lora___lorahyphen___", "").replace("___lorahyphen___", ".") |
| if module_name not in module_loras: |
| module_loras[module_name] = [] |
| module_loras[module_name].append(lora) |
| self.active_loras.append(lora_key) |
| |
| for module_name, loras in module_loras.items(): |
| module = self._get_module_by_name(module_name) |
| if not hasattr(module, 'org_forward'): |
| module.org_forward = module.forward |
| module.forward = self._create_multi_lora_forward(module, loras) |
| |
| def _create_multi_lora_forward(self, module, loras): |
| def multi_lora_forward(x, *args, **kwargs): |
| weight_dtype = x.dtype |
| org_output = module.org_forward(x, *args, **kwargs) |
| |
| total_lora_output = 0 |
| for lora in loras: |
| if lora.use_lora: |
| lx = lora.lora_down(x.to(lora.lora_down.weight.dtype)) |
| lx = lora.lora_up(lx) |
| lora_output = lx.to(weight_dtype) * lora.multiplier * lora.alpha_scale |
| total_lora_output += lora_output |
| |
| return org_output + total_lora_output |
| |
| return multi_lora_forward |
| |
| def _get_module_by_name(self, module_name): |
| try: |
| module = self |
| for part in module_name.split('.'): |
| module = getattr(module, part) |
| return module |
| except AttributeError as e: |
| raise ValueError(f"Cannot find module: {module_name}, error: {e}") |
| |
| def disable_all_loras(self): |
| for name, module in self.named_modules(): |
| if hasattr(module, 'org_forward'): |
| module.forward = module.org_forward |
| delattr(module, 'org_forward') |
| |
| for lora_key, lora_network in self.lora_dict.items(): |
| for lora in lora_network.loras: |
| lora.to("cpu") |
| |
| self.active_loras.clear() |
|
|
| def enable_bsa(self,): |
| for block in self.blocks: |
| block.attn.enable_bsa = True |
| |
| def disable_bsa(self,): |
| for block in self.blocks: |
| block.attn.enable_bsa = False |
|
|
| def forward( |
| self, |
| hidden_states, |
| timestep, |
| encoder_hidden_states, |
| encoder_attention_mask=None, |
| num_cond_latents=0, |
| return_kv=False, |
| kv_cache_dict={}, |
| skip_crs_attn=False, |
| offload_kv_cache=False, |
| use_gradient_checkpointing=False, |
| use_gradient_checkpointing_offload=False, |
| ): |
|
|
| B, _, T, H, W = hidden_states.shape |
|
|
| N_t = T // self.patch_size[0] |
| N_h = H // self.patch_size[1] |
| N_w = W // self.patch_size[2] |
|
|
| assert self.patch_size[0]==1, "Currently, 3D x_embedder should not compress the temporal dimension." |
|
|
| |
| if len(timestep.shape) == 1: |
| timestep = timestep.unsqueeze(1).expand(-1, N_t).clone() |
| timestep[:, :num_cond_latents] = 0 |
|
|
| dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(dtype) |
| timestep = timestep.to(dtype) |
| encoder_hidden_states = encoder_hidden_states.to(dtype) |
|
|
| hidden_states = self.x_embedder(hidden_states) |
|
|
| with amp.autocast(device_type='cuda', dtype=torch.float32): |
| t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) |
|
|
| encoder_hidden_states = self.y_embedder(encoder_hidden_states) |
|
|
| if self.text_tokens_zero_pad and encoder_attention_mask is not None: |
| encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[:, None, :, None] |
| encoder_attention_mask = (encoder_attention_mask * 0 + 1).to(encoder_attention_mask.dtype) |
|
|
| if encoder_attention_mask is not None: |
| encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1) |
| encoder_hidden_states = encoder_hidden_states.squeeze(1).masked_select(encoder_attention_mask.unsqueeze(-1) != 0).view(1, -1, hidden_states.shape[-1]) |
| y_seqlens = encoder_attention_mask.sum(dim=1).tolist() |
| else: |
| y_seqlens = [encoder_hidden_states.shape[2]] * encoder_hidden_states.shape[0] |
| encoder_hidden_states = encoder_hidden_states.squeeze(1).view(1, -1, hidden_states.shape[-1]) |
|
|
| |
| |
| |
| |
|
|
| |
| kv_cache_dict_ret = {} |
| for i, block in enumerate(self.blocks): |
| block_outputs = gradient_checkpoint_forward( |
| block, |
| use_gradient_checkpointing=use_gradient_checkpointing, |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, |
| x=hidden_states, |
| y=encoder_hidden_states, |
| t=t, |
| y_seqlen=y_seqlens, |
| latent_shape=(N_t, N_h, N_w), |
| num_cond_latents=num_cond_latents, |
| return_kv=return_kv, |
| kv_cache=kv_cache_dict.get(i, None), |
| skip_crs_attn=skip_crs_attn, |
| ) |
| |
| if return_kv: |
| hidden_states, kv_cache = block_outputs |
| if offload_kv_cache: |
| kv_cache_dict_ret[i] = (kv_cache[0].cpu(), kv_cache[1].cpu()) |
| else: |
| kv_cache_dict_ret[i] = (kv_cache[0].contiguous(), kv_cache[1].contiguous()) |
| else: |
| hidden_states = block_outputs |
|
|
| hidden_states = self.final_layer(hidden_states, t, (N_t, N_h, N_w)) |
|
|
| |
| |
|
|
| hidden_states = self.unpatchify(hidden_states, N_t, N_h, N_w) |
|
|
| |
| hidden_states = hidden_states.to(torch.float32) |
|
|
| if return_kv: |
| return hidden_states, kv_cache_dict_ret |
| else: |
| return hidden_states |
| |
|
|
| def unpatchify(self, x, N_t, N_h, N_w): |
| """ |
| Args: |
| x (torch.Tensor): of shape [B, N, C] |
| |
| Return: |
| x (torch.Tensor): of shape [B, C_out, T, H, W] |
| """ |
| T_p, H_p, W_p = self.patch_size |
| x = rearrange( |
| x, |
| "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", |
| N_t=N_t, |
| N_h=N_h, |
| N_w=N_w, |
| T_p=T_p, |
| H_p=H_p, |
| W_p=W_p, |
| C_out=self.out_channels, |
| ) |
| return x |
|
|
| @staticmethod |
| def state_dict_converter(): |
| return LongCatVideoTransformer3DModelDictConverter() |
|
|
|
|
| class LongCatVideoTransformer3DModelDictConverter: |
| def __init__(self): |
| pass |
|
|
| def from_diffusers(self, state_dict): |
| return state_dict |
| |
| def from_civitai(self, state_dict): |
| return state_dict |
|
|
|
|