"""Persist training metrics + loss/reward plots to disk. Why this exists: the hackathon submission asks for "evidence you actually trained — at minimum loss and reward plots from a real run." Since we run as a script (not a notebook), nothing renders automatically. This module: * Snapshots ``trainer.state.log_history`` every N steps via a TrainerCallback (so a crashed run still leaves partial evidence behind), and * Dumps a final set of artifacts (CSV, JSON, PNGs) after ``trainer.train()``. All artifacts land in the trainer's ``output_dir`` so they ride back to the Hugging Face Hub when ``push_to_hub=True``. """ from __future__ import annotations import csv import json import logging from pathlib import Path from typing import Any, Dict, Iterable, List, Optional logger = logging.getLogger(__name__) # Reward keys we track. TRL logs reward functions under "rewards/" # (and a single-scalar "reward" = sum of weighted rewards). PRIMARY_REWARD_KEY = "rewards/reward_total" PHASE_REWARD_KEYS = ( "rewards/reward_market", "rewards/reward_warehouse", "rewards/reward_showroom", ) LOSS_KEY = "loss" STEP_KEY = "step" def _flatten_log_history(log_history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Make sure every row carries a `step` field even when TRL omits it on epoch logs.""" cleaned: List[Dict[str, Any]] = [] last_step = 0 for row in log_history: step = row.get("step", row.get("global_step", last_step)) last_step = step or last_step merged = {"step": last_step, **{k: v for k, v in row.items() if k != "step"}} cleaned.append(merged) return cleaned def _series(rows: List[Dict[str, Any]], key: str) -> List[tuple]: """Return ``[(step, value), ...]`` for the given metric key.""" out: List[tuple] = [] for r in rows: if key in r and r[key] is not None: try: out.append((int(r["step"]), float(r[key]))) except (TypeError, ValueError): continue return out def _save_csv(rows: List[Dict[str, Any]], path: Path) -> None: if not rows: return columns: List[str] = [] seen = set() for r in rows: for k in r.keys(): if k not in seen: seen.add(k) columns.append(k) with path.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=columns) writer.writeheader() writer.writerows(rows) def _save_json(rows: List[Dict[str, Any]], path: Path) -> None: with path.open("w") as f: json.dump(rows, f, indent=2, default=str) def _try_plot( series: Iterable[tuple], title: str, ylabel: str, out_path: Path, *, label: Optional[str] = None, ) -> bool: """Draw a single-series line plot. Silently no-ops if matplotlib is missing.""" try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except Exception as exc: logger.warning("matplotlib unavailable, skipping %s (%s)", out_path.name, exc) return False pts = list(series) if not pts: logger.warning("no data for %s, skipping plot", out_path.name) return False xs, ys = zip(*pts) fig, ax = plt.subplots(figsize=(8, 4.5)) ax.plot(xs, ys, marker="o", linewidth=1.5, label=label or ylabel) ax.set_xlabel("training step") ax.set_ylabel(ylabel) ax.set_title(title) ax.grid(True, alpha=0.3) if label: ax.legend(loc="best") fig.tight_layout() fig.savefig(out_path, dpi=120) plt.close(fig) return True def _try_plot_multi( name_to_series: Dict[str, Iterable[tuple]], title: str, ylabel: str, out_path: Path, ) -> bool: """Draw a multi-series line plot.""" try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt except Exception as exc: logger.warning("matplotlib unavailable, skipping %s (%s)", out_path.name, exc) return False fig, ax = plt.subplots(figsize=(8.5, 5)) drew_any = False for label, pts in name_to_series.items(): pts = list(pts) if not pts: continue xs, ys = zip(*pts) ax.plot(xs, ys, marker="o", linewidth=1.3, label=label) drew_any = True if not drew_any: plt.close(fig) logger.warning("no data for %s, skipping plot", out_path.name) return False ax.set_xlabel("training step") ax.set_ylabel(ylabel) ax.set_title(title) ax.grid(True, alpha=0.3) ax.legend(loc="best") fig.tight_layout() fig.savefig(out_path, dpi=120) plt.close(fig) return True def _summary_stats(series: List[tuple]) -> Dict[str, float]: if not series: return {"final": 0.0, "max": 0.0, "min": 0.0, "mean": 0.0, "n": 0} ys = [v for _, v in series] return { "final": float(ys[-1]), "max": float(max(ys)), "min": float(min(ys)), "mean": float(sum(ys) / len(ys)), "n": len(ys), } def save_training_artifacts( log_history: List[Dict[str, Any]], output_dir: str | Path, *, run_config: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Write metrics + loss/reward plots into ``output_dir``. Returns the summary dict that was also written to ``training_summary.json``. """ out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) rows = _flatten_log_history(log_history) _save_csv(rows, out / "metrics.csv") _save_json(rows, out / "metrics.json") loss_series = _series(rows, LOSS_KEY) total_reward_series = _series(rows, PRIMARY_REWARD_KEY) # Some TRL versions log a flat "reward" scalar in addition. Prefer the # named primary; fall back to "reward" if the named one is empty. if not total_reward_series: total_reward_series = _series(rows, "reward") phase_series = { "market": _series(rows, "rewards/reward_market"), "warehouse": _series(rows, "rewards/reward_warehouse"), "showroom": _series(rows, "rewards/reward_showroom"), } _try_plot( loss_series, title="Training loss (GRPO)", ylabel="loss", out_path=out / "loss_curve.png", label="loss", ) _try_plot( total_reward_series, title="Reward (total) — env cumulative_reward in [0, 1]", ylabel="reward", out_path=out / "reward_total_curve.png", label="reward_total", ) _try_plot_multi( { "reward_total": total_reward_series, **{f"reward_{k}": v for k, v in phase_series.items()}, }, title="Rewards over training", ylabel="reward", out_path=out / "reward_curve.png", ) summary: Dict[str, Any] = { "loss": _summary_stats(loss_series), "reward_total": _summary_stats(total_reward_series), "reward_market": _summary_stats(phase_series["market"]), "reward_warehouse": _summary_stats(phase_series["warehouse"]), "reward_showroom": _summary_stats(phase_series["showroom"]), "n_log_rows": len(rows), "output_dir": str(out.resolve()), } if run_config is not None: summary["run_config"] = run_config with (out / "training_summary.json").open("w") as f: json.dump(summary, f, indent=2, default=str) logger.info("Wrote training artifacts to %s", out.resolve()) return summary def build_metrics_callback(output_dir: str | Path, snapshot_every: int = 5): """Return a TrainerCallback that snapshots metrics every N steps + on end. Imported lazily so this module can be inspected on a machine without transformers installed (e.g. for the local --smoke run). """ from transformers.trainer_callback import TrainerCallback out = Path(output_dir) class MetricsSaverCallback(TrainerCallback): """Persist metrics CSV/JSON + plots periodically and at the end.""" def __init__(self) -> None: self._last_snapshot_step = -1 def _snapshot(self, state) -> None: try: save_training_artifacts(list(state.log_history or []), out) except Exception as exc: # never let plotting kill training logger.warning("metrics snapshot failed: %s", exc) def on_log(self, args, state, control, **kwargs): step = int(getattr(state, "global_step", 0) or 0) if step <= 0: return control if (step - self._last_snapshot_step) >= max(snapshot_every, 1): self._snapshot(state) self._last_snapshot_step = step return control def on_train_end(self, args, state, control, **kwargs): self._snapshot(state) return control return MetricsSaverCallback() def upload_training_artifacts_to_hub( output_dir: str | Path, repo_id: str, *, path_in_repo: str = "training_artifacts", ) -> list[str]: """Upload small evidence files to the same model repo (PNGs, CSV, JSON). ``GRPOTrainer.push_to_hub`` typically uploads weights/tokenizer only; this adds ``metrics.csv``, ``loss_curve.png``, and related files under ``path_in_repo/`` on the Hub so they survive ephemeral cloud jobs. """ from huggingface_hub import HfApi, create_repo out = Path(output_dir) if not out.is_dir(): return [] create_repo(repo_id, repo_type="model", exist_ok=True) api = HfApi() names = ( "metrics.csv", "metrics.json", "loss_curve.png", "reward_curve.png", "reward_total_curve.png", "training_summary.json", ) prefix = path_in_repo.strip("/") uploaded: list[str] = [] for name in names: path = out / name if not path.is_file(): continue dest = f"{prefix}/{name}" if prefix else name api.upload_file( path_or_fileobj=str(path), path_in_repo=dest, repo_id=repo_id, repo_type="model", ) uploaded.append(dest) return uploaded