| """ |
| SyncFuse — our proposed method for T1 scene recognition. |
| |
| Four components (all toggleable via args for ablation): |
| |
| (1) Modality dropout: per-sample independent Bernoulli(p=0.3) drop on each |
| modality during training; at test time all modalities |
| are active. Keeps at least 1 modality. |
| (2) Pretrained transfer: each per-modality backbone is optionally loaded from |
| an independently pretrained single-modality |
| checkpoint and frozen during fine-tuning. |
| (3) Cross-modal temporal-shift attention: |
| a late cross-attention block where EMG queries |
| attend to MoCap keys/values at a LEARNED temporal |
| offset Δ (Gumbel-softmax over {-10,...,+10} bins at |
| 20 Hz = ±500 ms). Motivated by the paper's case-study |
| finding (EMG leads motion by ~20 ms sub-frame). |
| (4) Learnable late fusion: |
| per-modality classifier logits are combined with a |
| learnable softmax-weighted average (temperature is |
| also learned). Equivalent to `late_agg='learned'` |
| in the repo's existing LateFusionModel. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import random |
|
|
|
|
| def masked_mean(x, mask): |
| m = mask.unsqueeze(-1).float() |
| return (x * m).sum(dim=1) / m.sum(dim=1).clamp(min=1.0) |
|
|
|
|
| |
| |
| |
|
|
| class ModTransformer(nn.Module): |
| def __init__(self, feat_dim, hidden=128, n_layers=2, n_heads=4, dropout=0.1): |
| super().__init__() |
| self.in_proj = nn.Linear(feat_dim, hidden) |
| self.pos = nn.Parameter(torch.zeros(1, 4096, hidden)) |
| nn.init.trunc_normal_(self.pos, std=0.02) |
| layer = nn.TransformerEncoderLayer( |
| d_model=hidden, nhead=n_heads, dim_feedforward=4 * hidden, |
| dropout=dropout, batch_first=True, activation='gelu', |
| ) |
| self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) |
| self.output_dim = hidden |
|
|
| def forward(self, x, mask): |
| |
| T = x.size(1) |
| h = self.in_proj(x) + self.pos[:, :T, :] |
| h = self.encoder(h, src_key_padding_mask=~mask) |
| return h |
|
|
|
|
| |
| |
| |
|
|
| class TemporalShiftAttention(nn.Module): |
| """Multi-head attention where queries are temporally shifted by a learned |
| offset Δ from the keys. Δ is drawn from a discrete set {-3,...,+3} via |
| straight-through Gumbel-softmax: we sample ONE shift per forward pass, |
| but the softmax weights flow gradient back through shift_logits. |
| |
| At 20 Hz bins, ±3 ≈ ±150 ms, which brackets the paper's ~20 ms EMG-motion |
| lead. Memory cost is ~1 attention pass (not 7).""" |
| def __init__(self, d_model, n_heads=4, dropout=0.1, max_shift=3, |
| gumbel_tau=1.0): |
| super().__init__() |
| self.max_shift = max_shift |
| self.shifts = list(range(-max_shift, max_shift + 1)) |
| self.shift_logits = nn.Parameter(torch.zeros(len(self.shifts))) |
| self.tau = gumbel_tau |
| self.attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True, |
| ) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def _shift_tensor(self, x, shift, mask): |
| if shift == 0: |
| return x, mask |
| B, T, D = x.shape |
| if shift > 0: |
| pad = torch.zeros(B, shift, D, device=x.device, dtype=x.dtype) |
| x_s = torch.cat([x[:, shift:, :], pad], dim=1) |
| m_s = torch.cat([mask[:, shift:], |
| torch.zeros(B, shift, device=mask.device, dtype=torch.bool)], |
| dim=1) |
| else: |
| s = -shift |
| pad = torch.zeros(B, s, D, device=x.device, dtype=x.dtype) |
| x_s = torch.cat([pad, x[:, :-s, :]], dim=1) |
| m_s = torch.cat([torch.zeros(B, s, device=mask.device, dtype=torch.bool), |
| mask[:, :-s]], dim=1) |
| return x_s, m_s |
|
|
| def forward(self, q_tokens, kv_tokens, q_mask, kv_mask, hard=False): |
| if hard or not self.training: |
| |
| with torch.no_grad(): |
| idx = self.shift_logits.argmax().item() |
| shift = self.shifts[idx] |
| shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask) |
| out, _ = self.attn(q_tokens, shifted_kv, shifted_kv, |
| key_padding_mask=~shifted_mask) |
| return self.norm(q_tokens + out) |
|
|
| |
| |
| one_hot = F.gumbel_softmax(self.shift_logits, tau=self.tau, hard=True) |
| |
| idx = int(one_hot.argmax().item()) |
| shift = self.shifts[idx] |
| shifted_kv, shifted_mask = self._shift_tensor(kv_tokens, shift, kv_mask) |
| out, _ = self.attn(q_tokens, shifted_kv, shifted_kv, |
| key_padding_mask=~shifted_mask) |
| |
| out = out * one_hot[idx] |
| return self.norm(q_tokens + out) |
|
|
|
|
| |
| |
| |
|
|
| class SyncFuse(nn.Module): |
| def __init__(self, modality_dims: dict, num_classes, hidden=128, n_heads=4, |
| n_layers=2, dropout=0.1, |
| use_xmod_shift=True, use_learned_late=True): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = modality_dims |
| self.use_xmod_shift = use_xmod_shift |
| self.use_learned_late = use_learned_late |
|
|
| self.branches = nn.ModuleDict({ |
| m: ModTransformer(d, hidden, n_layers, n_heads, dropout) |
| for m, d in modality_dims.items() |
| }) |
| self.classifiers = nn.ModuleDict({ |
| m: nn.Sequential(nn.LayerNorm(hidden), nn.Dropout(dropout), |
| nn.Linear(hidden, num_classes)) |
| for m in self.mod_names |
| }) |
|
|
| |
| |
| if use_xmod_shift and 'emg' in self.mod_names and 'mocap' in self.mod_names: |
| self.xmod_emg2mocap = TemporalShiftAttention(hidden, n_heads, dropout) |
| self.xmod_mocap2emg = TemporalShiftAttention(hidden, n_heads, dropout) |
| else: |
| self.xmod_emg2mocap = None |
| self.xmod_mocap2emg = None |
|
|
| if use_learned_late: |
| self.late_logits = nn.Parameter(torch.zeros(len(self.mod_names))) |
| self.late_temperature = nn.Parameter(torch.ones(1)) |
|
|
| def load_pretrained(self, pretrain_paths: dict, freeze=True): |
| """Load pretrained single-modality checkpoints into branches. |
| pretrain_paths: {modality_name: path_to_checkpoint_state_dict}.""" |
| import torch as _torch |
| for m, path in pretrain_paths.items(): |
| if m not in self.branches: |
| continue |
| try: |
| sd = _torch.load(path, weights_only=True, map_location='cpu') |
| except TypeError: |
| sd = _torch.load(path, map_location='cpu') |
| |
| mapped = {} |
| for k, v in sd.items(): |
| if k.startswith('backbone.'): |
| new_k = k.replace('backbone.', '') |
| if new_k in self.branches[m].state_dict(): |
| mapped[new_k] = v |
| if mapped: |
| self.branches[m].load_state_dict(mapped, strict=False) |
| if freeze: |
| for p in self.branches[m].parameters(): |
| p.requires_grad = False |
| print(f" [SyncFuse] loaded {len(mapped)} tensors into branch '{m}' (frozen={freeze})") |
|
|
| def forward(self, x, mask, mod_dropout_p=0.0, training_time=True): |
| """ |
| x: (B, T, F_total) concatenated features |
| mask: (B, T) |
| mod_dropout_p: probability of dropping each modality (training only) |
| """ |
| B, T, _ = x.shape |
|
|
| |
| offset = 0 |
| feats = {} |
| for m in self.mod_names: |
| d = self.mod_dims[m] |
| feats[m] = x[..., offset:offset + d] |
| offset += d |
|
|
| |
| active = {m: torch.ones(B, dtype=torch.bool, device=x.device) for m in self.mod_names} |
| if training_time and self.training and mod_dropout_p > 0: |
| drop_map = {m: (torch.rand(B, device=x.device) < mod_dropout_p) |
| for m in self.mod_names} |
| all_dropped = torch.stack([drop_map[m] for m in self.mod_names], dim=0).all(dim=0) |
| if all_dropped.any(): |
| |
| rescue_idx = torch.randint(0, len(self.mod_names), |
| (all_dropped.sum().item(),), |
| device=x.device) |
| mod_name_tensor = self.mod_names |
| j = 0 |
| for b in range(B): |
| if all_dropped[b]: |
| r = mod_name_tensor[rescue_idx[j].item()] |
| drop_map[r][b] = False |
| j += 1 |
| for m in self.mod_names: |
| active[m] = ~drop_map[m] |
| |
| feats[m] = feats[m] * active[m].view(B, 1, 1).float() |
|
|
| |
| tokens = {} |
| for m in self.mod_names: |
| tokens[m] = self.branches[m](feats[m], mask) |
|
|
| |
| if self.xmod_emg2mocap is not None: |
| tokens['emg'] = self.xmod_emg2mocap( |
| tokens['emg'], tokens['mocap'], mask, mask, |
| hard=not self.training, |
| ) |
| tokens['mocap'] = self.xmod_mocap2emg( |
| tokens['mocap'], tokens['emg'], mask, mask, |
| hard=not self.training, |
| ) |
|
|
| |
| logits_per = [] |
| for m in self.mod_names: |
| pooled = masked_mean(tokens[m], mask) |
| logits_per.append(self.classifiers[m](pooled)) |
| stacked = torch.stack(logits_per, dim=0) |
|
|
| |
| if training_time and self.training and mod_dropout_p > 0: |
| act_mask = torch.stack([active[m].float() for m in self.mod_names], dim=0) |
| |
| if self.use_learned_late: |
| w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0) |
| w = w.view(-1, 1) * act_mask |
| w = w / w.sum(dim=0, keepdim=True).clamp(min=1e-6) |
| out = (stacked * w.unsqueeze(-1)).sum(dim=0) |
| else: |
| w = act_mask / act_mask.sum(dim=0, keepdim=True).clamp(min=1e-6) |
| out = (stacked * w.unsqueeze(-1)).sum(dim=0) |
| else: |
| |
| if self.use_learned_late: |
| w = F.softmax(self.late_logits / self.late_temperature.clamp(min=0.1), dim=0) |
| out = (stacked * w.view(-1, 1, 1)).sum(dim=0) |
| else: |
| out = stacked.mean(dim=0) |
| return out |
|
|