| """ |
| Wan Video DiT with instance-aware control (T5 semantics + bbox/mask). |
| |
| This refactor keeps the original Wan DiT backbone while integrating: |
| - Instance tokens: `<class> is <state>` text (T5) + instance_id embedding. |
| - Mask-guided cross-attention: per-patch gating via bbox- or mask-projected hints. |
| - Backward compatibility: still accepts id-based class/state embeddings and pixel masks. |
| """ |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from .wan_video_camera_controller import SimpleAdapter |
|
|
| try: |
| import flash_attn_interface |
| FLASH_ATTN_3_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_3_AVAILABLE = False |
|
|
| try: |
| import flash_attn |
| FLASH_ATTN_2_AVAILABLE = True |
| except ModuleNotFoundError: |
| FLASH_ATTN_2_AVAILABLE = False |
|
|
| try: |
| from sageattention import sageattn |
| SAGE_ATTN_AVAILABLE = True |
| except ModuleNotFoundError: |
| SAGE_ATTN_AVAILABLE = False |
|
|
|
|
| |
| |
| |
| def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode: bool = False): |
| if compatibility_mode: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| elif FLASH_ATTN_3_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn_interface.flash_attn_func(q, k, v) |
| if isinstance(x, tuple): |
| x = x[0] |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
| elif FLASH_ATTN_2_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b s n d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b s n d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b s n d", n=num_heads) |
| x = flash_attn.flash_attn_func(q, k, v) |
| x = rearrange(x, "b s n d -> b s (n d)", n=num_heads) |
| elif SAGE_ATTN_AVAILABLE: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = sageattn(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| else: |
| q = rearrange(q, "b s (n d) -> b n s d", n=num_heads) |
| k = rearrange(k, "b s (n d) -> b n s d", n=num_heads) |
| v = rearrange(v, "b s (n d) -> b n s d", n=num_heads) |
| x = F.scaled_dot_product_attention(q, k, v) |
| x = rearrange(x, "b n s d -> b s (n d)", n=num_heads) |
| return x |
|
|
|
|
| def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): |
| return (x * (1 + scale) + shift) |
|
|
|
|
| def sinusoidal_embedding_1d(dim, position): |
| sinusoid = torch.outer(position.type(torch.float64), torch.pow( |
| 10000, -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(dim // 2))) |
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
| return x.to(position.dtype) |
|
|
|
|
| def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0): |
| f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta) |
| h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta) |
| return f_freqs_cis, h_freqs_cis, w_freqs_cis |
|
|
|
|
| def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].double() / dim)) |
| freqs = torch.outer(torch.arange(end, device=freqs.device), freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def rope_apply(x, freqs, num_heads): |
| x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) |
| x_out = torch.view_as_complex(x.to(torch.float64).reshape( |
| x.shape[0], x.shape[1], x.shape[2], -1, 2)) |
| x_out = torch.view_as_real(x_out * freqs).flatten(2) |
| return x_out.to(x.dtype) |
|
|
|
|
| |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| dtype = x.dtype |
| return self.norm(x.float()).to(dtype) * self.weight |
|
|
|
|
| class AttentionModule(nn.Module): |
| def __init__(self, num_heads): |
| super().__init__() |
| self.num_heads = num_heads |
| |
| def forward(self, q, k, v): |
| x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) |
| return x |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
| |
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x, freqs): |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(x)) |
| v = self.v(x) |
| q = rope_apply(q, freqs, self.num_heads) |
| k = rope_apply(k, freqs, self.num_heads) |
| x = self.attn(q, k, v) |
| return self.o(x) |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
|
|
| self.q = nn.Linear(dim, dim) |
| self.k = nn.Linear(dim, dim) |
| self.v = nn.Linear(dim, dim) |
| self.o = nn.Linear(dim, dim) |
| self.norm_q = RMSNorm(dim, eps=eps) |
| self.norm_k = RMSNorm(dim, eps=eps) |
| self.has_image_input = has_image_input |
| if has_image_input: |
| self.k_img = nn.Linear(dim, dim) |
| self.v_img = nn.Linear(dim, dim) |
| self.norm_k_img = RMSNorm(dim, eps=eps) |
| |
| self.attn = AttentionModule(self.num_heads) |
|
|
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| if self.has_image_input: |
| img = y[:, :257] |
| ctx = y[:, 257:] |
| else: |
| ctx = y |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(ctx)) |
| v = self.v(ctx) |
| x = self.attn(q, k, v) |
| if self.has_image_input: |
| k_img = self.norm_k_img(self.k_img(img)) |
| v_img = self.v_img(img) |
| y = flash_attention(q, k_img, v_img, num_heads=self.num_heads) |
| x = x + y |
| return self.o(x) |
|
|
|
|
| class GateModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, gate, residual): |
| return x + gate * residual |
|
|
|
|
| class MaskGuidedCrossAttention(nn.Module): |
| """ |
| 每个 patch 只关注覆盖它的实例 token,使用 log-mask trick 保证 0 区域被强制屏蔽。 |
| """ |
| def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.scale = self.head_dim ** -0.5 |
|
|
| self.to_q = nn.Linear(dim, dim, bias=False) |
| self.to_k = nn.Linear(dim, dim, bias=False) |
| self.to_v = nn.Linear(dim, dim, bias=False) |
| |
| self.to_out = nn.Linear(dim, dim) |
| self.norm = nn.LayerNorm(dim, eps=eps) |
| self.gate = nn.Parameter(torch.zeros(1)) |
|
|
| def _attend(self, x: torch.Tensor, instance_tokens: torch.Tensor, instance_masks: torch.Tensor) -> torch.Tensor: |
| B, L, _ = x.shape |
| _, N, _ = instance_tokens.shape |
| if N == 0: |
| return x |
| if instance_masks.shape != (B, N, L): |
| raise ValueError(f"instance_masks shape mismatch, expect (B,N,L)=({B},{N},{L}), got {tuple(instance_masks.shape)}") |
|
|
| h = self.num_heads |
| q = rearrange(self.to_q(self.norm(x)), "b l (h d) -> b h l d", h=h) |
| k = rearrange(self.to_k(instance_tokens), "b n (h d) -> b h n d", h=h) |
| v = rearrange(self.to_v(instance_tokens), "b n (h d) -> b h n d", h=h) |
| sim = torch.einsum("b h l d, b h n d -> b h l n", q, k) * self.scale |
|
|
| mask_bias = instance_masks.transpose(1, 2).unsqueeze(1).to(dtype=sim.dtype) |
| sim = sim + torch.log(mask_bias.clamp(min=1e-6)) |
| attn = sim.softmax(dim=-1) |
| out = torch.einsum("b h l n, b h n d -> b h l d", attn, v) |
| out = rearrange(out, "b h l d -> b l (h d)") |
| return x + self.gate * self.to_out(out) |
|
|
| def forward(self, x: torch.Tensor, instance_tokens: torch.Tensor, instance_masks: torch.Tensor) -> torch.Tensor: |
| """ |
| instance_tokens supports: |
| - (B, N, D): static tokens for the whole sequence |
| - (B, T, N, D): tokens per patch-time (sequence assumed laid out as T contiguous chunks) |
| - (B, L, N, D): tokens per token position (used for sequence parallel chunking) |
| """ |
| B, L, _ = x.shape |
| if instance_tokens.ndim == 3: |
| return self._attend(x, instance_tokens, instance_masks) |
|
|
| if instance_tokens.ndim != 4: |
| raise ValueError(f"instance_tokens must be 3D or 4D, got {tuple(instance_tokens.shape)}") |
|
|
| if instance_tokens.shape[1] == L: |
| |
| _, _, N, _ = instance_tokens.shape |
| if instance_masks.shape != (B, N, L): |
| raise ValueError(f"instance_masks shape mismatch, expect (B,N,L)=({B},{N},{L}), got {tuple(instance_masks.shape)}") |
| h = self.num_heads |
| q = rearrange(self.to_q(self.norm(x)), "b l (h d) -> b h l d", h=h) |
| k = rearrange(self.to_k(instance_tokens), "b l n (h d) -> b h l n d", h=h) |
| v = rearrange(self.to_v(instance_tokens), "b l n (h d) -> b h l n d", h=h) |
| sim = torch.einsum("b h l d, b h l n d -> b h l n", q, k) * self.scale |
| mask_bias = instance_masks.transpose(1, 2).unsqueeze(1).to(dtype=sim.dtype) |
| sim = sim + torch.log(mask_bias.clamp(min=1e-6)) |
| attn = sim.softmax(dim=-1) |
| out = torch.einsum("b h l n, b h l n d -> b h l d", attn, v) |
| out = rearrange(out, "b h l d -> b l (h d)") |
| return x + self.gate * self.to_out(out) |
|
|
| |
| _, T, _, _ = instance_tokens.shape |
| if L % T != 0: |
| raise ValueError(f"Token length L={L} must be divisible by T={T} for per-time instance tokens.") |
| hw = L // T |
| chunks = [] |
| for t in range(T): |
| s, e = t * hw, (t + 1) * hw |
| chunks.append(self._attend(x[:, s:e], instance_tokens[:, t], instance_masks[:, :, s:e])) |
| return torch.cat(chunks, dim=1) |
|
|
|
|
| class DiTBlock(nn.Module): |
| def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.dim = dim |
| self.num_heads = num_heads |
| self.ffn_dim = ffn_dim |
|
|
| self.self_attn = SelfAttention(dim, num_heads, eps) |
| self.cross_attn = CrossAttention(dim, num_heads, eps, has_image_input=has_image_input) |
| self.instance_cross_attn = MaskGuidedCrossAttention(dim, num_heads, eps) |
|
|
| self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.norm3 = nn.LayerNorm(dim, eps=eps) |
| self.ffn = nn.Sequential( |
| nn.Linear(dim, ffn_dim), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(ffn_dim, dim), |
| ) |
| self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5) |
| self.gate = GateModule() |
|
|
| def forward(self, x, context, t_mod, freqs, instance_tokens=None, instance_masks=None): |
| has_seq = len(t_mod.shape) == 4 |
| chunk_dim = 2 if has_seq else 1 |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod |
| ).chunk(6, dim=chunk_dim) |
| if has_seq: |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
| shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2), |
| shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2), |
| ) |
|
|
| |
| input_x = modulate(self.norm1(x), shift_msa, scale_msa) |
| x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) |
|
|
| |
| x = x + self.cross_attn(self.norm3(x), context) |
|
|
| |
| if instance_tokens is not None and instance_masks is not None: |
| x = self.instance_cross_attn(x, instance_tokens, instance_masks) |
|
|
| |
| input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| x = self.gate(x, gate_mlp, self.ffn(input_x)) |
| return x |
|
|
|
|
| class MLP(torch.nn.Module): |
| def __init__(self, in_dim, out_dim, has_pos_emb=False): |
| super().__init__() |
| self.proj = torch.nn.Sequential( |
| nn.LayerNorm(in_dim), |
| nn.Linear(in_dim, in_dim), |
| nn.GELU(), |
| nn.Linear(in_dim, out_dim), |
| nn.LayerNorm(out_dim) |
| ) |
| self.has_pos_emb = has_pos_emb |
| if has_pos_emb: |
| self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280))) |
|
|
| def forward(self, x): |
| if self.has_pos_emb: |
| x = x + self.emb_pos.to(dtype=x.dtype, device=x.device) |
| return self.proj(x) |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float): |
| super().__init__() |
| self.dim = dim |
| self.patch_size = patch_size |
| self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False) |
| self.head = nn.Linear(dim, out_dim * math.prod(patch_size)) |
| self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5) |
|
|
| def forward(self, x, t_mod): |
| if len(t_mod.shape) == 3: |
| shift, scale = ( |
| self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2) |
| ).chunk(2, dim=2) |
| x = self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)) |
| else: |
| shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1) |
| x = self.head(self.norm(x) * (1 + scale) + shift) |
| return x |
|
|
|
|
| class InstanceFeatureExtractor(nn.Module): |
| """ |
| 将 `instance_id` 与 (class/state 组合短语) 的文本语义融合为实例 token,并支持按时间(帧/patch-time) |
| 的 state weights 做动态加权: |
| - 输入:`state_text_embeds_multi` 形状 (B, N, S, text_dim),其中每个 state 对应短语 `"<class> is <state>"` |
| - 输入:`state_weights` 形状 (B, N, F, S),F 为帧数(或任意时间长度),每帧对 S 个 state 的权重 |
| - 输出:实例 token 形状 (B, T, N, D),T 为时间长度(可选下采样到 patch-time) |
| """ |
| def __init__( |
| self, |
| num_instance_ids: int = 1000, |
| embedding_dim: int = 1280, |
| hidden_dim: int = 1280, |
| text_dim: int = 4096, |
| ): |
| super().__init__() |
| self.inst_id_emb = nn.Embedding(num_instance_ids, hidden_dim, padding_idx=0) |
| self.text_proj = nn.Sequential( |
| nn.Linear(int(text_dim), hidden_dim, bias=False), |
| nn.SiLU(), |
| nn.Linear(hidden_dim, hidden_dim, bias=False), |
| nn.LayerNorm(hidden_dim), |
| ) |
|
|
| self.fusion = nn.Sequential( |
| nn.Linear(hidden_dim * 2, embedding_dim), |
| nn.SiLU(), |
| nn.Linear(embedding_dim, embedding_dim), |
| nn.LayerNorm(embedding_dim), |
| ) |
|
|
| @staticmethod |
| def _pool_time_to_patches(weights: torch.Tensor, num_time_patches: int) -> torch.Tensor: |
| """ |
| Average-pool per-frame weights (B,N,F,S) to per-patch-time weights (B,N,T,S), |
| where mapping uses pt = floor(t * T / F). |
| """ |
| if weights.ndim != 4: |
| raise ValueError(f"state_weights must be (B,N,F,S), got {tuple(weights.shape)}") |
| B, N, F, S = weights.shape |
| T = int(num_time_patches) |
| if T <= 0: |
| raise ValueError("num_time_patches must be > 0") |
| if F == T: |
| return weights |
| device = weights.device |
| idx = (torch.arange(F, device=device, dtype=torch.float32) * (T / max(float(F), 1.0))).floor().clamp(0, T - 1).long() |
| idx = idx.view(1, 1, F, 1).expand(B, N, F, S) |
| out = torch.zeros((B, N, T, S), device=device, dtype=weights.dtype) |
| out.scatter_add_(2, idx, weights) |
| cnt = torch.zeros((B, N, T, S), device=device, dtype=weights.dtype) |
| cnt.scatter_add_(2, idx, torch.ones_like(weights)) |
| return out / cnt.clamp(min=1.0) |
|
|
| def forward( |
| self, |
| instance_ids: torch.Tensor, |
| state_text_embeds_multi: torch.Tensor, |
| state_weights: torch.Tensor, |
| num_time_patches: Optional[int] = None, |
| ): |
| if state_text_embeds_multi is None: |
| raise ValueError("state_text_embeds_multi is required.") |
| if state_weights is None: |
| raise ValueError("state_weights is required.") |
| if state_text_embeds_multi.ndim != 4: |
| raise ValueError(f"state_text_embeds_multi must be (B,N,S,D), got {tuple(state_text_embeds_multi.shape)}") |
| if state_weights.ndim != 4: |
| raise ValueError(f"state_weights must be (B,N,F,S), got {tuple(state_weights.shape)}") |
|
|
| B, N, S, _ = state_text_embeds_multi.shape |
| if instance_ids.shape[:2] != (B, N): |
| raise ValueError(f"instance_ids must be (B,N)=({B},{N}), got {tuple(instance_ids.shape)}") |
| if state_weights.shape[0] != B or state_weights.shape[1] != N or state_weights.shape[-1] != S: |
| raise ValueError(f"state_weights must be (B,N,F,S)=({B},{N},F,{S}), got {tuple(state_weights.shape)}") |
|
|
| sem_multi = self.text_proj(state_text_embeds_multi) |
| weights = state_weights.to(dtype=sem_multi.dtype, device=sem_multi.device).clamp(min=0) |
| if num_time_patches is not None: |
| weights = self._pool_time_to_patches(weights, int(num_time_patches)) |
| |
| sem_multi = sem_multi.unsqueeze(2) |
| weights = weights.unsqueeze(-1) |
| denom = weights.sum(dim=3).clamp(min=1e-6) |
| sem_time = (sem_multi * weights).sum(dim=3) / denom |
|
|
| i_feat = self.inst_id_emb(instance_ids % self.inst_id_emb.num_embeddings).to(dtype=sem_time.dtype, device=sem_time.device) |
| i_time = i_feat.unsqueeze(2).expand(-1, -1, sem_time.shape[2], -1) |
| tokens = self.fusion(torch.cat([sem_time, i_time], dim=-1)) |
| return tokens.permute(0, 2, 1, 3).contiguous() |
|
|
|
|
| |
| |
| |
| class WanModel(torch.nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| in_dim: int, |
| ffn_dim: int, |
| out_dim: int, |
| text_dim: int, |
| freq_dim: int, |
| eps: float, |
| patch_size: Tuple[int, int, int], |
| num_heads: int, |
| num_layers: int, |
| has_image_input: bool, |
| has_image_pos_emb: bool = False, |
| has_ref_conv: bool = False, |
| add_control_adapter: bool = False, |
| in_dim_control_adapter: int = 24, |
| seperated_timestep: bool = False, |
| require_vae_embedding: bool = True, |
| require_clip_embedding: bool = True, |
| fuse_vae_embedding_in_latents: bool = False, |
| |
| num_class_ids: int = 200, |
| num_state_ids: int = 100, |
| num_instance_ids: int = 1000, |
| state_feature_dim: int = 256, |
| instance_text_dim: Optional[int] = 4096, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.in_dim = in_dim |
| self.freq_dim = freq_dim |
| self.has_image_input = has_image_input |
| self.patch_size = patch_size |
| self.seperated_timestep = seperated_timestep |
| self.require_vae_embedding = require_vae_embedding |
| self.require_clip_embedding = require_clip_embedding |
| self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents |
|
|
| self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) |
| self.text_embedding = nn.Sequential( |
| nn.Linear(text_dim, dim), |
| nn.GELU(approximate="tanh"), |
| nn.Linear(dim, dim), |
| ) |
| self.time_embedding = nn.Sequential( |
| nn.Linear(freq_dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim), |
| ) |
| self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) |
|
|
| self.blocks = nn.ModuleList([DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) |
| self.head = Head(dim, out_dim, patch_size, eps) |
| head_dim = dim // num_heads |
| self.freqs = precompute_freqs_cis_3d(head_dim) |
|
|
| if has_image_input: |
| self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) |
| if has_ref_conv: |
| self.ref_conv = nn.Conv2d(16, dim, kernel_size=(2, 2), stride=(2, 2)) |
| self.has_image_pos_emb = has_image_pos_emb |
| self.has_ref_conv = has_ref_conv |
| if add_control_adapter: |
| self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:]) |
| else: |
| self.control_adapter = None |
|
|
| instance_text_dim = int(text_dim) if instance_text_dim is None else int(instance_text_dim) |
| self.instance_encoder = InstanceFeatureExtractor( |
| num_instance_ids=num_instance_ids, |
| embedding_dim=dim, |
| hidden_dim=dim, |
| text_dim=instance_text_dim, |
| ) |
| self.instance_text_dim = instance_text_dim |
|
|
| |
| def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): |
| """ |
| Returns: |
| tokens: (B, L, D) |
| grid_size: (F_p, H_p, W_p) |
| """ |
| x = self.patch_embedding(x) |
| if self.control_adapter is not None and control_camera_latents_input is not None: |
| y_camera = self.control_adapter(control_camera_latents_input) |
| if isinstance(y_camera, (list, tuple)): |
| x = x + y_camera[0] |
| else: |
| x = x + y_camera |
| grid_size = x.shape[2:] |
| x = rearrange(x, "b c f h w -> b (f h w) c").contiguous() |
| return x, grid_size |
|
|
| def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): |
| return rearrange( |
| x, "b (f h w) (x y z c) -> b c (f x) (h y) (w z)", |
| f=grid_size[0], h=grid_size[1], w=grid_size[2], |
| x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2], |
| ) |
|
|
| |
| def process_masks( |
| self, |
| grid_size, |
| image_size: Tuple[int, int, int], |
| bboxes: torch.Tensor, |
| bbox_mask: Optional[torch.Tensor] = None, |
| ): |
| """ |
| bbox-only path: |
| bboxes: (B, N, F, 4) or (B, N, 4), xyxy in pixel coords |
| bbox_mask: (B, N, F) or (B, N, 1), optional existence mask |
| Returns: |
| (B, N, L) flattened patch mask |
| """ |
| if bboxes is None: |
| raise ValueError("bboxes must be provided for instance control.") |
| return self._bboxes_to_masks(bboxes, bbox_mask, grid_size, image_size) |
|
|
| def _bboxes_to_masks( |
| self, |
| bboxes: torch.Tensor, |
| bbox_mask: Optional[torch.Tensor], |
| grid_size: Tuple[int, int, int], |
| image_size: Tuple[int, int, int], |
| ): |
| f_p, h_p, w_p = grid_size |
| F_img, H_img, W_img = image_size |
| |
| |
| |
| |
|
|
| if bboxes.ndim == 3: |
| bboxes = bboxes.unsqueeze(2).expand(-1, -1, f_p, -1) |
| if bboxes.ndim != 4 or bboxes.shape[-1] != 4: |
| raise ValueError(f"bboxes must be (B,N,F,4) or (B,N,4); got {tuple(bboxes.shape)}") |
|
|
| if bbox_mask is None: |
| bbox_mask = torch.ones(bboxes.shape[:3], device=bboxes.device, dtype=torch.bool) |
| else: |
| if bbox_mask.ndim == 3: |
| pass |
| elif bbox_mask.ndim == 2: |
| bbox_mask = bbox_mask.unsqueeze(-1).expand(-1, -1, bboxes.shape[2]) |
| else: |
| raise ValueError(f"bbox_mask must be (B,N,F) or (B,N,1); got {tuple(bbox_mask.shape)}") |
| bbox_mask = bbox_mask.to(dtype=torch.bool, device=bboxes.device) |
|
|
| mask = bboxes.new_zeros((bboxes.shape[0], bboxes.shape[1], f_p, h_p, w_p), dtype=torch.float32) |
| f_bbox = int(bboxes.shape[2]) |
| w_scale = (w_p / max(float(W_img), 1.0)) |
| h_scale = (h_p / max(float(H_img), 1.0)) |
|
|
| for b in range(bboxes.shape[0]): |
| for n in range(bboxes.shape[1]): |
| for t in range(f_bbox): |
| if not bbox_mask[b, n, t]: |
| continue |
| x0, y0, x1, y1 = bboxes[b, n, t] |
| x0 = max(0, min(float(x0), W_img)) |
| x1 = max(0, min(float(x1), W_img)) |
| y0 = max(0, min(float(y0), H_img)) |
| y1 = max(0, min(float(y1), H_img)) |
| if x1 <= x0 or y1 <= y0: |
| continue |
|
|
| px0 = int(math.floor(x0 * w_scale)) |
| py0 = int(math.floor(y0 * h_scale)) |
| px1 = int(math.ceil(x1 * w_scale)) |
| py1 = int(math.ceil(y1 * h_scale)) |
| px0 = max(0, min(px0, w_p - 1)) |
| py0 = max(0, min(py0, h_p - 1)) |
| px1 = max(px0 + 1, min(px1, w_p)) |
| py1 = max(py0 + 1, min(py1, h_p)) |
|
|
| pt = min(int(math.floor(t * f_p / max(f_bbox, 1))), f_p - 1) |
| mask[b, n, pt, py0:py1, px0:px1] = 1.0 |
|
|
| mask_flat = rearrange(mask, "b n f h w -> b n (f h w)") |
| return mask_flat |
|
|
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| timestep: torch.Tensor, |
| context: torch.Tensor, |
| clip_feature: Optional[torch.Tensor] = None, |
| y: Optional[torch.Tensor] = None, |
| use_gradient_checkpointing: bool = False, |
| use_gradient_checkpointing_offload: bool = False, |
| |
| instance_ids: Optional[torch.Tensor] = None, |
| instance_state_text_embeds_multi: Optional[torch.Tensor] = None, |
| instance_state_weights: Optional[torch.Tensor] = None, |
| instance_bboxes: Optional[torch.Tensor] = None, |
| **kwargs, |
| ): |
| |
| t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype)) |
| t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) |
|
|
| |
| context = self.text_embedding(context) |
|
|
| |
| if self.has_image_input: |
| x = torch.cat([x, y], dim=1) |
| clip_embedding = self.img_emb(clip_feature) |
| context = torch.cat([clip_embedding, context], dim=1) |
|
|
| orig_frames, orig_h, orig_w = x.shape[2:] |
| x, (f, h, w) = self.patchify(x) |
| grid_size = (f, h, w) |
|
|
| |
| inst_tokens = None |
| inst_mask_flat = None |
| has_instance = ( |
| instance_ids is not None |
| and instance_bboxes is not None |
| and instance_state_text_embeds_multi is not None |
| and instance_state_weights is not None |
| and instance_ids.shape[1] > 0 |
| ) |
| if has_instance: |
| inst_tokens = self.instance_encoder( |
| instance_ids=instance_ids, |
| state_text_embeds_multi=instance_state_text_embeds_multi, |
| state_weights=instance_state_weights, |
| num_time_patches=f, |
| ) |
|
|
| inst_mask_flat = self.process_masks( |
| grid_size, |
| image_size=(orig_frames, orig_h, orig_w), |
| bboxes=instance_bboxes, |
| ) |
|
|
| |
| freqs = torch.cat([ |
| self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), |
| self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), |
| self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), |
| ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
|
|
| def create_custom_forward_with_instance(module): |
| def custom_forward(x, context, t_mod, freqs, instance_tokens, instance_masks): |
| return module(x, context, t_mod, freqs, instance_tokens=instance_tokens, instance_masks=instance_masks) |
| return custom_forward |
|
|
| for block in self.blocks: |
| use_instance = inst_tokens is not None and inst_mask_flat is not None |
| if self.training and use_gradient_checkpointing: |
| if use_gradient_checkpointing_offload: |
| with torch.autograd.graph.save_on_cpu(): |
| if use_instance: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward_with_instance(block), |
| x, context, t_mod, freqs, inst_tokens, inst_mask_flat, |
| use_reentrant=False, |
| ) |
| else: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| if use_instance: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward_with_instance(block), |
| x, context, t_mod, freqs, inst_tokens, inst_mask_flat, |
| use_reentrant=False, |
| ) |
| else: |
| x = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| x, context, t_mod, freqs, |
| use_reentrant=False, |
| ) |
| else: |
| if use_instance: |
| x = block(x, context, t_mod, freqs, instance_tokens=inst_tokens, instance_masks=inst_mask_flat) |
| else: |
| x = block(x, context, t_mod, freqs) |
|
|
| x = self.head(x, t) |
| x = self.unpatchify(x, (f, h, w)) |
| return x |
|
|