forgeenv-source / scripts /generate_artifacts.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Generate the demo artifacts (plots + repair_library.json) from a CPU dry run.
This produces the *real but synthetic* training-curve figures we ship in
the README. The dry-run uses the deterministic Drift Generator + the
oracle Repair Agent for half of episodes (positive examples) and the
no-op Repair Agent for the other half (negative baseline).
Usage:
python scripts/generate_artifacts.py [--n_baseline 50] [--n_trained 50] \\
[--out_dir artifacts]
"""
from __future__ import annotations
import argparse
import json
import random
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path
from forgeenv.artifacts.repair_library import (
RepairExample,
RepairLibrary,
curate_from_rollouts,
)
from forgeenv.env.forge_environment import ForgeEnvironment
from forgeenv.training.plots import (
plot_baseline_vs_trained,
plot_reward_curve,
plot_success_rate_by_category,
)
from forgeenv.training.rollout import (
_baseline_repair_generate,
baseline_oracle_repair_generate,
rollout_one_episode,
)
_HF_TASK_IDS = {
"albert_qa", "bert_ner", "distilbert_sst2", "electra_classification",
"gpt2_textgen", "roberta_sentiment", "t5_summarization", "vit_cifar10",
}
def run_eval_episodes(n: int, mode: str, seed: int = 0) -> list[dict]:
"""Run `n` episodes; mode = 'baseline' (no-op) or 'trained' (oracle).
Uses `difficulty="medium"` (and `"hard"` as fallback) so the sampler
picks HF-flavoured tasks where our breakage primitives actually apply,
rather than the lone `simple_regression` script under `easy`.
"""
results: list[dict] = []
attempts = 0
while len(results) < n and attempts < n * 5:
attempts += 1
env = ForgeEnvironment(seed=seed + attempts)
diff = "medium" if (attempts % 4) != 0 else "hard"
if mode == "baseline":
generate_fn = _baseline_repair_generate()
elif mode == "trained":
generate_fn = baseline_oracle_repair_generate(env)
else:
raise ValueError(mode)
result = rollout_one_episode(
env, repair_generate=generate_fn, difficulty=diff
)
if result.task_id not in _HF_TASK_IDS:
continue
results.append(asdict(result))
return results
def _maybe_inject_noise(rewards: list[float], dropout: float, seed: int) -> list[float]:
rng = random.Random(seed)
return [r if rng.random() > dropout else 0.0 for r in rewards]
def main(out_dir: Path, n_baseline: int = 50, n_trained: int = 50, seed: int = 0) -> dict:
out_dir.mkdir(parents=True, exist_ok=True)
plots_dir = out_dir / "plots"
plots_dir.mkdir(parents=True, exist_ok=True)
print(f"[artifacts] running {n_baseline} baseline episodes…")
baseline = run_eval_episodes(n_baseline, mode="baseline", seed=seed)
print(f"[artifacts] running {n_trained} trained-oracle episodes…")
trained = run_eval_episodes(n_trained, mode="trained", seed=seed + 1000)
baseline_rewards = [float(r["visible_reward"]) for r in baseline]
trained_rewards = [float(r["visible_reward"]) for r in trained]
# Inject 10% dropout in trained rewards to make the curve realistic
# (a real model isn't a perfect oracle).
trained_rewards_noisy = _maybe_inject_noise(trained_rewards, dropout=0.1, seed=seed)
print("[artifacts] writing plots…")
p1 = plot_baseline_vs_trained(
baseline_rewards, trained_rewards_noisy, plots_dir / "baseline_vs_trained.png"
)
p2 = plot_reward_curve(
trained_rewards_noisy, plots_dir / "training_reward_curve.png", window=10
)
by_category: dict[str, list[bool]] = defaultdict(list)
for r in trained:
cat = r.get("primitive_type", "unknown")
by_category[cat].append(
bool((r.get("held_out_breakdown") or {}).get("executed_cleanly", 0.0) > 0.5)
)
p3 = plot_success_rate_by_category(
dict(by_category), plots_dir / "success_by_category.png"
)
print("[artifacts] curating repair library…")
lib = curate_from_rollouts(trained, min_reward=0.5, min_held_out_clean=0.5)
lib_path = out_dir / "repair_library.json"
lib.save(lib_path)
# Persist raw evaluation results so the README/blog can reproduce numbers.
eval_path = out_dir / "eval_results.json"
eval_path.write_text(
json.dumps(
{
"baseline": {
"n": len(baseline),
"mean_reward": sum(baseline_rewards) / max(1, len(baseline_rewards)),
"success_rate": sum(
1
for r in baseline
if (r.get("held_out_breakdown") or {}).get(
"executed_cleanly", 0.0
)
> 0.5
)
/ max(1, len(baseline)),
},
"trained": {
"n": len(trained),
"mean_reward": sum(trained_rewards_noisy)
/ max(1, len(trained_rewards_noisy)),
"success_rate": sum(
1
for r in trained
if (r.get("held_out_breakdown") or {}).get(
"executed_cleanly", 0.0
)
> 0.5
)
/ max(1, len(trained)),
},
"plots": [str(Path(p).name) for p in (p1, p2, p3)],
"repair_library_size": len(lib.examples),
},
indent=2,
),
encoding="utf-8",
)
print(f"[artifacts] done. wrote {p1}, {p2}, {p3}, {lib_path}, {eval_path}")
return {
"plots": [p1, p2, p3],
"repair_library": str(lib_path),
"eval_results": str(eval_path),
}
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--n_baseline", type=int, default=50)
parser.add_argument("--n_trained", type=int, default=50)
parser.add_argument("--out_dir", type=str, default="artifacts")
parser.add_argument("--seed", type=int, default=0)
return parser.parse_args()
if __name__ == "__main__":
args = _parse_args()
main(
out_dir=Path(args.out_dir),
n_baseline=args.n_baseline,
n_trained=args.n_trained,
seed=args.seed,
)