| """Smoke-test the trained Repair Agent locally on one episode. |
| |
| Loads the LoRA adapter pushed to ``akhiilll/forgeenv-repair-agent``, hits |
| the live ForgeEnv Space for a fresh broken script, asks the model to |
| emit a unified diff, applies it, and prints the verifier breakdown. |
| |
| Usage:: |
| |
| python scripts/test_repair_agent.py --seed 7 |
| python scripts/test_repair_agent.py --seed 7 --base-model unsloth/Qwen2.5-Coder-1.5B-Instruct |
| |
| Requires GPU + transformers/peft. Skip this if you only want a quick |
| demo -- use ``scripts/test_live_env.py`` or the Gradio Space instead. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import asyncio |
| import json |
|
|
| from openenv.core import GenericAction, GenericEnvClient |
|
|
| ENV_URL = "https://akhiilll-forgeenv.hf.space" |
| LORA_REPO = "akhiilll/forgeenv-repair-agent" |
|
|
| REPAIR_PROMPT = """\ |
| You are a senior ML engineer fixing a HuggingFace training script that just broke. |
| Output ONLY a unified diff (`--- a/script.py` / `+++ b/script.py`) that fixes the |
| breakage signaled by the error trace. No prose, no fences, no explanation. |
| |
| # Broken script |
| ```python |
| {script} |
| ``` |
| |
| # Error trace |
| ``` |
| {error} |
| ``` |
| |
| # Diff |
| """ |
|
|
|
|
| async def fetch_broken_episode(seed: int): |
| client = GenericEnvClient(base_url=ENV_URL) |
| res = await client.reset(seed=seed, options={"difficulty": "medium"}) |
| target = res.observation["target_category"] |
| res = await client.step(GenericAction( |
| breakage={"action_type": "breakage", "primitive_type": target, "params": {}}, |
| repair=None, |
| )) |
| obs = res.observation |
| return client, obs.get("script_content") or obs.get("broken_script") or "", obs.get("error_trace", "") |
|
|
|
|
| async def submit_repair(client: GenericEnvClient, diff: str): |
| res = await client.step(GenericAction( |
| breakage=None, |
| repair={"action_type": "repair", "unified_diff": diff}, |
| )) |
| return res |
|
|
|
|
| def generate_diff(base_model: str, lora_repo: str, prompt: str) -> str: |
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| print(f"loading base model: {base_model}") |
| tok = AutoTokenizer.from_pretrained(base_model) |
| model = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| print(f"attaching LoRA: {lora_repo}") |
| model = PeftModel.from_pretrained(model, lora_repo) |
| model.eval() |
|
|
| inputs = tok(prompt, return_tensors="pt").to(model.device) |
| with torch.no_grad(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=512, |
| do_sample=False, |
| temperature=0.0, |
| pad_token_id=tok.eos_token_id, |
| ) |
| text = tok.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
| return text.strip() |
|
|
|
|
| async def main(args) -> None: |
| print(f"--- pulling broken episode (seed={args.seed}) from {ENV_URL}") |
| client, broken_script, error_trace = await fetch_broken_episode(args.seed) |
| if not broken_script: |
| raise SystemExit("env returned empty script_content; pick a different seed") |
| print(f"broken script length: {len(broken_script)} chars") |
| print(f"error trace : {(error_trace[:200] + '...') if len(error_trace) > 200 else error_trace}") |
|
|
| prompt = REPAIR_PROMPT.format(script=broken_script, error=error_trace or "<env did not surface a trace>") |
| diff = generate_diff(args.base_model, args.lora_repo, prompt) |
|
|
| print("\n=== model diff ===") |
| print(diff) |
|
|
| print("\n=== submitting diff to env ===") |
| res = await submit_repair(client, diff) |
| print(f"reward: {res.reward} done: {res.done}") |
| breakdown = res.observation.get("reward_breakdown") if isinstance(res.observation, dict) else None |
| if breakdown: |
| print("reward_breakdown:") |
| print(json.dumps(breakdown, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser() |
| p.add_argument("--seed", type=int, default=7) |
| p.add_argument("--base-model", default="unsloth/Qwen2.5-Coder-1.5B-Instruct") |
| p.add_argument("--lora-repo", default=LORA_REPO) |
| args = p.parse_args() |
| asyncio.run(main(args)) |
|
|