| |
| """ |
| Training loop for T10 Triplet Next-Action Prediction. |
| |
| Usage example: |
| python3 experiments/train_seqpred.py \ |
| --model dailyactformer \ |
| --modalities imu,emg,eyetrack,mocap,pressure \ |
| --t_obs 8 --t_fut 2 \ |
| --epochs 40 --batch_size 32 --lr 3e-4 \ |
| --output_dir results/seqpred/ours_all5_tfut2_seed42 \ |
| --seed 42 |
| """ |
|
|
| from __future__ import annotations |
|
|
| |
| |
| import pandas |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Dict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| |
| |
| |
| THIS = Path(__file__).resolve() |
| sys.path.insert(0, str(THIS.parent)) |
| sys.path.insert(0, str(THIS.parents[1])) |
|
|
| try: |
| from experiments.dataset_seqpred import ( |
| TripletSeqPredDataset, build_train_test, collate_triplet, |
| TRAIN_VOLS_V3, TEST_VOLS_V3, |
| ) |
| from experiments.models_seqpred import build_model |
| from experiments.taxonomy import ( |
| NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, |
| ) |
| except ModuleNotFoundError: |
| from dataset_seqpred import ( |
| TripletSeqPredDataset, build_train_test, collate_triplet, |
| TRAIN_VOLS_V3, TEST_VOLS_V3, |
| ) |
| from models_seqpred import build_model |
| from taxonomy import ( |
| NUM_VERB_FINE, NUM_VERB_COMPOSITE, NUM_NOUN, NUM_HAND, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def top_k_correct(logits: torch.Tensor, target: torch.Tensor, k: int) -> torch.Tensor: |
| """Return a bool tensor (B,) indicating whether `target` is in top-k of logits.""" |
| k = min(k, logits.size(1)) |
| _, top = logits.topk(k, dim=1) |
| return (top == target.unsqueeze(1)).any(dim=1) |
|
|
|
|
| def mean_class_recall(logits: torch.Tensor, target: torch.Tensor, |
| num_classes: int) -> float: |
| pred = logits.argmax(dim=1) |
| recall_per_cls = [] |
| for c in range(num_classes): |
| sel = (target == c) |
| n = int(sel.sum().item()) |
| if n == 0: |
| continue |
| r = float((pred[sel] == c).float().mean().item()) |
| recall_per_cls.append(r) |
| return float(np.mean(recall_per_cls)) if recall_per_cls else 0.0 |
|
|
|
|
| def build_class_weights(counts: np.ndarray) -> torch.Tensor: |
| """Inverse-frequency weights, normalized so mean weight = 1.""" |
| counts = counts.astype(np.float32).clip(min=1.0) |
| w = 1.0 / counts |
| w = w / w.mean() |
| return torch.from_numpy(w) |
|
|
|
|
| |
| |
| |
|
|
| def triplet_loss( |
| logits: Dict[str, torch.Tensor], |
| y: Dict[str, torch.Tensor], |
| weights: Dict[str, torch.Tensor], |
| lambda_cfg: Dict[str, float], |
| label_smoothing: float = 0.05, |
| ) -> Dict[str, torch.Tensor]: |
| losses = {} |
| for head in ("verb_fine", "verb_composite", "noun", "hand"): |
| w = weights.get(head, None) |
| if w is not None: |
| w = w.to(logits[head].device) |
| l = F.cross_entropy( |
| logits[head], y[head], weight=w, |
| label_smoothing=label_smoothing, |
| ) |
| losses[head] = l |
| total = sum(lambda_cfg.get(k, 1.0) * losses[k] for k in losses) |
| losses["total"] = total |
| return losses |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(model, loader, device) -> Dict[str, float]: |
| model.eval() |
| all_logits: Dict[str, list] = {k: [] for k in |
| ("verb_fine", "verb_composite", "noun", "hand")} |
| all_y: Dict[str, list] = {k: [] for k in |
| ("verb_fine", "verb_composite", "noun", "hand")} |
|
|
| for batch in loader: |
| |
| if len(batch) == 6: |
| x, mask, lens, y, meta, prev = batch |
| else: |
| x, mask, lens, y, meta = batch |
| prev = None |
| x = {m: t.to(device) for m, t in x.items()} |
| mask = mask.to(device) |
| kwargs = {} |
| if prev is not None and getattr(model, "use_prev_action", False): |
| kwargs["prev_v_comp"] = prev["verb_composite"].to(device) |
| kwargs["prev_noun"] = prev["noun"].to(device) |
| logits = model(x, mask, **kwargs) |
| for k in all_logits: |
| all_logits[k].append(logits[k].cpu()) |
| all_y[k].append(y[k]) |
|
|
| logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()} |
| y_cat = {k: torch.cat(v, dim=0) for k, v in all_y.items()} |
|
|
| m = {} |
| for k, K in [("verb_fine", NUM_VERB_FINE), |
| ("verb_composite", NUM_VERB_COMPOSITE), |
| ("noun", NUM_NOUN), |
| ("hand", NUM_HAND)]: |
| preds = logits_cat[k].argmax(dim=1) |
| acc1 = float((preds == y_cat[k]).float().mean().item()) |
| m[f"{k}_top1"] = acc1 |
| if K > 5: |
| acc5 = float(top_k_correct(logits_cat[k], y_cat[k], 5).float().mean().item()) |
| m[f"{k}_top5"] = acc5 |
| m[f"{k}_mcr"] = mean_class_recall(logits_cat[k], y_cat[k], K) |
|
|
| |
| vf_pred = logits_cat["verb_fine"].argmax(dim=1) |
| n_pred = logits_cat["noun"].argmax(dim=1) |
| h_pred = logits_cat["hand"].argmax(dim=1) |
|
|
| |
| |
| |
| |
| |
| |
| vn_correct = (vf_pred == y_cat["verb_fine"]) & (n_pred == y_cat["noun"]) |
| m["action_vn_top1"] = float(vn_correct.float().mean().item()) |
|
|
| |
| vf_top5 = top_k_correct(logits_cat["verb_fine"], y_cat["verb_fine"], 5) |
| n_top5 = top_k_correct(logits_cat["noun"], y_cat["noun"], 5) |
| m["action_vn_top5"] = float((vf_top5 & n_top5).float().mean().item()) |
|
|
| |
| |
| vfn_h_correct = vn_correct & (h_pred == y_cat["hand"]) |
| m["action_top1"] = float(vfn_h_correct.float().mean().item()) |
| h_top1 = (h_pred == y_cat["hand"]) |
| m["action_top5"] = float((vf_top5 & n_top5 & h_top1).float().mean().item()) |
| return m |
|
|
|
|
| |
| |
| |
|
|
| def apply_modality_dropout(x: Dict[str, torch.Tensor], p: float) -> Dict[str, torch.Tensor]: |
| """Per-sample per-modality dropout: zero out each (sample, modality) cell |
| independently with probability p, but force-keep at least one modality |
| per sample so the model never receives an all-zero input.""" |
| if p <= 0.0: |
| return x |
| mods = list(x.keys()) |
| if len(mods) <= 1: |
| return x |
| any_t = next(iter(x.values())) |
| B = any_t.shape[0] |
| device = any_t.device |
| keep = (torch.rand(B, len(mods), device=device) >= p) |
| forced = torch.randint(len(mods), (B,), device=device) |
| keep[torch.arange(B, device=device), forced] = True |
| out = {} |
| for i, m in enumerate(mods): |
| km = keep[:, i].to(x[m].dtype).view(B, *([1] * (x[m].ndim - 1))) |
| out[m] = x[m] * km |
| return out |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", type=str, default="deepconvlstm", |
| choices=["deepconvlstm", "dailyactformer", |
| "rulstm", "futr", "afft", |
| "handformer", "actionllm"]) |
| ap.add_argument("--modalities", type=str, |
| default="imu,emg,eyetrack,mocap,pressure") |
| ap.add_argument("--t_obs", type=float, default=8.0, |
| help="Anticipation mode only: observation window length (s).") |
| ap.add_argument("--t_fut", type=float, default=2.0, |
| help="Anticipation mode only: prediction horizon (s).") |
| ap.add_argument("--mode", type=str, default="recognition", |
| choices=["recognition", "anticipation"], |
| help="recognition = classify segment from its own [start,end] sensor " |
| "window (default). anticipation = legacy T10 setup, predict from " |
| "[start-t_fut-t_obs, start-t_fut].") |
| ap.add_argument("--downsample", type=int, default=5) |
|
|
| ap.add_argument("--epochs", type=int, default=40) |
| ap.add_argument("--batch_size", type=int, default=32) |
| ap.add_argument("--lr", type=float, default=3e-4) |
| ap.add_argument("--weight_decay", type=float, default=1e-4) |
| ap.add_argument("--grad_clip", type=float, default=1.0) |
| ap.add_argument("--label_smoothing", type=float, default=0.05) |
| ap.add_argument("--dropout", type=float, default=0.1, |
| help="Dropout used inside DAF stems / transformer / pool.") |
| ap.add_argument("--use_prev_action", action="store_true", |
| help="Condition DAF on previous-segment (verb_composite, noun) " |
| "labels via embedding concat to pooled features. Only DAF " |
| "uses this; baselines ignore it.") |
| ap.add_argument("--modality_dropout", type=float, default=0.0, |
| help="Train-time per-sample per-modality dropout prob " |
| "(0.0=off). At least one modality is always kept.") |
|
|
| ap.add_argument("--use_class_weights", action="store_true", |
| help="Weight CE by inverse class frequency (better for tail).") |
| ap.add_argument("--lambda_verb_fine", type=float, default=1.0) |
| ap.add_argument("--lambda_verb_composite", type=float, default=0.5) |
| ap.add_argument("--lambda_noun", type=float, default=1.0) |
| ap.add_argument("--lambda_hand", type=float, default=0.5) |
|
|
| ap.add_argument("--patience", type=int, default=12) |
| ap.add_argument("--warmup_epochs", type=int, default=0, |
| help="Linear LR warmup over the first N epochs (0=off).") |
| ap.add_argument("--seed", type=int, default=42) |
| ap.add_argument("--output_dir", type=str, required=True) |
| ap.add_argument("--num_workers", type=int, default=0) |
| ap.add_argument("--tag", type=str, default="") |
| args = ap.parse_args() |
|
|
| set_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if args.mode == "anticipation": |
| print(f"[cfg] model={args.model} modalities={args.modalities} " |
| f"mode={args.mode} T_obs={args.t_obs}s T_fut={args.t_fut}s seed={args.seed}") |
| else: |
| print(f"[cfg] model={args.model} modalities={args.modalities} " |
| f"mode={args.mode} (segment-aligned window) seed={args.seed}") |
| print(f"[cfg] device={device} epochs={args.epochs} lr={args.lr} " |
| f"batch_size={args.batch_size}") |
|
|
| mods = tuple(args.modalities.split(",")) |
| train_ds, test_ds = build_train_test( |
| modalities=mods, t_obs_sec=args.t_obs, t_fut_sec=args.t_fut, |
| downsample=args.downsample, mode=args.mode, |
| ) |
| print(f"[data] train={len(train_ds)} test={len(test_ds)} " |
| f"modality_dims={train_ds.modality_dims}") |
|
|
| |
| counts = train_ds.class_counts() |
| weights: Dict[str, torch.Tensor] = {} |
| if args.use_class_weights: |
| for k in ("verb_fine", "verb_composite", "noun", "hand"): |
| weights[k] = build_class_weights(counts[k]) |
|
|
| train_loader = DataLoader( |
| train_ds, batch_size=args.batch_size, shuffle=True, |
| collate_fn=collate_triplet, num_workers=args.num_workers, drop_last=True, |
| ) |
| test_loader = DataLoader( |
| test_ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=collate_triplet, num_workers=args.num_workers, |
| ) |
|
|
| |
| |
| extra_kwargs = {} |
| if args.model in ("dailyactformer", "ours", "daf"): |
| extra_kwargs["causal"] = (args.mode == "anticipation") |
| extra_kwargs["dropout"] = args.dropout |
| |
| extra_kwargs["use_prev_action"] = args.use_prev_action |
| model = build_model(args.model, train_ds.modality_dims, **extra_kwargs).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"[model] {args.model} params={n_params:,}") |
|
|
| opt = torch.optim.AdamW( |
| model.parameters(), lr=args.lr, weight_decay=args.weight_decay, |
| ) |
| if args.warmup_epochs > 0: |
| warmup = torch.optim.lr_scheduler.LinearLR( |
| opt, start_factor=1.0 / max(1, args.warmup_epochs), end_factor=1.0, |
| total_iters=args.warmup_epochs, |
| ) |
| cosine = torch.optim.lr_scheduler.CosineAnnealingLR( |
| opt, T_max=max(1, args.epochs - args.warmup_epochs), |
| eta_min=args.lr * 0.05, |
| ) |
| sched = torch.optim.lr_scheduler.SequentialLR( |
| opt, schedulers=[warmup, cosine], milestones=[args.warmup_epochs], |
| ) |
| else: |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR( |
| opt, T_max=args.epochs, eta_min=args.lr * 0.05, |
| ) |
|
|
| lambda_cfg = { |
| "verb_fine": args.lambda_verb_fine, |
| "verb_composite": args.lambda_verb_composite, |
| "noun": args.lambda_noun, |
| "hand": args.lambda_hand, |
| } |
|
|
| |
| out_dir = Path(args.output_dir) |
| if args.tag: |
| out_dir = out_dir.parent / f"{out_dir.name}_{args.tag}" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| with open(out_dir / "config.json", "w") as f: |
| json.dump(vars(args) | {"n_params": n_params}, f, indent=2) |
|
|
| best = {"action_vn_top1": -1.0, "action_top1": -1.0} |
| best_epoch = 0 |
| best_path = out_dir / "model_best.pt" |
| patience = 0 |
| history = [] |
|
|
| for epoch in range(1, args.epochs + 1): |
| t0 = time.time() |
| model.train() |
| losses_epoch = {k: 0.0 for k in |
| ("verb_fine", "verb_composite", "noun", "hand", "total")} |
| n_batches = 0 |
| for batch in train_loader: |
| if len(batch) == 6: |
| x, mask, lens, y, meta, prev = batch |
| else: |
| x, mask, lens, y, meta = batch |
| prev = None |
| x = {m: t.to(device) for m, t in x.items()} |
| mask = mask.to(device) |
| y = {k: v.to(device) for k, v in y.items()} |
|
|
| if args.modality_dropout > 0.0: |
| x = apply_modality_dropout(x, args.modality_dropout) |
|
|
| kwargs = {} |
| if prev is not None and getattr(model, "use_prev_action", False): |
| kwargs["prev_v_comp"] = prev["verb_composite"].to(device) |
| kwargs["prev_noun"] = prev["noun"].to(device) |
|
|
| opt.zero_grad() |
| logits = model(x, mask, **kwargs) |
| l = triplet_loss(logits, y, weights, lambda_cfg, |
| label_smoothing=args.label_smoothing) |
| l["total"].backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| opt.step() |
|
|
| for k in losses_epoch: |
| losses_epoch[k] += float(l[k].detach().item()) |
| n_batches += 1 |
|
|
| for k in losses_epoch: |
| losses_epoch[k] /= max(1, n_batches) |
| sched.step() |
|
|
| metrics = evaluate(model, test_loader, device) |
| dur = time.time() - t0 |
|
|
| print( |
| f" E{epoch:3d} loss={losses_epoch['total']:.3f} " |
| f"(vf={losses_epoch['verb_fine']:.2f} " |
| f"n={losses_epoch['noun']:.2f} " |
| f"h={losses_epoch['hand']:.2f}) | " |
| f"act_vn@1={metrics['action_vn_top1']:.3f} " |
| f"vf@1={metrics['verb_fine_top1']:.3f} " |
| f"n@1={metrics['noun_top1']:.3f} " |
| f"h@1={metrics['hand_top1']:.3f} | " |
| f"{dur:.1f}s", |
| flush=True, |
| ) |
|
|
| history.append({"epoch": epoch, **losses_epoch, **metrics}) |
| if metrics["action_vn_top1"] > best["action_vn_top1"]: |
| best = dict(metrics) |
| best_epoch = epoch |
| patience = 0 |
| torch.save( |
| {"state_dict": {k: v.cpu().clone() |
| for k, v in model.state_dict().items()}, |
| "epoch": epoch, |
| "metrics": metrics}, |
| best_path, |
| ) |
| else: |
| patience += 1 |
| if patience >= args.patience: |
| print(f" early stop at epoch {epoch} (best epoch {best_epoch})") |
| break |
|
|
| |
| results = { |
| "best_epoch": best_epoch, |
| "best_test_metrics": best, |
| "history": history, |
| "n_params": n_params, |
| "train_size": len(train_ds), |
| "test_size": len(test_ds), |
| "train_class_counts": {k: v.tolist() for k, v in counts.items()}, |
| "modality_dims": train_ds.modality_dims, |
| "args": vars(args), |
| } |
| with open(out_dir / "results.json", "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"\n[done] best action_vn@1 = {best['action_vn_top1']:.4f} " |
| f"(legacy action@1 = {best['action_top1']:.4f}, epoch {best_epoch}) " |
| f"saved to {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|