Inject-Arena / scripts /make_plots.py
Jaswanth1210's picture
feat: 5 training plots (reward±std, KL/loss, completion stats) + Drive backup in Cell 8
ff63792
"""Generate training plots from TRL trainer_state.json or JSONL logs + eval results.
Plots produced
--------------
1. reward_curve.png — mean reward ± std band + smoothed trend
2. kl_loss_curve.png — KL divergence & policy loss on twin axes
3. completion_stats.png — mean completion length + clipped-ratio line
4. bypass_bars.png — RL-trained vs handcrafted baseline (eval results)
5. per_category.png — per-scenario-category breakdown (eval results)
Usage
-----
python scripts/make_plots.py \\
--trainer-state /path/to/trainer_state.json \\
--eval docs/eval_results.json \\
--out docs/plots/
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--logs", type=str, default="logs/",
help="Directory of JSONL log files (legacy fallback)")
p.add_argument("--trainer-state", type=str, default=None,
help="Path to TRL trainer_state.json (preferred)")
p.add_argument("--out", type=str, default="docs/plots/")
p.add_argument("--eval", type=str, default="docs/eval_results.json")
return p.parse_args()
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
def _load_jsonl(path: Path) -> List[Dict[str, Any]]:
rows = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
try:
rows.append(json.loads(line))
except json.JSONDecodeError:
pass
return rows
def _load_all_logs(logs_dir: Path) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for p in sorted(logs_dir.glob("*.jsonl")):
rows.extend(_load_jsonl(p))
rows.sort(key=lambda r: r.get("step", r.get("global_step", 0)))
return rows
def _load_trainer_state(state_path: Path) -> List[Dict[str, Any]]:
"""Parse TRL trainer_state.json log_history into a normalised row list."""
if not state_path.exists():
print(f"trainer_state.json not found at {state_path}")
return []
with open(state_path) as f:
data = json.load(f)
# TRL GRPO key → normalised key mapping
KEY_MAP = {
"reward": "reward/mean",
"rewards/mean": "reward/mean",
"reward/mean": "reward/mean",
"reward_std": "reward/std",
"rewards/std": "reward/std",
"reward/std": "reward/std",
"kl": "kl",
"loss": "loss",
"train/loss": "loss",
"learning_rate": "lr",
"completions/mean_length": "completion/mean_length",
"completions/clipped_ratio": "completion/clipped_ratio",
}
rows: List[Dict[str, Any]] = []
for entry in data.get("log_history", []):
step = entry.get("step")
if step is None:
continue
row: Dict[str, Any] = {"step": step}
for src, dst in KEY_MAP.items():
if src in entry:
row[dst] = entry[src]
if len(row) > 1: # has at least one metric besides step
rows.append(row)
rows.sort(key=lambda r: r["step"])
print(f"Loaded {len(rows)} log entries from {state_path}")
return rows
def _extract(rows: List[Dict[str, Any]], key: str) -> Tuple[List[int], List[float]]:
steps, vals = [], []
for r in rows:
v = r.get(key)
if v is not None:
steps.append(r["step"])
vals.append(float(v))
return steps, vals
# ---------------------------------------------------------------------------
# Plot helpers
# ---------------------------------------------------------------------------
def _smooth(vals: List[float], window: int) -> List[float]:
import numpy as np
if window <= 1 or len(vals) < window:
return vals
return list(np.convolve(vals, np.ones(window) / window, mode="valid"))
BLUE = "#3b82f6"
DBLUE = "#1d4ed8"
RED = "#ef4444"
GREEN = "#22c55e"
ORANGE = "#f97316"
PURPLE = "#a855f7"
GRAY = "#94a3b8"
# ---------------------------------------------------------------------------
# Plot 1: Reward curve with ±std band
# ---------------------------------------------------------------------------
def _plot_reward_curve(rows: List[Dict[str, Any]], out_dir: Path) -> None:
import matplotlib.pyplot as plt
import numpy as np
steps_r, rewards = _extract(rows, "reward/mean")
steps_s, stds = _extract(rows, "reward/std")
if not steps_r:
print("No reward data — skipping reward_curve.png")
return
window = max(1, len(rewards) // 15)
smoothed = _smooth(rewards, window)
smooth_steps = steps_r[window - 1:] if window > 1 else steps_r
fig, ax = plt.subplots(figsize=(10, 5))
# std band
if steps_s and len(steps_s) == len(steps_r):
r_arr = np.array(rewards)
s_arr = np.array(stds)
ax.fill_between(steps_r, r_arr - s_arr, r_arr + s_arr,
alpha=0.15, color=BLUE, label="±1 std")
ax.plot(steps_r, rewards, alpha=0.35, color=BLUE, linewidth=0.9, label="raw reward")
ax.plot(smooth_steps, smoothed, color=DBLUE, linewidth=2.2,
label=f"smoothed (w={window})")
ax.axhline(0, color="gray", linestyle="--", linewidth=0.6)
ax.set_xlabel("Training Step")
ax.set_ylabel("Mean Reward")
ax.set_title("InjectArena — GRPO Reward Curve (300 steps)")
ax.legend(loc="lower right")
ax.set_ylim(bottom=0)
ax.grid(alpha=0.25)
out_path = out_dir / "reward_curve.png"
plt.tight_layout()
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved {out_path}")
# ---------------------------------------------------------------------------
# Plot 2: KL divergence + policy loss on twin axes
# ---------------------------------------------------------------------------
def _plot_kl_loss(rows: List[Dict[str, Any]], out_dir: Path) -> None:
import matplotlib.pyplot as plt
steps_kl, kls = _extract(rows, "kl")
steps_l, losses = _extract(rows, "loss")
if not steps_kl and not steps_l:
print("No KL/loss data — skipping kl_loss_curve.png")
return
fig, ax1 = plt.subplots(figsize=(10, 4))
if steps_kl:
ax1.plot(steps_kl, kls, color=PURPLE, linewidth=1.8, label="KL divergence")
ax1.set_ylabel("KL Divergence", color=PURPLE)
ax1.tick_params(axis="y", labelcolor=PURPLE)
if steps_l:
ax2 = ax1.twinx()
ax2.plot(steps_l, losses, color=RED, linewidth=1.8, linestyle="--", label="Policy loss")
ax2.set_ylabel("Policy Loss", color=RED)
ax2.tick_params(axis="y", labelcolor=RED)
ax1.set_xlabel("Training Step")
ax1.set_title("InjectArena — KL Divergence & Policy Loss")
ax1.grid(alpha=0.2)
# Combined legend
lines = []
if steps_kl:
lines += ax1.get_lines()
if steps_l:
lines += ax2.get_lines()
if lines:
ax1.legend(lines, [l.get_label() for l in lines], loc="upper right")
out_path = out_dir / "kl_loss_curve.png"
plt.tight_layout()
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved {out_path}")
# ---------------------------------------------------------------------------
# Plot 3: Completion length + clipped ratio
# ---------------------------------------------------------------------------
def _plot_completion_stats(rows: List[Dict[str, Any]], out_dir: Path) -> None:
import matplotlib.pyplot as plt
steps_l, lengths = _extract(rows, "completion/mean_length")
steps_c, clipped = _extract(rows, "completion/clipped_ratio")
if not steps_l and not steps_c:
print("No completion stats — skipping completion_stats.png")
return
fig, ax1 = plt.subplots(figsize=(10, 4))
if steps_l:
ax1.plot(steps_l, lengths, color=ORANGE, linewidth=1.8, label="Mean completion length (tokens)")
ax1.set_ylabel("Mean Length (tokens)", color=ORANGE)
ax1.tick_params(axis="y", labelcolor=ORANGE)
if steps_c:
ax2 = ax1.twinx()
ax2.plot(steps_c, clipped, color=RED, linewidth=1.8, linestyle="--",
label="Clipped ratio (hit max_len)")
ax2.set_ylabel("Clipped Ratio", color=RED)
ax2.set_ylim(0, 1.05)
ax2.tick_params(axis="y", labelcolor=RED)
ax2.axhline(1.0, color=RED, linestyle=":", linewidth=0.7, alpha=0.5)
ax1.set_xlabel("Training Step")
ax1.set_title("InjectArena — Completion Length & Clipping")
ax1.grid(alpha=0.2)
lines = []
if steps_l:
lines += ax1.get_lines()
if steps_c:
lines += ax2.get_lines()
if lines:
ax1.legend(lines, [l.get_label() for l in lines], loc="upper right")
out_path = out_dir / "completion_stats.png"
plt.tight_layout()
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved {out_path}")
# ---------------------------------------------------------------------------
# Plot 4: Bypass bars (eval results vs baseline)
# ---------------------------------------------------------------------------
def _plot_bypass_bars(eval_path: Path, out_dir: Path) -> None:
import matplotlib.pyplot as plt
import numpy as np
if not eval_path.exists():
print(f"Eval results not found at {eval_path} — skipping bypass_bars.png")
return
with open(eval_path) as f:
data = json.load(f)
metrics = {
"PG2 Bypass": data.get("pg2_bypass_rate", 0),
"FW Bypass": data.get("fw_bypass_rate", 0),
"Task Success": data.get("task_success_rate", 0),
"Composed Bypass": data.get("composed_bypass_rate", 0),
}
baselines = {
"PG2 Bypass": 0.15, "FW Bypass": 0.20,
"Task Success": 0.05, "Composed Bypass": 0.02,
}
x = np.arange(len(metrics))
width = 0.35
fig, ax = plt.subplots(figsize=(9, 5))
bars1 = ax.bar(x - width / 2, [baselines[k] for k in metrics],
width, label="Handcrafted Baseline", color=GRAY, edgecolor="white")
bars2 = ax.bar(x + width / 2, [metrics[k] for k in metrics],
width, label="InjectArena (RL-trained)", color=BLUE, edgecolor="white")
ax.set_ylabel("Rate")
ax.set_title("InjectArena — Attacker Performance vs Baseline")
ax.set_xticks(x)
ax.set_xticklabels(list(metrics.keys()))
ax.set_ylim(0, 1.05)
ax.legend()
ax.grid(axis="y", alpha=0.3)
for bar in bars1:
h = bar.get_height()
if h > 0.01:
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01,
f"{h:.0%}", ha="center", va="bottom", fontsize=9, color="#475569")
for bar in bars2:
h = bar.get_height()
if h > 0.01:
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.01,
f"{h:.0%}", ha="center", va="bottom", fontsize=9, color=DBLUE)
out_path = out_dir / "bypass_bars.png"
plt.tight_layout()
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved {out_path}")
# ---------------------------------------------------------------------------
# Plot 5: Per-category breakdown
# ---------------------------------------------------------------------------
def _plot_per_category(eval_path: Path, out_dir: Path) -> None:
import matplotlib.pyplot as plt
import numpy as np
if not eval_path.exists():
return
with open(eval_path) as f:
data = json.load(f)
per_cat = data.get("per_category", {})
if not per_cat:
print("No per_category data in eval results — skipping per_category.png")
return
cats = list(per_cat.keys())
task_rates = [per_cat[c]["task_success"] for c in cats]
bypass_rates = [per_cat[c]["composed_bypass"] for c in cats]
x = np.arange(len(cats))
width = 0.35
fig, ax = plt.subplots(figsize=(8, 5))
ax.bar(x - width / 2, task_rates, width, label="Task Success", color=GREEN, edgecolor="white")
ax.bar(x + width / 2, bypass_rates, width, label="Composed Bypass", color=BLUE, edgecolor="white")
ax.set_ylabel("Rate")
ax.set_title("InjectArena — Per-Category Breakdown")
ax.set_xticks(x)
ax.set_xticklabels(cats, rotation=15, ha="right")
ax.set_ylim(0, 1.05)
ax.legend()
ax.grid(axis="y", alpha=0.3)
out_path = out_dir / "per_category.png"
plt.tight_layout()
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved {out_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
args = _parse_args()
out_dir = Path(args.out)
eval_path = Path(args.eval)
out_dir.mkdir(parents=True, exist_ok=True)
try:
import matplotlib
matplotlib.use("Agg")
except ImportError:
print("matplotlib not installed — pip install matplotlib")
return
# Load training log rows (trainer_state preferred, JSONL fallback)
rows: List[Dict[str, Any]] = []
if args.trainer_state:
rows = _load_trainer_state(Path(args.trainer_state))
if not rows:
logs_dir = Path(args.logs)
if logs_dir.exists():
rows = _load_all_logs(logs_dir)
if rows:
print(f"Loaded {len(rows)} log rows from {logs_dir}")
# Training plots (require rows)
if rows:
_plot_reward_curve(rows, out_dir)
_plot_kl_loss(rows, out_dir)
_plot_completion_stats(rows, out_dir)
else:
print("No training log data found — skipping reward/KL/completion plots.")
# Eval plots (require eval results JSON)
_plot_bypass_bars(eval_path, out_dir)
_plot_per_category(eval_path, out_dir)
print("\nAll plots done.")
for p in sorted(out_dir.glob("*.png")):
print(f" {p}")
if __name__ == "__main__":
main()