| """Training loop shared across all four models. |
| |
| Differences across runs are entirely in the model registered under `config.model`. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| import time |
| from dataclasses import dataclass, field |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import yaml |
| from torch.utils.data import DataLoader |
|
|
| from .data import MIMICAlignedDataset, collate_with_dt, split_by_subject |
| from .ema import ema_tau |
| from .models import MODEL_REGISTRY, ModelConfig |
| from .monitor import CollapseMonitor, cross_modal_cosine, effective_rank |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| run_name: str = "debug" |
| model: str = "F" |
| epochs: int = 100 |
| batch_size: int = 64 |
| lr: float = 1e-4 |
| weight_decay: float = 0.04 |
| warmup_epochs: int = 10 |
| ema_start: float = 0.996 |
| ema_end: float = 0.9999 |
| ema_warmup_frac: float = 0.30 |
| grad_clip: float = 1.0 |
| log_every: int = 100 |
| ckpt_every_epochs: int = 5 |
| seed: int = 0 |
| wandb_project: str = "physiojepa" |
| wandb_mode: str = "online" |
| wandb_entity: str | None = None |
| output_dir: str = "runs" |
| index_path: str = "cache/mimic_index.json" |
| shard_roots: list[str] = field(default_factory=list) |
| num_workers: int = 4 |
| amp: bool = True |
| |
| log_uniform_frac: float = 0.6 |
| |
| subset_frac: float = 1.0 |
| |
| pred_depth: int = 4 |
| query_mode: str = "learned" |
| mask_ratio: float = 0.50 |
| |
| fast_cache_dir: str = "" |
|
|
|
|
| def load_yaml_config(path: str) -> TrainConfig: |
| with open(path, "r") as f: |
| d = yaml.safe_load(f) |
| return TrainConfig(**d) |
|
|
|
|
| class _Collator: |
| """Top-level callable so DataLoader workers can serialize it across fork.""" |
|
|
| def __init__(self, log_uniform_frac: float, seed: int): |
| self.log_uniform_frac = log_uniform_frac |
| self.seed = seed |
| self._rng = None |
|
|
| def __call__(self, items): |
| if self._rng is None: |
| self._rng = np.random.default_rng(self.seed + os.getpid()) |
| return collate_with_dt(items, log_uniform_frac=self.log_uniform_frac, rng=self._rng) |
|
|
|
|
| def _build_dataloaders(cfg: TrainConfig) -> tuple[DataLoader, DataLoader, list[str]]: |
| if cfg.fast_cache_dir: |
| from .data_fast import MIMICFastDataset |
| cache_dir = Path(cfg.fast_cache_dir) |
| import json |
| meta = json.loads((cache_dir / "windows_meta.json").read_text()) |
| subjects = sorted(set(meta["subjects"])) |
| train_subj, val_subj = split_by_subject(subjects, frac=0.9, seed=cfg.seed) |
| train_ds = MIMICFastDataset(cache_dir, subjects_allow=train_subj) |
| val_ds = MIMICFastDataset(cache_dir, subjects_allow=val_subj) |
| else: |
| shard_roots = [Path(p) for p in cfg.shard_roots] |
| ds_full = MIMICAlignedDataset( |
| shard_roots=shard_roots, |
| index_path=Path(cfg.index_path), |
| build_index=not Path(cfg.index_path).exists(), |
| ) |
| subjects = sorted({r["subject_id"] for r in ds_full.index}) |
| train_subj, val_subj = split_by_subject(subjects, frac=0.9, seed=cfg.seed) |
| train_ds = MIMICAlignedDataset( |
| shard_roots, Path(cfg.index_path), build_index=False, subjects_allow=train_subj, |
| subset_frac=cfg.subset_frac, subset_seed=cfg.seed, |
| ) |
| val_ds = MIMICAlignedDataset( |
| shard_roots, Path(cfg.index_path), build_index=False, subjects_allow=val_subj, |
| ) |
| collate = _Collator(cfg.log_uniform_frac, cfg.seed) |
| train_loader = DataLoader( |
| train_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=cfg.num_workers, collate_fn=collate, drop_last=True, |
| persistent_workers=cfg.num_workers > 0, |
| ) |
| val_loader = DataLoader( |
| val_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=max(cfg.num_workers, 1), collate_fn=collate, drop_last=False, |
| ) |
| return train_loader, val_loader, subjects |
|
|
|
|
| def _cosine_lr(step: int, total_steps: int, base: float, warmup_steps: int) -> float: |
| if step < warmup_steps: |
| return base * (step + 1) / max(1, warmup_steps) |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) |
| return 0.5 * base * (1 + math.cos(math.pi * progress)) |
|
|
|
|
| def train(cfg: TrainConfig) -> dict: |
| import wandb |
|
|
| torch.manual_seed(cfg.seed) |
| np.random.seed(cfg.seed) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() |
| else ("mps" if torch.backends.mps.is_available() else "cpu")) |
| train_loader, val_loader, subjects = _build_dataloaders(cfg) |
| print(f"[trainer] device={device} n_train_windows={len(train_loader.dataset)} " |
| f"n_val_windows={len(val_loader.dataset)} subjects={len(subjects)}") |
|
|
| mcfg = ModelConfig( |
| pred_depth=cfg.pred_depth, |
| query_mode=cfg.query_mode, |
| mask_ratio=cfg.mask_ratio, |
| ) |
| model = MODEL_REGISTRY[cfg.model](mcfg).to(device) |
| opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| scaler = torch.amp.GradScaler(device.type) if cfg.amp and device.type == "cuda" else None |
|
|
| total_steps = cfg.epochs * len(train_loader) |
| warmup_steps = cfg.warmup_epochs * len(train_loader) |
|
|
| wandb.init(project=cfg.wandb_project, name=cfg.run_name, config=cfg.__dict__, |
| mode=cfg.wandb_mode, entity=cfg.wandb_entity) |
|
|
| monitor = CollapseMonitor() |
| step = 0 |
| out_root = Path(cfg.output_dir) / cfg.run_name |
| out_root.mkdir(parents=True, exist_ok=True) |
| aborted = False |
| for epoch in range(cfg.epochs): |
| model.train(True) |
| for batch in train_loader: |
| |
| for k in ("ecg", "ppg", "dt_seconds", "ptt_ms"): |
| if k in batch and isinstance(batch[k], torch.Tensor): |
| batch[k] = batch[k].to(device) |
| |
| lr_now = _cosine_lr(step, total_steps, cfg.lr, warmup_steps) |
| for g in opt.param_groups: |
| g["lr"] = lr_now |
| opt.zero_grad(set_to_none=True) |
| if scaler is not None: |
| with torch.amp.autocast("cuda"): |
| out = model.step(batch) |
| scaler.scale(out["loss"]).backward() |
| scaler.unscale_(opt) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) |
| scaler.step(opt) |
| scaler.update() |
| else: |
| out = model.step(batch) |
| out["loss"].backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) |
| opt.step() |
|
|
| |
| tau = ema_tau(step, total_steps, cfg.ema_start, cfg.ema_end, cfg.ema_warmup_frac) |
| for online, tgt in model.targets(): |
| tgt.update(online, tau) |
|
|
| if step % cfg.log_every == 0: |
| metrics = { |
| "step": step, "epoch": epoch, "lr": lr_now, "tau": tau, |
| "loss": float(out["loss"].detach().item()), |
| "L_cross": float(out.get("L_cross", torch.tensor(0.0)).item()), |
| "L_self": float(out.get("L_self", torch.tensor(0.0)).item()), |
| } |
| z_e = out.get("z_ecg") |
| if z_e is not None and z_e.shape[0] > 1: |
| metrics["ecg_latent_var"] = float(z_e.var(dim=0).mean().item()) |
| metrics["ecg_eff_rank"] = effective_rank(z_e) |
| z_p_pred = out.get("z_pred") |
| z_p_tgt = out.get("z_ppg") |
| if z_p_pred is not None and z_p_tgt is not None and z_p_pred.shape[0] > 1: |
| cosine = cross_modal_cosine(z_p_pred, z_p_tgt) |
| metrics["cross_modal_cosine"] = cosine |
| if monitor.update(cosine): |
| print(f"[trainer] COLLAPSE DETECTED at step={step} cosine={cosine:.4f}") |
| aborted = True |
| wandb.log(metrics, step=step) |
| print(f"[step {step}] loss={metrics['loss']:.4f} " |
| f"L_cross={metrics['L_cross']:.4f} L_self={metrics['L_self']:.4f} " |
| f"tau={tau:.4f}") |
| step += 1 |
| if aborted: |
| break |
| if aborted: |
| break |
| if (epoch + 1) % cfg.ckpt_every_epochs == 0 or epoch == cfg.epochs - 1: |
| ckpt = out_root / f"ckpt_epoch{epoch + 1:03d}.pt" |
| torch.save({"model": model.state_dict(), "cfg": cfg.__dict__, "epoch": epoch + 1, |
| "step": step}, ckpt) |
| print(f"[trainer] saved {ckpt}") |
|
|
| final_ckpt = out_root / "ckpt_final.pt" |
| torch.save({"model": model.state_dict(), "cfg": cfg.__dict__, "aborted": aborted, |
| "step": step}, final_ckpt) |
| wandb.finish() |
| return {"aborted": aborted, "final_step": step, "ckpt": str(final_ckpt)} |
|
|