PULSE-code / scripts /eval_topk_v3.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""Re-evaluate v3 saved models to compute action_vn@3 and action_vn@5.
Loads model_best.pt from each seed dir, runs test set, computes:
- action_vn_top1 / top3 / top5 (verb_fine top-K AND noun top-K)
- verb_fine_top1 / top3 / top5
- noun_top1 / top3 / top5
Writes results into <seed_dir>/eval_topk.json so the aggregator can pick them up.
"""
from __future__ import annotations
import json, sys, time
from pathlib import Path
import pandas as pd # noqa
import torch
from torch.utils.data import DataLoader
REPO = Path("${PULSE_ROOT}")
sys.path.insert(0, str(REPO / "experiments"))
from dataset_seqpred import build_train_test, collate_triplet # noqa
from models_seqpred import build_model # noqa
def topk_correct(logits, y, k):
if k > logits.shape[1]:
k = logits.shape[1]
_, topk = logits.topk(k, dim=1)
return (topk == y.unsqueeze(1)).any(dim=1)
def find_v3_seed_dirs():
"""Walk table1_main_comparison/row*/seeds_v3{,_bidir,_sf}/seed*/model_best.pt"""
out = []
base = REPO / "table1_main_comparison"
for row_dir in sorted(base.glob("row*")):
for sub in ("seeds_v3", "seeds_v3_bidir", "seeds_v3_sf"):
for sd in sorted((row_dir / sub).glob("seed*")):
if (sd / "model_best.pt").exists() and (sd / "results.json").exists():
out.append(sd)
return out
_loader_cache = {}
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device={device}", flush=True)
seed_dirs = find_v3_seed_dirs()
print(f"Found {len(seed_dirs)} v3 seed dirs", flush=True)
t0 = time.time()
n_ok, n_fail = 0, 0
for i, sd in enumerate(seed_dirs, 1):
try:
with open(sd / "results.json") as f:
results = json.load(f)
args = results["args"]
mods_list = args["modalities"].split(",")
mods_key = tuple(mods_list)
mode = args.get("mode", "anticipation")
if (mods_key, mode) not in _loader_cache:
print(f" [build loader] mode={mode} modalities={mods_list}", flush=True)
train_ds, test_ds = build_train_test(modalities=mods_list, mode=mode)
del train_ds
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False,
collate_fn=collate_triplet, num_workers=0)
_loader_cache[(mods_key, mode)] = (test_loader, test_ds.modality_dims)
test_loader, modality_dims = _loader_cache[(mods_key, mode)]
extra = {}
if args["model"] in ("dailyactformer", "ours", "daf"):
extra["causal"] = (mode == "anticipation")
model = build_model(args["model"], modality_dims, **extra).to(device)
state = torch.load(sd / "model_best.pt", map_location=device, weights_only=False)
model.load_state_dict(state["state_dict"])
model.eval()
all_logits = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")}
all_y = {k: [] for k in ("verb_fine", "verb_composite", "noun", "hand")}
with torch.no_grad():
for x, mask, lens, y, meta in test_loader:
x = {m: t.to(device) for m, t in x.items()}
mask = mask.to(device)
logits = model(x, mask)
for k in all_logits:
all_logits[k].append(logits[k].cpu())
all_y[k].append(y[k])
logits_cat = {k: torch.cat(v, dim=0) for k, v in all_logits.items()}
y_cat = {k: torch.cat(v, dim=0) for k, v in all_y.items()}
out = {}
for k in ("verb_fine", "verb_composite", "noun", "hand"):
preds_top1 = logits_cat[k].argmax(dim=1)
out[f"{k}_top1"] = float((preds_top1 == y_cat[k]).float().mean())
out[f"{k}_top3"] = float(topk_correct(logits_cat[k], y_cat[k], 3).float().mean())
out[f"{k}_top5"] = float(topk_correct(logits_cat[k], y_cat[k], 5).float().mean())
# Joint action_vn (verb_fine ∧ noun) at top-1, top-3, top-5
for K, lbl in [(1, "top1"), (3, "top3"), (5, "top5")]:
vf_ok = topk_correct(logits_cat["verb_fine"], y_cat["verb_fine"], K)
n_ok2 = topk_correct(logits_cat["noun"], y_cat["noun"], K)
out[f"action_vn_{lbl}"] = float((vf_ok & n_ok2).float().mean())
with open(sd / "eval_topk.json", "w") as f:
json.dump(out, f, indent=2)
n_ok += 1
if i % 5 == 0 or i <= 3:
rel = sd.relative_to(REPO)
print(f" [{i:>3}/{len(seed_dirs)}] {rel} vn@1={out['action_vn_top1']:.4f} vn@3={out['action_vn_top3']:.4f} vn@5={out['action_vn_top5']:.4f}", flush=True)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
n_fail += 1
print(f" [{i:>3}/{len(seed_dirs)}] FAIL {sd.relative_to(REPO)}: {e}", flush=True)
print(f"Done. ok={n_ok} fail={n_fail} elapsed={time.time()-t0:.1f}s", flush=True)
if __name__ == "__main__":
main()