| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class SinusoidalPosEmb(nn.Module): |
| """Sinusoidal positional embedding for timestep inputs.""" |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| if t.ndim == 0: |
| t = t.unsqueeze(0) |
|
|
| if not torch.is_floating_point(t): |
| t = t.float() |
|
|
| t = t * 1000.0 |
| half_dim = self.dim // 2 |
| emb_scale = math.log(10000) / max(half_dim - 1, 1) |
| emb = torch.exp( |
| torch.arange(half_dim, device=t.device, dtype=t.dtype) * -emb_scale |
| ) |
| emb = t.unsqueeze(1) * emb.unsqueeze(0) |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
|
|
| class MultiTokenFusion(nn.Module): |
| """Project each modality to a shared hidden space and fuse across modalities.""" |
|
|
| def __init__( |
| self, |
| modality_dims: list[int], |
| hidden_dim: int = 256, |
| dropout: float = 0.1, |
| modality_dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.modality_dims = modality_dims |
| self.n_modalities = len(modality_dims) |
| self.hidden_dim = hidden_dim |
| self.modality_dropout = modality_dropout |
|
|
| self.projectors = nn.ModuleList( |
| [ |
| nn.Sequential( |
| nn.Linear(dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| ) |
| for dim in modality_dims |
| ] |
| ) |
|
|
| self.modality_emb = nn.Parameter(torch.randn(self.n_modalities, hidden_dim) * 0.02) |
|
|
| self.output_proj = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, modality_features: list[torch.Tensor]) -> torch.Tensor: |
| if len(modality_features) != self.n_modalities: |
| raise ValueError( |
| f"Expected {self.n_modalities} modalities, got {len(modality_features)}." |
| ) |
|
|
| projected = [] |
| for i, (feat, proj) in enumerate(zip(modality_features, self.projectors)): |
| h = proj(feat) |
| h = h + self.modality_emb[i] |
| projected.append(h) |
|
|
| if self.training and self.modality_dropout > 0: |
| keep_mask = ( |
| torch.rand( |
| projected[0].shape[0], |
| projected[0].shape[1], |
| self.n_modalities, |
| device=projected[0].device, |
| ) |
| > self.modality_dropout |
| ) |
| all_dropped = keep_mask.sum(dim=2, keepdim=True) == 0 |
| keep_mask[:, :, 0:1] = torch.max(keep_mask[:, :, 0:1], all_dropped) |
|
|
| scale = 1.0 / max(1.0 - self.modality_dropout, 1e-6) |
| for i in range(self.n_modalities): |
| projected[i] = projected[i] * keep_mask[:, :, i : i + 1] * scale |
|
|
| x = torch.stack(projected, dim=0).mean(dim=0) |
| return self.output_proj(x) |
|
|
|
|
| class SimpleFiLMBlock(nn.Module): |
| """Residual FiLM block with feed-forward and context cross-attention.""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| time_dim: int, |
| context_dim: int, |
| n_heads: int = 8, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.film = nn.Linear(time_dim, dim * 2) |
| self.norm1 = nn.LayerNorm(dim) |
| self.ffn = nn.Sequential( |
| nn.Linear(dim, dim * 4), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(dim * 4, dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| self.norm_q = nn.LayerNorm(dim) |
| self.norm_kv = nn.LayerNorm(context_dim) |
| self.cross_attn = nn.MultiheadAttention( |
| dim, |
| n_heads, |
| dropout=dropout, |
| batch_first=True, |
| kdim=context_dim, |
| vdim=context_dim, |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| t_emb: torch.Tensor, |
| context: torch.Tensor, |
| ) -> torch.Tensor: |
| scale_shift = self.film(t_emb) |
| scale, shift = scale_shift.chunk(2, dim=-1) |
|
|
| h = self.norm1(x) * (1 + scale) + shift |
| x = x + self.ffn(h) |
|
|
| q = self.norm_q(x).unsqueeze(1) |
| kv = self.norm_kv(context) |
| attn_out, _ = self.cross_attn(q, kv, kv, need_weights=False) |
| x = x + attn_out.squeeze(1) |
| return x |
|
|
|
|
| class VelocityNet(nn.Module): |
| """DiT-style velocity estimator with late-fusion context conditioning.""" |
|
|
| def __init__( |
| self, |
| output_dim: int, |
| hidden_dim: int = 256, |
| modality_dims: Optional[list[int]] = None, |
| n_blocks: int = 4, |
| n_heads: int = 8, |
| dropout: float = 0.1, |
| modality_dropout: float = 0.0, |
| max_seq_len: int = 2048, |
| temporal_attn_layers: int = 2, |
| ): |
| super().__init__() |
| self.output_dim = output_dim |
| self.hidden_dim = hidden_dim |
| self.modality_dims = modality_dims or [output_dim] |
| self.max_seq_len = max_seq_len |
|
|
| self.fusion_block = MultiTokenFusion( |
| modality_dims=self.modality_dims, |
| hidden_dim=hidden_dim, |
| dropout=dropout, |
| modality_dropout=modality_dropout, |
| ) |
|
|
| self.context_pos_emb = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim) * 0.02) |
|
|
| if temporal_attn_layers > 0: |
| self.temporal_attn = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=n_heads, |
| dim_feedforward=hidden_dim * 4, |
| dropout=dropout, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ), |
| num_layers=temporal_attn_layers, |
| ) |
| else: |
| self.temporal_attn = nn.Identity() |
|
|
| self.temporal_norm = nn.LayerNorm(hidden_dim) |
|
|
| self.input_proj = nn.Sequential( |
| nn.Linear(output_dim, hidden_dim), |
| nn.GELU(), |
| ) |
|
|
| self.time_emb = SinusoidalPosEmb(hidden_dim) |
| self.time_mlp = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| ) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| SimpleFiLMBlock( |
| dim=hidden_dim, |
| time_dim=hidden_dim, |
| context_dim=hidden_dim, |
| n_heads=n_heads, |
| dropout=dropout, |
| ) |
| for _ in range(n_blocks) |
| ] |
| ) |
|
|
| self.final_norm = nn.LayerNorm(hidden_dim) |
| self.output_layer = nn.Linear(hidden_dim, output_dim) |
|
|
| nn.init.constant_(self.output_layer.weight, 0) |
| nn.init.constant_(self.output_layer.bias, 0) |
|
|
| def encode_context(self, cond: torch.Tensor) -> torch.Tensor: |
| """Encode context tensor from (B, T, total_dim) to (B, T, hidden_dim).""" |
| if cond.ndim != 3: |
| raise ValueError(f"Expected cond with shape (B, T, D), got {tuple(cond.shape)}") |
|
|
| B, T, D = cond.shape |
| if T > self.max_seq_len: |
| raise ValueError( |
| f"Sequence length {T} exceeds max_seq_len={self.max_seq_len}. " |
| "Increase max_seq_len in stage2.velocity_net config." |
| ) |
|
|
| splits = [] |
| offset = 0 |
| for dim in self.modality_dims: |
| splits.append(cond[:, :, offset : offset + dim]) |
| offset += dim |
|
|
| if offset != D: |
| raise ValueError( |
| f"Context dim mismatch: expected sum(modality_dims)={offset}, got {D}." |
| ) |
|
|
| context = self.fusion_block(splits) |
| context = context + self.context_pos_emb[:, :T, :] |
| context = self.temporal_attn(context) |
| context = self.temporal_norm(context) |
| return context |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| t: torch.Tensor, |
| cond: Optional[torch.Tensor] = None, |
| pre_encoded_context: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| if t.ndim == 0: |
| t = t.expand(x.shape[0]) |
|
|
| if pre_encoded_context is not None: |
| context_encoded = pre_encoded_context |
| elif cond is not None: |
| context_encoded = self.encode_context(cond) |
| else: |
| context_encoded = torch.zeros( |
| x.shape[0], |
| 1, |
| self.hidden_dim, |
| device=x.device, |
| dtype=x.dtype, |
| ) |
|
|
| t_emb = self.time_mlp(self.time_emb(t)) |
|
|
| h = self.input_proj(x) |
| for block in self.blocks: |
| h = block(h, t_emb, context_encoded) |
|
|
| h = self.final_norm(h) |
| return self.output_layer(h) |
|
|