| """ |
| Models for T10 Triplet Next-Action Prediction. |
| |
| Two classes live here: |
| |
| * TripletHead — shared head module producing (verb_fine, verb_composite, |
| noun, hand) logits from a pooled feature vector. |
| * DeepConvLSTMTriplet — single-flow CNN+LSTM baseline (concatenates all |
| available modalities along the feature axis). |
| * DailyActFormer — our full-modality cross-modal Transformer that keeps |
| each modality in its own stem, fuses via a modality |
| token, and runs a causal temporal Transformer. Supports |
| the anticipatory auxiliary loss mentioned in the paper |
| plan (currently as a stub; enabled later in training). |
| |
| All models take: |
| x: dict[mod_name -> (B, T, F_mod)] |
| mask: BoolTensor (B, T) |
| and return a dict: |
| {'verb_fine': (B, NUM_VERB_FINE), |
| 'verb_composite': (B, NUM_VERB_COMPOSITE), |
| 'noun': (B, NUM_NOUN), |
| 'hand': (B, NUM_HAND)} |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| import sys |
| from pathlib import Path |
| from typing import Dict, List, Optional, Sequence |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| _THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(_THIS.parent)) |
| sys.path.insert(0, str(_THIS.parent.parent)) |
|
|
| try: |
| from experiments.taxonomy import ( |
| NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, |
| ) |
| except ModuleNotFoundError: |
| from taxonomy import ( |
| NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, |
| ) |
|
|
| |
| |
| |
|
|
| class _PrevActionConcat(nn.Module): |
| """Embeds the previous-segment (verb_composite, noun) ground-truth labels |
| and concatenates them to a pooled feature vector. Used by every model |
| when `use_prev_action=True`. The +1 vocab slot is the BOS / no-prev |
| sentinel emitted by the dataset for the first kept segment of each |
| recording. Output dim added to pooled = 2 * prev_emb_dim.""" |
|
|
| def __init__(self, prev_emb_dim: int = 32): |
| super().__init__() |
| from taxonomy import NUM_VERB_COMPOSITE as _NVC, NUM_NOUN as _NN |
| self.vc_emb = nn.Embedding(_NVC + 1, prev_emb_dim) |
| self.n_emb = nn.Embedding(_NN + 1, prev_emb_dim) |
| self.out_dim = 2 * prev_emb_dim |
|
|
| def forward(self, pooled: torch.Tensor, |
| prev_v_comp: Optional[torch.Tensor] = None, |
| prev_noun: Optional[torch.Tensor] = None) -> torch.Tensor: |
| if prev_v_comp is None or prev_noun is None: |
| B = pooled.size(0) |
| prev_v_comp = torch.full((B,), self.vc_emb.num_embeddings - 1, |
| dtype=torch.long, device=pooled.device) |
| prev_noun = torch.full((B,), self.n_emb.num_embeddings - 1, |
| dtype=torch.long, device=pooled.device) |
| pe = torch.cat([self.vc_emb(prev_v_comp), self.n_emb(prev_noun)], dim=-1) |
| return torch.cat([pooled, pe], dim=-1) |
|
|
|
|
| class TripletHead(nn.Module): |
| def __init__(self, feat_dim: int, hidden: int = 256, dropout: float = 0.2): |
| super().__init__() |
| self.norm = nn.LayerNorm(feat_dim) |
| self.trunk = nn.Sequential( |
| nn.Linear(feat_dim, hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| ) |
| self.verb_fine = nn.Linear(hidden, NUM_VERB_FINE) |
| self.verb_composite = nn.Linear(hidden, NUM_VERB_COMPOSITE) |
| self.noun = nn.Linear(hidden, NUM_NOUN) |
| self.hand = nn.Linear(hidden, NUM_HAND) |
|
|
| def forward(self, feat: torch.Tensor) -> Dict[str, torch.Tensor]: |
| h = self.trunk(self.norm(feat)) |
| return { |
| "verb_fine": self.verb_fine(h), |
| "verb_composite": self.verb_composite(h), |
| "noun": self.noun(h), |
| "hand": self.hand(h), |
| } |
|
|
|
|
| def _masked_mean_pool(h: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """Mean over the time axis of `h` (B, T, D) using a boolean mask (B, T).""" |
| m = mask.to(h.dtype).unsqueeze(-1) |
| return (h * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) |
|
|
|
|
| |
| |
| |
|
|
| class DeepConvLSTMTriplet(nn.Module): |
| """Single-flow CNN+LSTM. Concatenates per-modality features on F axis.""" |
|
|
| def __init__( |
| self, |
| modality_dims: Dict[str, int], |
| conv_filters: int = 64, |
| conv_kernel: int = 5, |
| num_conv_layers: int = 4, |
| lstm_hidden: int = 128, |
| num_lstm_layers: int = 2, |
| dropout: float = 0.2, |
| head_hidden: int = 256, |
| use_prev_action: bool = False, |
| prev_emb_dim: int = 32, |
| ): |
| super().__init__() |
| self.modality_dims = dict(modality_dims) |
| self.use_prev_action = use_prev_action |
| in_ch = sum(modality_dims.values()) |
|
|
| convs: List[nn.Module] = [] |
| c = in_ch |
| for i in range(num_conv_layers): |
| convs.append(nn.Sequential( |
| nn.Conv1d(c, conv_filters, conv_kernel, padding=conv_kernel // 2), |
| nn.BatchNorm1d(conv_filters), |
| nn.ReLU(), |
| nn.Dropout(dropout if i < num_conv_layers - 1 else dropout + 0.1), |
| )) |
| c = conv_filters |
| self.convs = nn.Sequential(*convs) |
|
|
| self.lstm = nn.LSTM( |
| conv_filters, lstm_hidden, num_layers=num_lstm_layers, |
| batch_first=True, bidirectional=False, |
| dropout=dropout if num_lstm_layers > 1 else 0.0, |
| ) |
| head_in = lstm_hidden |
| if use_prev_action: |
| self.prev_concat = _PrevActionConcat(prev_emb_dim) |
| head_in += self.prev_concat.out_dim |
| else: |
| self.prev_concat = None |
| self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) |
|
|
| def forward( |
| self, x: Dict[str, torch.Tensor], mask: torch.Tensor, |
| prev_v_comp: Optional[torch.Tensor] = None, |
| prev_noun: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2) |
| feats = self.convs(feats).transpose(1, 2) |
| out, (h_n, _) = self.lstm(feats) |
| pooled = h_n[-1] |
| if self.use_prev_action: |
| pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) |
| return self.head(pooled) |
|
|
|
|
| |
| |
| |
|
|
| class _ModalityStem(nn.Module): |
| """Multi-scale 1-D conv stem (kernels 3, 5, 9) per modality. |
| |
| Borrowed from HandFormer (the top-1 baseline on T10 recognition): three |
| parallel convolutions capture fast (k=3, ~0.15s @ 20Hz), medium (k=5), |
| and slow (k=9, ~0.45s) temporal patterns. Output is a 1×1 fusion of |
| the three branches, projected back to d_model. |
| """ |
|
|
| def __init__(self, in_dim: int, d_model: int, kernels=(3, 5, 9), |
| dropout: float = 0.1): |
| super().__init__() |
| self.kernels = kernels |
| self.branches = nn.ModuleList([ |
| nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels |
| ]) |
| self.merge = nn.Sequential( |
| nn.GELU(), |
| nn.Conv1d(d_model * len(kernels), d_model, 1), |
| ) |
| self.norm = nn.LayerNorm(d_model) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| z = x.transpose(1, 2) |
| multi = [c(z) for c in self.branches] |
| h = self.merge(torch.cat(multi, dim=1)).transpose(1, 2) |
| return self.drop(self.norm(h)) |
|
|
|
|
| class _QueryPool(nn.Module): |
| """Learnable-query cross-attention pooling (replaces mean pool). |
| |
| Inspired by FUTR (the top-5 baseline winner): a single learnable query |
| cross-attends to the entire encoder output, producing one summary vector. |
| Compared to a plain mean pool this lets the model weight informative |
| frames more heavily. |
| """ |
|
|
| def __init__(self, d_model: int, n_heads: int = 4, dropout: float = 0.1): |
| super().__init__() |
| self.q = nn.Parameter(torch.zeros(1, 1, d_model)) |
| nn.init.trunc_normal_(self.q, std=0.02) |
| self.attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True, |
| ) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, h: torch.Tensor, key_padding_mask: Optional[torch.Tensor]): |
| |
| B = h.size(0) |
| q = self.q.expand(B, -1, -1) |
| out, _ = self.attn(q, h, h, key_padding_mask=key_padding_mask, |
| need_weights=False) |
| return self.norm(out.squeeze(1)) |
|
|
|
|
| class _CrossModalTemporalShift(nn.Module): |
| """Cross-modal temporal-shift attention between two modalities. |
| |
| Motivation (paper case study, §sec:grasp-phase-main): EMG activation leads |
| motion onset by a sub-frame ~20ms in our 100Hz recordings. After the 5x |
| downsample to 20Hz, that lag is ~0.4 frames, but per-subject variability |
| plus slack in our segment annotations introduces a few frames of drift |
| that a fixed alignment cannot capture. |
| |
| We learn a discrete temporal shift Δ ∈ {-max_shift, …, +max_shift} frames |
| applied to one of the two modalities (EMG by default), so the shifted |
| tokens align with the other branch (MoCap) before cross-modal fusion. The |
| shift is sampled via straight-through Gumbel-softmax during training; at |
| inference we take the argmax (deterministic). |
| |
| Inputs are per-modality token sequences (B, T, D). Outputs the same shape. |
| Only the `shift_modality` branch is shifted; other modalities pass through. |
| """ |
|
|
| def __init__(self, max_shift: int = 3, tau: float = 1.0): |
| super().__init__() |
| self.max_shift = max_shift |
| self.tau = tau |
| |
| self.shift_logits = nn.Parameter(torch.zeros(2 * max_shift + 1)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| if self.training: |
| w = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True, dim=-1) |
| else: |
| w = F.one_hot(self.shift_logits.argmax(), |
| num_classes=2 * self.max_shift + 1).float() |
| shifted = [] |
| for i, s in enumerate(range(-self.max_shift, self.max_shift + 1)): |
| shifted.append(w[i] * torch.roll(x, shifts=s, dims=1)) |
| return torch.stack(shifted, dim=0).sum(dim=0) |
|
|
|
|
| class _CausalTransformerBlock(nn.Module): |
| """Standard Transformer encoder block with a strictly causal attention mask.""" |
|
|
| def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0, |
| dropout: float = 0.1): |
| super().__init__() |
| self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, |
| batch_first=True) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| mlp_dim = int(d_model * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, mlp_dim), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(mlp_dim, d_model), nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor, |
| key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: |
| h = self.norm1(x) |
| h, _ = self.attn(h, h, h, attn_mask=attn_mask, |
| key_padding_mask=key_padding_mask, need_weights=False) |
| x = x + h |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class DailyActFormer(nn.Module): |
| """Cross-modal Transformer that uses every available modality. |
| |
| Architecture outline: |
| per-modality stem → learnable modality embedding → |
| concat across time (each frame -> M modality tokens) → |
| 1 fusion-layer cross-modal attention (compress M→1 per frame) → |
| temporal Transformer (bidirectional by default; causal when |
| `causal=True` for anticipation-style next-action prediction) |
| → pooled → TripletHead |
| |
| For simplicity the fusion step is an attention pooling with learnable |
| queries, rather than a full cross-modal block. This keeps the parameter |
| count modest (2–4 M range with d_model=128). |
| """ |
|
|
| def __init__( |
| self, |
| modality_dims: Dict[str, int], |
| d_model: int = 128, |
| n_layers: int = 4, |
| n_heads: int = 4, |
| dropout: float = 0.1, |
| head_hidden: int = 256, |
| max_T: int = 256, |
| causal: bool = False, |
| xshift_modality: Optional[str] = "emg", |
| xshift_max: int = 3, |
| use_prev_action: bool = False, |
| prev_emb_dim: int = 32, |
| ): |
| super().__init__() |
| self.modalities = list(modality_dims.keys()) |
| self.causal = causal |
| self.use_prev_action = use_prev_action |
|
|
| |
| if use_prev_action: |
| self.prev_concat = _PrevActionConcat(prev_emb_dim) |
| self._prev_extra_dim = self.prev_concat.out_dim |
| else: |
| self.prev_concat = None |
| self._prev_extra_dim = 0 |
|
|
| |
| |
| if xshift_modality is not None and xshift_modality in modality_dims: |
| self.xshift_modality = xshift_modality |
| self.xshift = _CrossModalTemporalShift(max_shift=xshift_max) |
| else: |
| self.xshift_modality = None |
| self.xshift = None |
|
|
| |
| self.stems = nn.ModuleDict({ |
| m: _ModalityStem(F, d_model, dropout=dropout) |
| for m, F in modality_dims.items() |
| }) |
|
|
| |
| self.modality_embed = nn.Parameter( |
| torch.zeros(len(self.modalities), d_model) |
| ) |
| nn.init.trunc_normal_(self.modality_embed, std=0.02) |
|
|
| |
| self.fusion_q = nn.Parameter(torch.zeros(1, 1, d_model)) |
| self.fusion_kv = nn.LayerNorm(d_model) |
| self.fusion_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True) |
|
|
| |
| self.pos_embed = nn.Parameter(torch.zeros(1, max_T, d_model)) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| self.max_T = max_T |
|
|
| |
| self.temporal_norm = nn.LayerNorm(d_model) |
| self.temporal = nn.ModuleList([ |
| _CausalTransformerBlock(d_model, n_heads, dropout=dropout) |
| for _ in range(n_layers) |
| ]) |
|
|
| |
| self.pool = _QueryPool(d_model, n_heads=n_heads, dropout=dropout) |
|
|
| |
| head_in = d_model + self._prev_extra_dim |
| self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) |
|
|
| nn.init.trunc_normal_(self.fusion_q, std=0.02) |
|
|
| |
| def _causal_mask(self, T: int, device) -> torch.Tensor: |
| |
| m = torch.full((T, T), float("-inf"), device=device) |
| m.triu_(diagonal=1) |
| return m |
|
|
| |
| def forward( |
| self, x: Dict[str, torch.Tensor], mask: torch.Tensor, |
| prev_v_comp: Optional[torch.Tensor] = None, |
| prev_noun: Optional[torch.Tensor] = None, |
| return_features: bool = False, |
| ) -> Dict[str, torch.Tensor]: |
| |
| stem_tokens: List[torch.Tensor] = [] |
| mods_in = [m for m in self.modalities if m in x] |
| if not mods_in: |
| raise ValueError("No modality from the model signature was provided.") |
| for i, m in enumerate(mods_in): |
| h = self.stems[m](x[m]) |
| |
| |
| |
| if self.xshift is not None and m == self.xshift_modality: |
| h = self.xshift(h) |
| h = h + self.modality_embed[self.modalities.index(m)] |
| stem_tokens.append(h) |
|
|
| |
| |
| B, T, D = stem_tokens[0].shape |
| |
| stacked = torch.stack(stem_tokens, dim=2) |
| M = stacked.size(2) |
| stacked = stacked.reshape(B * T, M, D) |
| kv = self.fusion_kv(stacked) |
| q = self.fusion_q.expand(B * T, -1, -1) |
| fused, _ = self.fusion_attn(q, kv, kv, need_weights=False) |
| fused = fused.reshape(B, T, D) |
|
|
| |
| if T > self.max_T: |
| raise ValueError(f"T={T} exceeds max_T={self.max_T}") |
| h = fused + self.pos_embed[:, :T, :] |
| h = self.temporal_norm(h) |
|
|
| attn_mask = self._causal_mask(T, h.device) if self.causal else None |
| key_padding = ~mask if mask is not None else None |
| for block in self.temporal: |
| h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding) |
|
|
| |
| pooled = self.pool(h, key_padding_mask=key_padding) |
|
|
| |
| if self.use_prev_action: |
| pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) |
|
|
| logits = self.head(pooled) |
| if return_features: |
| logits["_pooled"] = pooled |
| return logits |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class _RULSTMBranch(nn.Module): |
| def __init__(self, in_dim: int, hidden: int, future_steps: int, |
| dropout: float = 0.2): |
| super().__init__() |
| self.future_steps = future_steps |
| self.rolling = nn.LSTM(in_dim, hidden, batch_first=True) |
| self.unrolling = nn.LSTMCell(hidden, hidden) |
| self.drop = nn.Dropout(dropout) |
| self.out_dim = hidden |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| |
| |
| _, (h_n, c_n) = self.rolling(x) |
| h = h_n.squeeze(0); c = c_n.squeeze(0) |
| inp = h |
| for _ in range(self.future_steps): |
| h, c = self.unrolling(inp, (h, c)) |
| inp = h |
| return self.drop(h) |
|
|
|
|
| class RULSTMTriplet(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], hidden: int = 128, |
| future_steps: int = 8, dropout: float = 0.2, |
| head_hidden: int = 256, |
| use_prev_action: bool = False, prev_emb_dim: int = 32): |
| super().__init__() |
| self.use_prev_action = use_prev_action |
| self.branches = nn.ModuleDict({ |
| m: _RULSTMBranch(F, hidden, future_steps, dropout) |
| for m, F in modality_dims.items() |
| }) |
| head_in = hidden |
| if use_prev_action: |
| self.prev_concat = _PrevActionConcat(prev_emb_dim) |
| head_in += self.prev_concat.out_dim |
| else: |
| self.prev_concat = None |
| self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) |
|
|
| def forward(self, x, mask, prev_v_comp=None, prev_noun=None): |
| feats = [] |
| for m in x: |
| feats.append(self.branches[m](x[m], mask)) |
| fused = torch.stack(feats, dim=0).mean(dim=0) |
| if self.use_prev_action: |
| fused = self.prev_concat(fused, prev_v_comp, prev_noun) |
| return self.head(fused) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class FUTRTriplet(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], d_model: int = 128, |
| n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1, |
| head_hidden: int = 256, max_T: int = 256, |
| use_prev_action: bool = False, prev_emb_dim: int = 32): |
| super().__init__() |
| self.use_prev_action = use_prev_action |
| in_dim = sum(modality_dims.values()) |
| self.in_proj = nn.Linear(in_dim, d_model) |
| self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| self.max_T = max_T |
|
|
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, |
| dropout=dropout, batch_first=True, activation="gelu", |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) |
|
|
| self.future_q = nn.Parameter(torch.zeros(1, 1, d_model)) |
| nn.init.trunc_normal_(self.future_q, std=0.02) |
| self.cross_attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True, |
| ) |
| head_in = d_model |
| if use_prev_action: |
| self.prev_concat = _PrevActionConcat(prev_emb_dim) |
| head_in += self.prev_concat.out_dim |
| else: |
| self.prev_concat = None |
| self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) |
|
|
| def forward(self, x, mask, prev_v_comp=None, prev_noun=None): |
| feats = torch.cat([x[m] for m in x], dim=-1) |
| B, T, _ = feats.shape |
| if T > self.max_T: |
| raise ValueError(f"T={T} exceeds FUTR max_T={self.max_T}") |
| h = self.in_proj(feats) + self.pos[:, :T, :] |
| h = self.encoder(h, src_key_padding_mask=~mask) |
| q = self.future_q.expand(B, -1, -1) |
| out, _ = self.cross_attn(q, h, h, key_padding_mask=~mask, |
| need_weights=False) |
| pooled = out.squeeze(1) |
| if self.use_prev_action: |
| pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) |
| return self.head(pooled) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class AFFTTriplet(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], d_model: int = 96, |
| n_heads: int = 4, n_layers: int = 3, dropout: float = 0.1, |
| head_hidden: int = 256, max_T: int = 256, |
| use_prev_action: bool = False, prev_emb_dim: int = 32): |
| super().__init__() |
| self.use_prev_action = use_prev_action |
| self.modalities = list(modality_dims.keys()) |
| self.stems = nn.ModuleDict({ |
| m: nn.Linear(F, d_model) for m, F in modality_dims.items() |
| }) |
| self.mod_embed = nn.Parameter( |
| torch.zeros(len(self.modalities), d_model) |
| ) |
| nn.init.trunc_normal_(self.mod_embed, std=0.02) |
| self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| self.max_T = max_T |
| self.d_model = d_model |
|
|
| self.blocks = nn.ModuleList([ |
| _CausalTransformerBlock(d_model, n_heads, dropout=dropout) |
| for _ in range(n_layers) |
| ]) |
| head_in = d_model |
| if use_prev_action: |
| self.prev_concat = _PrevActionConcat(prev_emb_dim) |
| head_in += self.prev_concat.out_dim |
| else: |
| self.prev_concat = None |
| self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) |
|
|
| def _expand_causal_mask(self, T: int, M: int, device) -> torch.Tensor: |
| |
| |
| ts = torch.arange(T, device=device).unsqueeze(1).expand(-1, M).reshape(-1) |
| return ts[:, None] < ts[None, :] |
|
|
| def forward(self, x, mask, prev_v_comp=None, prev_noun=None): |
| |
| mods = [m for m in self.modalities if m in x] |
| per_mod_tokens = [] |
| B, T, _ = x[mods[0]].shape |
| for i, m in enumerate(mods): |
| h = self.stems[m](x[m]) + self.mod_embed[self.modalities.index(m)] |
| per_mod_tokens.append(h) |
| stacked = torch.stack(per_mod_tokens, dim=2) |
| M = stacked.size(2) |
| tokens = stacked.reshape(B, T * M, self.d_model) |
| if T > self.max_T: |
| raise ValueError(f"T={T} exceeds AFFT max_T={self.max_T}") |
| pos_per_frame = self.pos[:, :T, :].unsqueeze(2).expand(-1, -1, M, -1) |
| tokens = tokens + pos_per_frame.reshape(1, T * M, self.d_model) |
| attn_mask = self._expand_causal_mask(T, M, tokens.device) |
| attn_mask = torch.where(attn_mask, torch.tensor(float("-inf"), |
| device=tokens.device), |
| torch.tensor(0.0, device=tokens.device)) |
| kp = (~mask).unsqueeze(2).expand(-1, -1, M).reshape(B, T * M) |
| for blk in self.blocks: |
| tokens = blk(tokens, attn_mask=attn_mask, key_padding_mask=kp) |
| last_slice = tokens[:, -M:, :] |
| pooled = last_slice.mean(dim=1) |
| if self.use_prev_action: |
| pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) |
| return self.head(pooled) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| class HandFormerTriplet(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], d_model: int = 128, |
| n_heads: int = 4, n_layers: int = 3, kernels=(3, 5, 9), |
| dropout: float = 0.1, head_hidden: int = 256, max_T: int = 256, |
| use_prev_action: bool = False, prev_emb_dim: int = 32): |
| super().__init__() |
| self.use_prev_action = use_prev_action |
| in_dim = sum(modality_dims.values()) |
| self.multi_conv = nn.ModuleList([ |
| nn.Conv1d(in_dim, d_model, k, padding=k // 2) for k in kernels |
| ]) |
| self.conv_merge = nn.Conv1d(d_model * len(kernels), d_model, 1) |
|
|
| self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| self.max_T = max_T |
|
|
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, |
| dropout=dropout, batch_first=True, activation="gelu", |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) |
| head_in = d_model |
| if use_prev_action: |
| self.prev_concat = _PrevActionConcat(prev_emb_dim) |
| head_in += self.prev_concat.out_dim |
| else: |
| self.prev_concat = None |
| self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) |
|
|
| def forward(self, x, mask, prev_v_comp=None, prev_noun=None): |
| feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2) |
| multi = [c(feats) for c in self.multi_conv] |
| h = self.conv_merge(torch.cat(multi, dim=1)) |
| h = h.transpose(1, 2) |
| T = h.size(1) |
| if T > self.max_T: |
| raise ValueError(f"T={T} exceeds HandFormer max_T={self.max_T}") |
| h = h + self.pos[:, :T, :] |
| h = self.encoder(h, src_key_padding_mask=~mask) |
| pooled = _masked_mean_pool(h, mask) |
| if self.use_prev_action: |
| pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) |
| return self.head(pooled) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| class ActionLLMSurrogate(nn.Module): |
| def __init__(self, modality_dims: Dict[str, int], d_model: int = 192, |
| n_heads: int = 6, n_layers: int = 2, dropout: float = 0.1, |
| head_hidden: int = 256, max_T: int = 256, |
| use_prev_action: bool = False, prev_emb_dim: int = 32): |
| super().__init__() |
| self.use_prev_action = use_prev_action |
| in_dim = sum(modality_dims.values()) |
| self.stem = nn.Sequential( |
| nn.Conv1d(in_dim, d_model, 5, padding=2), |
| nn.GELU(), |
| nn.Conv1d(d_model, d_model, 5, padding=2), |
| ) |
| self.pos = nn.Parameter(torch.zeros(1, max_T, d_model)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| self.max_T = max_T |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=4 * d_model, |
| dropout=dropout, batch_first=True, activation="gelu", |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers) |
| head_in = d_model |
| if use_prev_action: |
| self.prev_concat = _PrevActionConcat(prev_emb_dim) |
| head_in += self.prev_concat.out_dim |
| else: |
| self.prev_concat = None |
| self.head = TripletHead(head_in, hidden=head_hidden, dropout=dropout) |
|
|
| def forward(self, x, mask, prev_v_comp=None, prev_noun=None): |
| feats = torch.cat([x[m] for m in x], dim=-1).transpose(1, 2) |
| h = self.stem(feats).transpose(1, 2) |
| T = h.size(1) |
| if T > self.max_T: |
| raise ValueError(f"T={T} exceeds ActionLLM max_T={self.max_T}") |
| h = h + self.pos[:, :T, :] |
| h = self.encoder(h, src_key_padding_mask=~mask) |
| pooled = _masked_mean_pool(h, mask) |
| if self.use_prev_action: |
| pooled = self.prev_concat(pooled, prev_v_comp, prev_noun) |
| return self.head(pooled) |
|
|
|
|
| |
| |
| |
|
|
| def build_model( |
| name: str, modality_dims: Dict[str, int], **kwargs, |
| ) -> nn.Module: |
| name = name.lower() |
| if name in ("deepconvlstm", "dcl"): |
| return DeepConvLSTMTriplet(modality_dims, **kwargs) |
| if name in ("dailyactformer", "ours", "daf"): |
| return DailyActFormer(modality_dims, **kwargs) |
| if name in ("rulstm",): |
| return RULSTMTriplet(modality_dims, **kwargs) |
| if name in ("futr",): |
| return FUTRTriplet(modality_dims, **kwargs) |
| if name in ("afft",): |
| return AFFTTriplet(modality_dims, **kwargs) |
| if name in ("handformer",): |
| return HandFormerTriplet(modality_dims, **kwargs) |
| if name in ("actionllm",): |
| return ActionLLMSurrogate(modality_dims, **kwargs) |
| raise ValueError(f"Unknown model: {name}") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| B, T = 2, 160 |
| dims = {"imu": 180, "emg": 8, "eyetrack": 24} |
| x = {m: torch.randn(B, T, d) for m, d in dims.items()} |
| mask = torch.ones(B, T, dtype=torch.bool) |
|
|
| for name in ("deepconvlstm", "dailyactformer", "rulstm", "futr", "afft", |
| "handformer", "actionllm"): |
| model = build_model(name, dims) |
| n_params = sum(p.numel() for p in model.parameters()) |
| out = model(x, mask) |
| print(f"{name:16s} params={n_params:>10,} shapes=" |
| f"vf={tuple(out['verb_fine'].shape)} " |
| f"vc={tuple(out['verb_composite'].shape)} " |
| f"n={tuple(out['noun'].shape)} " |
| f"h={tuple(out['hand'].shape)}") |
|
|