PULSE-code / experiments /tasks /train_exp_missing.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Experiment A: Missing-modality robustness for scene recognition (T1).
Train a late-fusion Transformer on all 5 modalities with random per-sample
modality dropout. At test time, systematically evaluate every modality subset
(single modalities, leave-one-out, and full set) by zeroing out the
slices of the concatenated input tensor that correspond to the dropped
modalities.
Reuses: experiments.dataset.get_dataloaders, experiments.models.build_model,
and the pretrained-backbone-transfer helper from train_exp1.py.
"""
import os
import sys
import json
import time
import random
import argparse
import itertools
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import get_dataloaders, NUM_CLASSES
from nets.models import build_model
from tasks.train_exp1 import (
set_seed, apply_augmentation, _load_and_freeze_backbone,
)
def modality_slices(modality_dims):
"""Return {mod_name: (start, end)} byte-offsets into the concatenated feature dim."""
slices = {}
off = 0
for name, dim in modality_dims.items():
slices[name] = (off, off + dim)
off += dim
return slices
def mask_modalities(x, slices, active_mods):
"""Zero out the slices of x corresponding to modalities NOT in active_mods.
x: (B, T, F_total)
Returns a new tensor; does not mutate x in place.
"""
if set(active_mods) == set(slices.keys()):
return x
x2 = x.clone()
for name, (s, e) in slices.items():
if name not in active_mods:
x2[..., s:e] = 0.0
return x2
def train_one_epoch_with_dropout(model, loader, criterion, optimizer, device,
slices, mod_dropout_p=0.0,
augment=False, noise_std=0.1, time_mask_ratio=0.1):
"""Train one epoch. With probability mod_dropout_p, for each training sample
independently drop a random non-empty subset of modalities.
Strategy: for each sample, flip an independent Bernoulli(p) per modality;
if ALL modalities would be dropped, keep one at random.
"""
model.train()
mods = list(slices.keys())
total_loss = 0.0
all_preds, all_labels = [], []
for x, y, mask, _ in loader:
x, y, mask = x.to(device), y.to(device), mask.to(device)
if augment:
x = apply_augmentation(x, mask, noise_std, time_mask_ratio)
if mod_dropout_p > 0:
B = x.size(0)
for i in range(B):
dropped = [m for m in mods if random.random() < mod_dropout_p]
# ensure at least one modality survives
if len(dropped) == len(mods):
dropped = random.sample(dropped, len(dropped) - 1)
for m in dropped:
s, e = slices[m]
x[i, :, s:e] = 0.0
optimizer.zero_grad()
logits = model(x, mask)
loss = criterion(logits, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], 1.0
)
optimizer.step()
total_loss += loss.item() * y.size(0)
all_preds.extend(logits.argmax(dim=1).cpu().numpy())
all_labels.extend(y.cpu().numpy())
n = len(all_labels)
return total_loss / n, accuracy_score(all_labels, all_preds)
@torch.no_grad()
def evaluate_with_mask(model, loader, criterion, device, slices, active_mods):
model.eval()
total_loss = 0.0
all_preds, all_labels = [], []
for x, y, mask, _ in loader:
x, y, mask = x.to(device), y.to(device), mask.to(device)
x = mask_modalities(x, slices, set(active_mods))
logits = model(x, mask)
loss = criterion(logits, y)
total_loss += loss.item() * y.size(0)
all_preds.extend(logits.argmax(dim=1).cpu().numpy())
all_labels.extend(y.cpu().numpy())
n = len(all_labels)
if n == 0:
return 0.0, 0.0, 0.0, np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=int)
acc = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
cm = confusion_matrix(all_labels, all_preds, labels=list(range(NUM_CLASSES)))
return total_loss / n, acc, f1, cm
def run_experiment(args):
set_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
modalities = args.modalities.split(',')
print(f"Model: {args.model} | Fusion: {args.fusion} | Modalities: {modalities}")
print(f"Training dropout p={args.mod_dropout_p}")
train_loader, val_loader, test_loader, info = get_dataloaders(
modalities, batch_size=args.batch_size, downsample=args.downsample
)
if info['val_size'] == 0:
val_loader = test_loader
print(f"Train: {info['train_size']}, Test: {info['test_size']}")
print(f"Feature dim: {info['feat_dim']}, Modality dims: {info['modality_dims']}")
slices = modality_slices(info['modality_dims'])
print(f"Modality slices: {slices}")
model = build_model(
args.model, args.fusion, info['feat_dim'],
info['modality_dims'], info['num_classes'],
hidden_dim=args.hidden_dim, proj_dim=args.proj_dim,
late_agg=args.late_agg,
).to(device)
# Optional pretrained backbone loading (per-modality)
if args.pretrained_dir:
for i, mod in enumerate(modalities):
pt_path = os.path.join(args.pretrained_dir,
f"transformer_{mod}_early", "model_best.pt")
if os.path.exists(pt_path):
_load_and_freeze_backbone(model, pt_path, i, args.fusion)
else:
print(f" WARN: no pretrained ckpt for {mod} at {pt_path}")
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Params: {trainable:,}/{total:,}")
class_weights = info['class_weights'].to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights,
label_smoothing=args.label_smoothing)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr, weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=7, min_lr=1e-6,
)
mod_str = '-'.join(modalities)
exp_name = f"{args.model}_{mod_str}_{args.fusion}_drop{args.mod_dropout_p}_seed{args.seed}"
if args.tag:
exp_name += f"_{args.tag}"
out_dir = os.path.join(args.output_dir, exp_name)
os.makedirs(out_dir, exist_ok=True)
best_val_loss = float('inf')
best_epoch = 0
patience_counter = 0
for epoch in range(1, args.epochs + 1):
t0 = time.time()
train_loss, train_acc = train_one_epoch_with_dropout(
model, train_loader, criterion, optimizer, device,
slices=slices, mod_dropout_p=args.mod_dropout_p,
augment=args.augment,
)
# Validate on FULL modalities (baseline performance)
val_loss, val_acc, val_f1, _ = evaluate_with_mask(
model, val_loader, criterion, device, slices, modalities,
)
scheduler.step(val_loss)
print(f" E{epoch:3d} | tr_loss {train_loss:.4f} tr_acc {train_acc:.4f} | "
f"va_loss {val_loss:.4f} va_acc {val_acc:.4f} va_f1 {val_f1:.4f} | "
f"{time.time()-t0:.1f}s")
if val_loss < best_val_loss:
best_val_loss = val_loss
best_epoch = epoch
patience_counter = 0
torch.save(model.state_dict(), os.path.join(out_dir, 'model_best.pt'))
else:
patience_counter += 1
if patience_counter >= args.patience:
print(f" Early stop at epoch {epoch} (best {best_epoch})")
break
# Restore best model
model.load_state_dict(torch.load(os.path.join(out_dir, 'model_best.pt'),
weights_only=True))
# Systematic evaluation: full, leave-one-out, and all singletons
print("\n=== Robustness Evaluation ===")
eval_configs = []
eval_configs.append(('full', modalities))
for m in modalities:
remaining = [x for x in modalities if x != m]
eval_configs.append((f'drop_{m}', remaining))
for m in modalities:
eval_configs.append((f'only_{m}', [m]))
results_matrix = {}
for name, active in eval_configs:
_, acc, f1, _ = evaluate_with_mask(
model, test_loader, criterion, device, slices, active,
)
results_matrix[name] = {'active': active, 'acc': float(acc), 'f1': float(f1)}
print(f" {name:<15s} mods={active} | acc {acc:.4f} f1 {f1:.4f}")
results = {
'experiment': exp_name,
'training_dropout_p': args.mod_dropout_p,
'seed': args.seed,
'best_epoch': best_epoch,
'eval_configs': results_matrix,
'train_size': info['train_size'],
'test_size': info['test_size'],
'modality_dims': info['modality_dims'],
'args': vars(args),
}
with open(os.path.join(out_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Saved: {out_dir}/results.json")
return results
def main():
p = argparse.ArgumentParser()
p.add_argument('--model', type=str, default='transformer')
p.add_argument('--modalities', type=str, default='mocap,emg,eyetrack,imu,pressure')
p.add_argument('--fusion', type=str, default='late')
p.add_argument('--late_agg', type=str, default='mean')
p.add_argument('--mod_dropout_p', type=float, default=0.3,
help='Per-modality independent dropout prob at training time')
p.add_argument('--pretrained_dir', type=str, default='',
help='Directory with pretrained single-modality ckpts')
p.add_argument('--epochs', type=int, default=100)
p.add_argument('--batch_size', type=int, default=16)
p.add_argument('--lr', type=float, default=1e-3)
p.add_argument('--weight_decay', type=float, default=1e-4)
p.add_argument('--hidden_dim', type=int, default=128)
p.add_argument('--proj_dim', type=int, default=0)
p.add_argument('--downsample', type=int, default=5)
p.add_argument('--patience', type=int, default=15)
p.add_argument('--label_smoothing', type=float, default=0.1)
p.add_argument('--augment', action='store_true')
p.add_argument('--seed', type=int, default=42)
p.add_argument('--output_dir', type=str, required=True)
p.add_argument('--tag', type=str, default='')
args = p.parse_args()
run_experiment(args)
if __name__ == '__main__':
main()