| """ |
| Training script for 2-stage fMRI encoding with Flow Matching. |
| Stage 1: Train MultiSubjectConvLinearEncoder (Mean Anchor) |
| Stage 2: Train Conditional Flow Matching (Neural Vector Field) per subject. |
| """ |
|
|
| import argparse |
| import json |
| import math |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Dict, Any, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from omegaconf import DictConfig, OmegaConf |
| from torch.utils.data import DataLoader |
| from timm.utils import AverageMeter, random_seed |
|
|
| from .visualize import plot_loss_curve |
| from .data import ( |
| Algonauts2025Dataset, |
| load_algonauts2025_friends_fmri, |
| load_algonauts2025_movie10_fmri, |
| load_sharded_features, |
| episode_filter, |
| ) |
| from .stage1.medarc_architecture import MultiSubjectConvLinearEncoder |
| from .stage2.CFM import CFM |
| from .metric import pearsonr_score |
|
|
| |
| DEFAULT_DATA_DIR = Path("/raid/lttung05/fmri_encoder/data") |
| SUBJECTS = (1, 2, 3, 5) |
|
|
|
|
| def load_features(cfg: DictConfig, model: str, layer: str) -> dict[str, np.ndarray]: |
| data_dir = Path(cfg.datasets_root or DEFAULT_DATA_DIR) |
| friends_features = load_sharded_features( |
| data_dir / "features", model=model, layer=layer, series="friends" |
| ) |
| movie10_features = load_sharded_features( |
| data_dir / "features", model=model, layer=layer, series="movie10" |
| ) |
| features = {**friends_features, **movie10_features} |
| return features |
|
|
|
|
| def pool_features(features: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
| pooled = {} |
| for key, feat in features.items(): |
| assert feat.ndim in {2, 3} |
| if feat.ndim == 3: |
| feat = feat.mean(axis=1) |
| pooled[key] = feat |
| return pooled |
|
|
|
|
| def make_data_loaders(cfg: DictConfig) -> dict[str, DataLoader]: |
| print("loading fmri data") |
|
|
| data_dir = Path(cfg.datasets_root or DEFAULT_DATA_DIR) |
| subjects = cfg.get("subjects", SUBJECTS) |
|
|
| friends_fmri = load_algonauts2025_friends_fmri( |
| data_dir / "algonauts_2025.competitors", subjects=subjects |
| ) |
| movie10_fmri = load_algonauts2025_movie10_fmri( |
| data_dir / "algonauts_2025.competitors", subjects=subjects |
| ) |
| all_fmri = {**friends_fmri, **movie10_fmri} |
| all_episodes = list(all_fmri) |
|
|
| all_features = [] |
| for feat_name in cfg.include_features: |
| model, layer = feat_name.split("/") |
| feat_cfg = cfg.features[model] |
| model_name = feat_cfg.model |
| layer_name = feat_cfg.layers[layer] |
| print(f"loading features {feat_name} ({model_name}/{layer_name})") |
| features = load_features(cfg, model_name, layer_name) |
|
|
| if cfg.stage1.model.global_pool == "avg": |
| features = pool_features(features) |
|
|
| all_features.append(features) |
|
|
| data_loaders = {} |
|
|
| for ds_name, ds_cfg in cfg.datasets.items(): |
| print(f"loading dataset: {ds_name}\n\n{OmegaConf.to_yaml(ds_cfg)}") |
|
|
| ds_cfg = ds_cfg.copy() |
| filter_cfg = ds_cfg.pop("filter") |
| filter_fn = episode_filter(**filter_cfg) |
| ds_episodes = list(filter(filter_fn, all_episodes)) |
| |
|
|
| dataset = Algonauts2025Dataset( |
| episode_list=ds_episodes, |
| fmri_data=all_fmri, |
| feat_data=all_features, |
| **ds_cfg, |
| ) |
|
|
| batch_size = cfg.batch_size if ds_name == "train" else 1 |
| loader = DataLoader(dataset, batch_size=batch_size) |
|
|
| data_loaders[ds_name] = loader |
|
|
| return data_loaders |
|
|
|
|
| def train_one_epoch_condition( |
| *, |
| epoch: int, |
| model: torch.nn.Module, |
| train_loader: DataLoader, |
| optimizer: torch.optim.Optimizer, |
| device: torch.device, |
| ): |
| model.train() |
|
|
| use_cuda = device.type == "cuda" |
| if use_cuda: |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| loss_m = AverageMeter() |
| data_time_m = AverageMeter() |
| step_time_m = AverageMeter() |
|
|
| end = time.monotonic() |
|
|
| for batch_idx, batch in enumerate(train_loader): |
| feats = [f.to(device) for f in batch["features"]] |
| fmri = batch["fmri"].to(device) |
| |
| batch_size = fmri.size(0) |
| data_time = time.monotonic() - end |
|
|
| pred = model(feats) |
|
|
| loss = nn.MSELoss()(pred, fmri) |
| loss_item = loss.item() |
|
|
| if math.isnan(loss_item) or math.isinf(loss_item): |
| raise RuntimeError( |
| f"NaN/Inf loss encountered on step {batch_idx + 1}; exiting" |
| ) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| if use_cuda: |
| torch.cuda.synchronize() |
| step_time = time.monotonic() - end |
|
|
| loss_m.update(loss_item, batch_size) |
| data_time_m.update(data_time, batch_size) |
| step_time_m.update(step_time, batch_size) |
|
|
| if (batch_idx + 1) % 20 == 0: |
| tput = batch_size / step_time_m.avg |
| if use_cuda: |
| alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9 |
| res_mem_gb = torch.cuda.max_memory_reserved() / 1e9 |
| else: |
| alloc_mem_gb = res_mem_gb = 0.0 |
|
|
| print( |
| f"Stage 1 Train: {epoch:>3d} [{batch_idx:>3d}]" |
| f" Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})" |
| f" Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s" |
| f" Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB" |
| ) |
|
|
| end = time.monotonic() |
|
|
| return loss_m.avg |
|
|
|
|
| def train_one_epoch_flow_matching( |
| *, |
| epoch: int, |
| stage1_model: torch.nn.Module, |
| stage2_models: nn.ModuleDict, |
| train_loader: DataLoader, |
| optimizers: Dict[str, torch.optim.Optimizer], |
| device: torch.device, |
| subjects: list, |
| ): |
| stage1_model.eval() |
| for model in stage2_models.values(): |
| model.train() |
|
|
| use_cuda = device.type == "cuda" |
| if use_cuda: |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| loss_m = AverageMeter() |
| data_time_m = AverageMeter() |
| step_time_m = AverageMeter() |
|
|
| end = time.monotonic() |
|
|
| for batch_idx, batch in enumerate(train_loader): |
| feats = [f.to(device) for f in batch["features"]] |
| fmri = batch["fmri"].to(device) |
| batch_size = fmri.size(0) |
| data_time = time.monotonic() - end |
|
|
| |
| with torch.no_grad(): |
| mu_anchor = stage1_model(feats) |
|
|
| batch_loss = 0 |
|
|
| |
| for i, sub in enumerate(subjects): |
| sub_key = str(sub) |
| cfm = stage2_models[sub_key] |
| optimizer = optimizers[sub_key] |
|
|
| |
| |
| x1 = fmri[:, i].transpose(1, 2) |
| mu = mu_anchor[:, i].transpose(1, 2) |
|
|
| |
| loss, _ = cfm.compute_loss(x1, mu) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| batch_loss += loss.item() |
|
|
| loss_item = batch_loss / len(subjects) |
|
|
| if math.isnan(loss_item) or math.isinf(loss_item): |
| raise RuntimeError( |
| f"NaN/Inf loss encountered on step {batch_idx + 1}; exiting" |
| ) |
|
|
| if use_cuda: |
| torch.cuda.synchronize() |
| step_time = time.monotonic() - end |
|
|
| loss_m.update(loss_item, fmri.size(0)) |
| data_time_m.update(data_time, batch_size) |
| step_time_m.update(step_time, batch_size) |
|
|
| if (batch_idx + 1) % 20 == 0: |
| tput = batch_size / step_time_m.avg |
| if use_cuda: |
| alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9 |
| res_mem_gb = torch.cuda.max_memory_reserved() / 1e9 |
| else: |
| alloc_mem_gb = res_mem_gb = 0.0 |
|
|
| print( |
| f"Stage 2 Train: {epoch:>3d} [{batch_idx:>3d}]" |
| f" Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})" |
| f" Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s" |
| f" Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB" |
| ) |
|
|
| end = time.monotonic() |
|
|
| return loss_m.avg |
|
|
|
|
| @torch.no_grad() |
| def evaluate_stage1( |
| *, |
| epoch: int, |
| model: torch.nn.Module, |
| val_loader: DataLoader, |
| device: torch.device, |
| subjects: list, |
| ds_name: str = "val", |
| ): |
| model.eval() |
|
|
| loss_m = AverageMeter() |
| samples = [] |
| outputs = [] |
|
|
| for batch_idx, batch in enumerate(val_loader): |
| feats = [f.to(device) for f in batch["features"]] |
| fmri = batch["fmri"].to(device) |
| batch_size = fmri.size(0) |
|
|
| pred = model(feats) |
| loss = nn.MSELoss()(pred, fmri) |
| loss_m.update(loss.item(), batch_size) |
|
|
| N, S, L, C = fmri.shape |
| assert N, S == (1, 4) |
|
|
| outputs.append(pred.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C))) |
| samples.append(fmri.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C))) |
|
|
| outputs = np.concatenate(outputs, axis=1) |
| samples = np.concatenate(samples, axis=1) |
|
|
| metrics = {} |
|
|
| |
| dim = samples.shape[-1] |
| acc = 0.0 |
| acc_map = np.zeros(dim) |
| for ii, sub in enumerate(subjects): |
| y_true = samples[ii].reshape(-1, dim) |
| y_pred = outputs[ii].reshape(-1, dim) |
| metrics[f"accmap_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred) |
| metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i) |
| acc_map += acc_map_i / len(subjects) |
| acc += acc_i / len(subjects) |
|
|
| metrics["accmap_avg"] = acc_map |
| metrics["acc_avg"] = acc |
| accs_fmt = ",".join( |
| f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-") |
| ) |
|
|
| print( |
| f"Evaluate Stage 1 ({ds_name}): {epoch:>3d}" |
| f" Loss: {loss_m.avg:#.3g}" |
| f" Acc: {accs_fmt} ({acc:.3f})" |
| ) |
|
|
| return acc, metrics |
|
|
|
|
| @torch.no_grad() |
| def evaluate_stage2( |
| *, |
| epoch: int, |
| stage1_model: torch.nn.Module, |
| stage2_models: nn.ModuleDict, |
| val_loader: DataLoader, |
| device: torch.device, |
| subjects: list, |
| ds_name: str = "val", |
| n_timesteps: int = 10, |
| ): |
| stage1_model.eval() |
| for model in stage2_models.values(): |
| model.eval() |
|
|
| samples = [] |
| outputs = [] |
|
|
| for batch in val_loader: |
| feats = [f.to(device) for f in batch["features"]] |
| fmri = batch["fmri"].to(device) |
|
|
| mu_anchor = stage1_model(feats) |
|
|
| batch_preds = [] |
| for i, sub in enumerate(subjects): |
| sub_key = str(sub) |
| cfm = stage2_models[sub_key] |
|
|
| mu = mu_anchor[:, i].transpose(1, 2) |
|
|
| |
| pred = cfm(mu, n_timesteps=n_timesteps) |
| pred = pred.transpose(1, 2).unsqueeze(1) |
| batch_preds.append(pred) |
|
|
| pred_combined = torch.cat(batch_preds, dim=1) |
|
|
| N, S, L, C = fmri.shape |
| assert N, S == (1, 4) |
|
|
| outputs.append( |
| pred_combined.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C)) |
| ) |
| samples.append(fmri.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C))) |
|
|
| outputs = np.concatenate(outputs, axis=1) |
| samples = np.concatenate(samples, axis=1) |
|
|
| metrics = {} |
|
|
| dim = samples.shape[-1] |
| acc = 0.0 |
| acc_map = np.zeros(dim) |
| for ii, sub in enumerate(subjects): |
| y_true = samples[ii].reshape(-1, dim) |
| y_pred = outputs[ii].reshape(-1, dim) |
| metrics[f"accmap_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred) |
| metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i) |
| acc_map += acc_map_i / len(subjects) |
| acc += acc_i / len(subjects) |
|
|
| metrics["accmap_avg"] = acc_map |
| metrics["acc_avg"] = acc |
| accs_fmt = ",".join( |
| f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-") |
| ) |
|
|
| print(f"Evaluate Stage 2 ({ds_name}): {epoch:>3d}" f" Acc: {accs_fmt} ({acc:.3f})") |
|
|
| return acc, metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--cfg-path", type=str, default="config.yml") |
| args = parser.parse_args() |
|
|
| cfg = OmegaConf.load(args.cfg_path) |
| print("Config loaded:\n", OmegaConf.to_yaml(cfg)) |
|
|
| out_dir = Path(cfg.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| OmegaConf.save(cfg, out_dir / "config.yaml") |
|
|
| random_seed(cfg.seed) |
| device = torch.device(cfg.device) |
|
|
| |
| data_loaders = make_data_loaders(cfg) |
| train_loader = data_loaders["train"] |
| val_loaders = data_loaders.copy() |
| val_loaders.pop("train") |
|
|
| |
| print("Creating Stage 1 Model (Encoder)...") |
|
|
| |
| sample_batch = next(iter(train_loader)) |
| feat_dims = [f.shape[-1] for f in sample_batch["features"]] |
|
|
| subjects_list = cfg.get("subjects", SUBJECTS) |
|
|
| stage1_model = MultiSubjectConvLinearEncoder( |
| num_subjects=len(subjects_list), |
| feat_dims=feat_dims, |
| |
| **cfg.stage1.model, |
| ).to(device) |
|
|
| optimizer1 = torch.optim.AdamW( |
| stage1_model.parameters(), |
| lr=cfg.stage1.lr, |
| weight_decay=cfg.stage1.weight_decay, |
| ) |
|
|
| |
| print("--- Starting Stage 1 Training (Mean Anchor) ---") |
| best_score_s1 = -1.0 |
| stage1_train_losses = [] |
| stage1_val_accs = [] |
|
|
| for epoch in range(cfg.stage1.epochs): |
| train_loss = train_one_epoch_condition( |
| epoch=epoch, |
| model=stage1_model, |
| train_loader=train_loader, |
| optimizer=optimizer1, |
| device=device, |
| ) |
| stage1_train_losses.append(train_loss) |
|
|
| |
| val_acc = None |
| for name, loader in val_loaders.items(): |
| acc, _ = evaluate_stage1( |
| epoch=epoch, |
| model=stage1_model, |
| val_loader=loader, |
| device=device, |
| subjects=subjects_list, |
| ds_name=name, |
| ) |
| if name == cfg.val_set_name: |
| val_acc = acc |
|
|
| stage1_val_accs.append(val_acc if val_acc is not None else 0.0) |
|
|
| if val_acc is not None and val_acc > best_score_s1: |
| best_score_s1 = val_acc |
| torch.save(stage1_model.state_dict(), out_dir / "stage1_best.pt") |
| print("Saved best Stage 1 model.") |
|
|
| plot_loss_curve( |
| stage1_train_losses, |
| stage1_val_accs, |
| out_dir, |
| filename="stage1_loss_curve.png", |
| prefix="Stage 1", |
| ) |
|
|
| print(f"Stage 1 Training Complete. Best model at Pearson's r {best_score_s1}") |
|
|
| |
| stage1_model.load_state_dict(torch.load(out_dir / "stage1_best.pt")) |
| stage1_model.eval() |
|
|
| |
| print("Creating Stage 2 Models (Flow Matching)...") |
| stage2_models = nn.ModuleDict() |
| optimizers2 = {} |
|
|
| |
| target_dim = sample_batch["fmri"].shape[-1] |
|
|
| cfm_params = cfg.stage2.cfm |
| velocity_net_params = cfg.stage2.velocity_net |
| source_ve_params = cfg.stage2.source_ve |
| transport_params = cfg.stage2.transport |
|
|
| for sub in subjects_list: |
| sub_key = str(sub) |
| |
| cfm_model = CFM( |
| feat_dim=target_dim, |
| cfm_params=cfm_params, |
| velocity_net_params=velocity_net_params, |
| source_ve_params=source_ve_params, |
| transport_params=transport_params, |
| ).to(device) |
|
|
| stage2_models[sub_key] = cfm_model |
| optimizers2[sub_key] = torch.optim.AdamW( |
| cfm_model.parameters(), |
| lr=cfg.stage2.lr, |
| weight_decay=cfg.stage2.weight_decay, |
| ) |
|
|
| |
| print("--- Starting Stage 2 Training (Vector Fields) ---") |
|
|
| stage2_train_losses = [] |
|
|
| for epoch in range(cfg.stage2.epochs): |
| train_loss = train_one_epoch_flow_matching( |
| epoch=epoch, |
| stage1_model=stage1_model, |
| stage2_models=stage2_models, |
| train_loader=train_loader, |
| optimizers=optimizers2, |
| device=device, |
| subjects=subjects_list, |
| ) |
| stage2_train_losses.append(train_loss) |
|
|
| |
| if epoch % 5 == 0 or epoch == cfg.stage2.epochs - 1: |
| ckpt_path = out_dir / f"stage2_epoch_{epoch}.pt" |
| torch.save(stage2_models.state_dict(), ckpt_path) |
| print(f"Saved Stage 2 checkpoint to {ckpt_path}") |
|
|
| |
| print("Evaluating final Stage 2 model...") |
| for name, loader in val_loaders.items(): |
| evaluate_stage2( |
| epoch=cfg.stage2.epochs, |
| stage1_model=stage1_model, |
| stage2_models=stage2_models, |
| val_loader=loader, |
| device=device, |
| subjects=subjects_list, |
| ds_name=name, |
| n_timesteps=cfg.stage2.get("n_timesteps", 25), |
| ) |
|
|
| plot_loss_curve( |
| stage2_train_losses, |
| out_path=out_dir, |
| filename="stage2_loss_curve.png", |
| prefix="Stage 2", |
| ) |
|
|
| print("Done! All training complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|