| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """WorldModel transformer for frame generation.""" |
| |
|
| | from typing import Optional, List |
| | import math |
| |
|
| | import einops as eo |
| | import torch |
| | from torch import nn, Tensor |
| | import torch.nn.functional as F |
| | from tensordict import TensorDict |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| |
|
| | from .attn import Attn, MergedQKVAttn, CrossAttention |
| | from .nn import AdaLN, MLP, NoiseConditioner, ada_gate, ada_rmsnorm, rms_norm |
| | from .quantize import quantize_model |
| | from .cache import CachedDenoiseStepEmb, CachedCondHead |
| |
|
| |
|
| | def patch_cached_noise_conditioning(model) -> None: |
| | |
| | cached_denoise_step_emb = CachedDenoiseStepEmb( |
| | model.denoise_step_emb, model.config.scheduler_sigmas |
| | ) |
| | model.denoise_step_emb = cached_denoise_step_emb |
| | for blk in model.transformer.blocks: |
| | blk.cond_head = CachedCondHead(blk.cond_head, cached_denoise_step_emb) |
| |
|
| |
|
| | def patch_Attn_merge_qkv(model) -> None: |
| | for name, mod in list(model.named_modules()): |
| | if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn): |
| | model.set_submodule(name, MergedQKVAttn(mod, model.config)) |
| |
|
| |
|
| | def patch_MLPFusion_split(model) -> None: |
| | for name, mod in list(model.named_modules()): |
| | if isinstance(mod, MLPFusion) and not isinstance(mod, SplitMLPFusion): |
| | model.set_submodule(name, SplitMLPFusion(mod)) |
| |
|
| |
|
| | def _apply_inference_patches(model) -> None: |
| | patch_cached_noise_conditioning(model) |
| | patch_Attn_merge_qkv(model) |
| | patch_MLPFusion_split(model) |
| |
|
| |
|
| | class CFG(nn.Module): |
| | def __init__(self, d_model: int, dropout: float): |
| | super().__init__() |
| | self.dropout = dropout |
| | self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model)) |
| |
|
| | def forward( |
| | self, x: torch.Tensor, is_conditioned: Optional[bool] = None |
| | ) -> torch.Tensor: |
| | """ |
| | x: [B, L, D] |
| | is_conditioned: |
| | - None: training-style random dropout |
| | - bool: whole batch conditioned / unconditioned at sampling |
| | """ |
| | B, L, _ = x.shape |
| | null = self.null_emb.expand(B, L, -1) |
| |
|
| | |
| | if self.training or is_conditioned is None: |
| | if self.dropout == 0.0: |
| | return x |
| | drop = torch.rand(B, 1, 1, device=x.device) < self.dropout |
| | return torch.where(drop, null, x) |
| |
|
| | |
| | return x if is_conditioned else null |
| |
|
| |
|
| | class ControllerInputEmbedding(nn.Module): |
| | """Embeds controller inputs (mouse + buttons) into model dimension.""" |
| |
|
| | def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4): |
| | super().__init__() |
| | self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model) |
| |
|
| | def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor): |
| | assert len(mouse.shape) == 3 |
| | x = torch.cat((mouse, button, scroll), dim=-1) |
| | return self.mlp(x) |
| |
|
| |
|
| | class MLPFusion(nn.Module): |
| | """Fuses per-group conditioning into tokens by applying an MLP to cat([x, cond]).""" |
| |
|
| | def __init__(self, d_model: int): |
| | super().__init__() |
| | self.mlp = MLP(2 * d_model, d_model, d_model) |
| |
|
| | def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| | B, _, D = x.shape |
| | L = cond.shape[1] |
| |
|
| | Wx, Wc = self.mlp.fc1.weight.chunk(2, dim=1) |
| |
|
| | x = x.view(B, L, -1, D) |
| | h = F.linear(x, Wx) + F.linear(cond, Wc).unsqueeze( |
| | 2 |
| | ) |
| | h = F.silu(h) |
| | y = F.linear(h, self.mlp.fc2.weight) |
| | return y.flatten(1, 2) |
| |
|
| |
|
| | class SplitMLPFusion(nn.Module): |
| | """Packed MLPFusion -> split linears (no cat, quant-friendly).""" |
| |
|
| | def __init__(self, src: MLPFusion): |
| | super().__init__() |
| | D = src.mlp.fc2.in_features |
| | dev, dt = src.mlp.fc2.weight.device, src.mlp.fc2.weight.dtype |
| |
|
| | self.fc1_x = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
| | self.fc1_c = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
| | self.fc2 = nn.Linear(D, D, bias=False, device=dev, dtype=dt) |
| |
|
| | with torch.no_grad(): |
| | Wx, Wc = src.mlp.fc1.weight.chunk(2, dim=1) |
| | self.fc1_x.weight.copy_(Wx) |
| | self.fc1_c.weight.copy_(Wc) |
| | self.fc2.weight.copy_(src.mlp.fc2.weight) |
| |
|
| | self.train(src.training) |
| |
|
| | def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| | B, _, D = x.shape |
| | L = cond.shape[1] |
| | x = x.reshape(B, L, -1, D) |
| | return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten( |
| | 1, 2 |
| | ) |
| |
|
| |
|
| | class CondHead(nn.Module): |
| | """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond).""" |
| |
|
| | n_cond = 6 |
| |
|
| | def __init__(self, d_model: int, noise_conditioning: str = "wan"): |
| | super().__init__() |
| | self.bias_in = ( |
| | nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None |
| | ) |
| | self.cond_proj = nn.ModuleList( |
| | [nn.Linear(d_model, d_model, bias=False) for _ in range(self.n_cond)] |
| | ) |
| |
|
| | def forward(self, cond): |
| | cond = cond + self.bias_in if self.bias_in is not None else cond |
| | h = F.silu(cond) |
| | return tuple(p(h) for p in self.cond_proj) |
| |
|
| |
|
| | class WorldDiTBlock(nn.Module): |
| | """Single transformer block with self-attention, optional cross-attention, and MLP.""" |
| |
|
| | def __init__( |
| | self, |
| | d_model: int, |
| | n_heads: int, |
| | mlp_ratio: int, |
| | layer_idx: int, |
| | prompt_conditioning: Optional[str], |
| | prompt_conditioning_period: int, |
| | prompt_embedding_dim: int, |
| | ctrl_conditioning_period: int, |
| | noise_conditioning: str, |
| | config, |
| | ): |
| | super().__init__() |
| | self.config = config |
| | self.attn = Attn(config, layer_idx) |
| | self.mlp = MLP(d_model, d_model * mlp_ratio, d_model) |
| | self.cond_head = CondHead(d_model, noise_conditioning) |
| |
|
| | do_prompt_cond = ( |
| | prompt_conditioning is not None |
| | and layer_idx % prompt_conditioning_period == 0 |
| | ) |
| | self.prompt_cross_attn = ( |
| | CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None |
| | ) |
| | do_ctrl_cond = layer_idx % ctrl_conditioning_period == 0 |
| | self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None |
| |
|
| | def forward(self, x, pos_ids, cond, ctx, v, kv_cache=None): |
| | """ |
| | 0) Causal Frame Attention |
| | 1) Frame->CTX Cross Attention |
| | 2) MLP |
| | """ |
| | s0, b0, g0, s1, b1, g1 = self.cond_head(cond) |
| |
|
| | |
| | residual = x |
| | x = ada_rmsnorm(x, s0, b0) |
| | x, v = self.attn(x, pos_ids, v, kv_cache=kv_cache) |
| | x = ada_gate(x, g0) + residual |
| |
|
| | |
| | if self.prompt_cross_attn is not None: |
| | x = ( |
| | self.prompt_cross_attn( |
| | rms_norm(x), |
| | context=rms_norm(ctx["prompt_emb"]), |
| | context_pad_mask=ctx["prompt_pad_mask"], |
| | ) |
| | + x |
| | ) |
| |
|
| | |
| | if self.ctrl_mlpfusion is not None: |
| | x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x |
| |
|
| | |
| | x = ada_gate(self.mlp(ada_rmsnorm(x, s1, b1)), g1) + x |
| |
|
| | return x, v |
| |
|
| |
|
| | class WorldDiT(nn.Module): |
| | """Stack of WorldDiTBlocks with shared parameters.""" |
| |
|
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.blocks = nn.ModuleList( |
| | [ |
| | WorldDiTBlock( |
| | d_model=config.d_model, |
| | n_heads=config.n_heads, |
| | mlp_ratio=config.mlp_ratio, |
| | layer_idx=idx, |
| | prompt_conditioning=config.prompt_conditioning, |
| | prompt_conditioning_period=config.prompt_conditioning_period, |
| | prompt_embedding_dim=config.prompt_embedding_dim, |
| | ctrl_conditioning_period=config.ctrl_conditioning_period, |
| | noise_conditioning=config.noise_conditioning, |
| | config=config, |
| | ) |
| | for idx in range(config.n_layers) |
| | ] |
| | ) |
| |
|
| | if config.noise_conditioning in ("dit_air", "wan"): |
| | ref_proj = self.blocks[0].cond_head.cond_proj |
| | for blk in self.blocks[1:]: |
| | for blk_mod, ref_mod in zip(blk.cond_head.cond_proj, ref_proj): |
| | blk_mod.weight = ref_mod.weight |
| |
|
| | |
| | ref_rope = self.blocks[0].attn.rope |
| | for blk in self.blocks[1:]: |
| | blk.attn.rope = ref_rope |
| |
|
| | def forward(self, x, pos_ids, cond, ctx, kv_cache=None): |
| | v = None |
| | for i, block in enumerate(self.blocks): |
| | x, v = block(x, pos_ids, cond, ctx, v, kv_cache=kv_cache) |
| | return x |
| |
|
| |
|
| | class WorldModel(ModelMixin, ConfigMixin): |
| | """ |
| | WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser. |
| | |
| | Denoises a frame given: |
| | - All previous frames (via KV cache) |
| | - The prompt embedding |
| | - The controller input embedding |
| | - The current noise level |
| | """ |
| |
|
| | _supports_gradient_checkpointing = False |
| | _keep_in_fp32_modules = ["denoise_step_emb", "rope"] |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | |
| | d_model: int = 2560, |
| | n_heads: int = 40, |
| | n_kv_heads: Optional[int] = 20, |
| | n_layers: int = 22, |
| | mlp_ratio: int = 5, |
| | channels: int = 16, |
| | height: int = 16, |
| | width: int = 16, |
| | patch: tuple = (2, 2), |
| | tokens_per_frame: int = 256, |
| | n_frames: int = 512, |
| | local_window: int = 16, |
| | global_window: int = 128, |
| | global_attn_period: int = 4, |
| | global_pinned_dilation: int = 8, |
| | global_attn_offset: int = -1, |
| | value_residual: bool = False, |
| | gated_attn: bool = True, |
| | n_buttons: int = 256, |
| | ctrl_conditioning: Optional[str] = "mlp_fusion", |
| | ctrl_conditioning_period: int = 3, |
| | ctrl_cond_dropout: float = 0.0, |
| | prompt_conditioning: Optional[str] = "cross_attention", |
| | prompt_conditioning_period: int = 3, |
| | prompt_embedding_dim: int = 2048, |
| | prompt_cond_dropout: float = 0.0, |
| | noise_conditioning: str = "wan", |
| | scheduler_sigmas: Optional[List[float]] = [ |
| | 1.0, |
| | 0.9483006596565247, |
| | 0.8379597067832947, |
| | 0.0, |
| | ], |
| | base_fps: int = 60, |
| | causal: bool = True, |
| | mlp_gradient_checkpointing: bool = True, |
| | block_gradient_checkpointing: bool = True, |
| | rope_impl: str = "ortho", |
| | ): |
| | super().__init__() |
| |
|
| | self.denoise_step_emb = NoiseConditioner(d_model) |
| | self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio) |
| |
|
| | if self.config.ctrl_conditioning is not None: |
| | self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout) |
| | if self.config.prompt_conditioning is not None: |
| | self.prompt_cfg = CFG( |
| | self.config.prompt_embedding_dim, self.config.prompt_cond_dropout |
| | ) |
| |
|
| | self.transformer = WorldDiT(self.config) |
| | self.patch = tuple(patch) |
| |
|
| | C, D = channels, d_model |
| | self.patchify = nn.Conv2d( |
| | C, D, kernel_size=self.patch, stride=self.patch, bias=False |
| | ) |
| | self.unpatchify = nn.Linear(D, C * math.prod(self.patch), bias=True) |
| | self.out_norm = AdaLN(d_model) |
| |
|
| | |
| | T = tokens_per_frame |
| | idx = torch.arange(T, dtype=torch.long) |
| | self.register_buffer( |
| | "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False |
| | ) |
| | self.register_buffer( |
| | "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False |
| | ) |
| | self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False) |
| |
|
| | def forward( |
| | self, |
| | x: Tensor, |
| | sigma: Tensor, |
| | frame_timestamp: Tensor, |
| | prompt_emb: Optional[Tensor] = None, |
| | prompt_pad_mask: Optional[Tensor] = None, |
| | mouse: Optional[Tensor] = None, |
| | button: Optional[Tensor] = None, |
| | scroll: Optional[Tensor] = None, |
| | kv_cache=None, |
| | ): |
| | """ |
| | Args: |
| | x: [B, N, C, H, W] - latent frames |
| | sigma: [B, N] - noise levels |
| | frame_timestamp: [B, N] - frame indices |
| | prompt_emb: [B, P, D] - prompt embeddings |
| | prompt_pad_mask: [B, P] - padding mask for prompts |
| | mouse: [B, N, 2] - mouse velocity |
| | button: [B, N, n_buttons] - button states |
| | scroll: [B, N, 1] - scroll wheel sign (-1, 0, 1) |
| | kv_cache: StaticKVCache instance |
| | ctrl_cond: whether to apply controller conditioning (inference only) |
| | prompt_cond: whether to apply prompt conditioning (inference only) |
| | """ |
| | B, N, C, H, W = x.shape |
| | ph, pw = self.patch |
| | assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch" |
| | Hp, Wp = H // ph, W // pw |
| | torch._assert( |
| | Hp * Wp == self.config.tokens_per_frame, |
| | f"{Hp} * {Wp} != {self.config.tokens_per_frame}", |
| | ) |
| |
|
| | torch._assert( |
| | B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1" |
| | ) |
| | self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f)) |
| | pos_ids = TensorDict( |
| | { |
| | "t_pos": self._t_pos_1f[None], |
| | "y_pos": self._y_pos_1f[None], |
| | "x_pos": self._x_pos_1f[None], |
| | }, |
| | batch_size=[1, self._t_pos_1f.numel()], |
| | ) |
| | cond = self.denoise_step_emb(sigma) |
| |
|
| | assert button is not None |
| | ctx = { |
| | "ctrl_emb": self.ctrl_emb(mouse, button, scroll), |
| | "prompt_emb": prompt_emb, |
| | "prompt_pad_mask": prompt_pad_mask, |
| | } |
| |
|
| | D = self.unpatchify.in_features |
| | x = self.patchify(x.reshape(B * N, C, H, W)) |
| | x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d") |
| | x = self.transformer(x, pos_ids, cond, ctx, kv_cache) |
| | x = F.silu(self.out_norm(x, cond)) |
| | x = eo.rearrange( |
| | self.unpatchify(x), |
| | "b (n hp wp) (c ph pw) -> b n c (hp ph) (wp pw)", |
| | n=N, |
| | hp=Hp, |
| | wp=Wp, |
| | ph=ph, |
| | pw=pw, |
| | ) |
| |
|
| | return x |
| |
|
| | def quantize(self, quant_type: str): |
| | quantize_model(self, quant_type) |
| |
|
| | def apply_inference_patches(self): |
| | _apply_inference_patches(self) |
| |
|