clarify-rl / scripts /run_eval.py
agarwalanu3103's picture
eval: enforce one-tool-call response format on every turn
a22fcfd verified
#!/usr/bin/env python
"""Evaluate any policy / model on the held-out scenario set.
Two modes:
policy Use the deterministic POLICY_PLANS asker from inference.py
— no LLM, free, deterministic, used as the floor baseline.
api Use an OpenAI-compatible chat endpoint (the same path the
submission validator uses on inference.py). Set:
MODEL_NAME e.g. Qwen/Qwen3-0.6B
API_BASE_URL e.g. https://router.huggingface.co/v1
HF_TOKEN write/read token
Output: a single JSON file with per-scenario scores, breakdowns,
question counts, and aggregate metrics — formatted exactly the way
`scripts/make_plots.py` consumes.
Usage:
# baseline (deterministic policy)
python scripts/run_eval.py --mode policy --out outputs/eval_policy.json --limit 100
# untrained Qwen3-0.6B via HF Inference router
HF_TOKEN=hf_xxx MODEL_NAME=Qwen/Qwen3-0.6B \\
python scripts/run_eval.py --mode api --out outputs/eval_qwen3-0.6b_base.json --limit 100
# trained model via HF Inference Endpoints (you provided the URL)
API_BASE_URL=https://my-endpoint.endpoints.huggingface.cloud/v1 \\
MODEL_NAME=clarify-rl-grpo-qwen3-0.6b HF_TOKEN=hf_xxx \\
python scripts/run_eval.py --mode api --out outputs/eval_qwen3-0.6b_trained.json --limit 100
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import sys
import time
from pathlib import Path
from typing import Any, Optional
# Make the inference.py helpers importable without copy-paste.
_HERE = Path(__file__).resolve().parent
_REPO = _HERE.parent
sys.path.insert(0, str(_REPO))
def _lazy_import_inference():
"""Lazy-import inference.py so `--help` works without openai installed."""
import inference as _inf # type: ignore
return _inf
def _make_ws_url(base_url: str) -> str:
return base_url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/") + "/ws"
async def _ws_reset_with_seed(ws, task_id: str, seed: int) -> dict:
"""Reset env to a specific (task_id, seed) — exact replay of an eval scenario."""
await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id, "seed": seed}}))
resp = json.loads(await ws.recv())
if resp.get("type") == "error":
return {"observation": {}, "reward": 0.0, "done": False, "error": resp.get("data", {})}
data = resp.get("data", {})
return {
"observation": data.get("observation", {}),
"reward": float(data.get("reward", 0.0)),
"done": bool(data.get("done", False)),
}
def _parse_observation(obs: dict) -> dict:
"""Pull the canonical tool-result dict out of an MCP observation."""
result = obs.get("result")
if isinstance(result, dict):
if isinstance(result.get("structured_content"), dict):
return result["structured_content"]
if isinstance(result.get("data"), dict):
return result["data"]
content = result.get("content")
if isinstance(content, list) and content:
txt = content[0].get("text", "")
try:
parsed = json.loads(txt)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return result
if isinstance(result, str):
try:
parsed = json.loads(result)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
return {}
async def _eval_one_scenario(
ws,
scenario: dict,
mode: str,
llm_client,
timeout_s: float,
inf,
) -> dict:
"""Run a single scenario end-to-end. Returns a result row."""
seed = scenario["seed"]
task_id = scenario["task_id"]
family = scenario.get("family", "")
t0 = time.time()
reset = await _ws_reset_with_seed(ws, task_id, seed)
if "error" in reset:
return {
"seed": seed,
"task_id": task_id,
"scenario_id": f"seed{seed:05d}_{family}_{task_id}",
"family": family,
"request": "",
"final_score": 0.0,
"score_breakdown": {},
"questions_asked": 0,
"format_pass": False,
"error": str(reset["error"]),
"messages": [],
"trace": [],
"elapsed_s": time.time() - t0,
}
initial_data = _parse_observation(reset["observation"])
request_text = initial_data.get("request", "")
max_steps = int(initial_data.get("max_steps", 10))
messages = [
{"role": "system", "content": inf.SYSTEM_PROMPT},
{"role": "user", "content": (
f"USER REQUEST:\n{request_text}\n\nYou have {max_steps} steps. "
"Available tools: ask_question(question), propose_plan(plan), get_task_info().\n\n"
"RESPONSE FORMAT: Reply with ONE function call only, no other text.\n"
"Examples:\n"
" ask_question(\"What is the date?\")\n"
" propose_plan('{\"event_type\": \"birthday\", \"date\": \"2024-12-25\"}')\n"
" get_task_info()\n"
)},
]
trace: list[dict] = []
revealed: dict[str, Any] = {}
questions_asked = 0
final_score = 0.0
score_breakdown: dict[str, float] = {}
format_pass: Optional[bool] = None
parse_error: Optional[str] = None
llm_attempts = 0
used_policy_step = 0
done = False
for step in range(max_steps):
if time.time() - t0 > timeout_s:
trace.append({"step": step, "error": "timeout"})
break
if mode == "policy":
tool_name, args = inf._next_policy_action( # type: ignore[attr-defined]
task_id, used_policy_step, request_text, revealed
)
used_policy_step += 1
else: # api
tool_name, args, fellback, llm_attempts = inf._choose_action( # type: ignore[attr-defined]
task_id, messages, llm_client, used_policy_step, llm_attempts, request_text, revealed
)
if fellback:
used_policy_step += 1
try:
step_resp = await inf.ws_step(ws, tool_name, args)
except Exception as exc: # noqa: BLE001
trace.append({"step": step, "error": f"ws_step exception: {exc}"})
break
obs = step_resp.get("observation", {}) or {}
result = _parse_observation(obs)
done = bool(step_resp.get("done"))
record = {
"step": step,
"tool": tool_name,
"args": args,
"reward": float(step_resp.get("reward", 0.0)),
"done": done,
"result": result,
}
trace.append(record)
format_reminder = (
"\n\nReminder: Reply with ONE function call only "
"(ask_question/propose_plan/get_task_info), no other text."
)
if tool_name == "ask_question":
questions_asked += 1
if isinstance(result, dict) and result.get("field_revealed"):
fld = result["field_revealed"]
ans = result.get("answer", "")
revealed[fld] = ans
messages.append({"role": "user", "content": json.dumps(result) + format_reminder})
elif tool_name == "get_task_info":
messages.append({"role": "user", "content": json.dumps(result) + format_reminder})
elif tool_name == "propose_plan":
if isinstance(result, dict):
final_score = float(result.get("score", step_resp.get("reward", 0.0)))
score_breakdown = result.get("breakdown", {}) or {}
parse_error = result.get("parse_error")
fmt = score_breakdown.get("FormatCheck") or score_breakdown.get("format_check")
if fmt is not None:
format_pass = fmt > 0
done = True
if done:
break
return {
"seed": seed,
"task_id": task_id,
"scenario_id": f"seed{seed:05d}_{family}_{task_id}",
"family": family,
"request": request_text,
"final_score": final_score,
"score_breakdown": score_breakdown,
"questions_asked": questions_asked,
"format_pass": format_pass,
"parse_error": parse_error,
"messages": messages,
"trace": trace,
"elapsed_s": time.time() - t0,
}
async def _run(args) -> dict:
inf = _lazy_import_inference()
eval_path = Path(args.scenarios)
if not eval_path.exists():
raise FileNotFoundError(f"Scenario file not found: {eval_path}")
scenarios = json.loads(eval_path.read_text())
if args.limit and args.limit < len(scenarios):
scenarios = scenarios[: args.limit]
print(f"Loaded {len(scenarios)} scenarios from {eval_path}")
llm_client = None
if args.mode == "api":
if not inf.API_KEY:
raise RuntimeError("api mode requires HF_TOKEN / OPENAI_API_KEY")
llm_client = inf.create_client()
if llm_client is None:
raise RuntimeError("Failed to create OpenAI client (check API_BASE_URL/HF_TOKEN)")
print(f"Using OpenAI client with base_url={inf.API_BASE_URL} model={inf.MODEL_NAME}")
else:
print("Mode: policy (deterministic, no LLM)")
import websockets
results: list[dict] = []
ws_url = _make_ws_url(args.env)
print(f"Env WS: {ws_url}")
print(f"Output to: {args.out}")
print()
overall_t0 = time.time()
async with websockets.connect(
ws_url, open_timeout=30, close_timeout=10, max_size=2**24
) as ws:
for i, scn in enumerate(scenarios):
print(f"[{i+1}/{len(scenarios)}] family={scn.get('family','?')} task={scn['task_id']} seed={scn['seed']}", flush=True)
row = await _eval_one_scenario(ws, scn, args.mode, llm_client, args.timeout, inf)
results.append(row)
print(
f" score={row['final_score']:.3f} q={row['questions_asked']} fmt={row['format_pass']} "
f"err={row.get('error') or row.get('parse_error') or ''}",
flush=True,
)
total_s = time.time() - overall_t0
scores = [r["final_score"] for r in results]
fmt_passes = [r["format_pass"] for r in results if r["format_pass"] is not None]
qs = [r["questions_asked"] for r in results]
summary = {
"model": inf.MODEL_NAME if args.mode == "api" else None,
"mode": args.mode,
"scenarios_total": len(results),
"elapsed_s": total_s,
"avg_score": sum(scores) / len(scores) if scores else 0.0,
"avg_questions": sum(qs) / len(qs) if qs else 0.0,
"format_pass_rate": (sum(1 for f in fmt_passes if f) / len(fmt_passes)) if fmt_passes else 0.0,
"completion_rate": sum(1 for r in results if r["final_score"] > 0) / max(1, len(results)),
}
payload = {
"summary": summary,
"config": {
"mode": args.mode,
"model": inf.MODEL_NAME if args.mode == "api" else None,
"api_base_url": inf.API_BASE_URL if args.mode == "api" else None,
"env_base_url": args.env,
"scenarios_file": str(eval_path),
"limit": args.limit,
},
"results": results,
}
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(payload, indent=2))
print()
print(f"Saved {len(results)} results to {out_path}")
print(f"Avg score: {summary['avg_score']:.4f}")
print(f"Format pass rate: {summary['format_pass_rate']:.4f}")
print(f"Completion rate: {summary['completion_rate']:.4f}")
print(f"Avg questions: {summary['avg_questions']:.2f}")
print(f"Total elapsed: {total_s:.1f} s")
return summary
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--mode", choices=("policy", "api"), required=True)
parser.add_argument(
"--scenarios",
default=str(_REPO / "scenarios" / "eval_held_out.json"),
help="Path to eval scenario JSON (default: scenarios/eval_held_out.json)",
)
parser.add_argument("--out", required=True, help="Output JSON file (e.g. outputs/eval_policy.json)")
parser.add_argument("--limit", type=int, default=None, help="Cap to first N scenarios")
parser.add_argument(
"--env",
default=os.environ.get("ENV_BASE_URL", "https://agarwalanu3103-clarify-rl.hf.space"),
help="Env Space URL",
)
parser.add_argument("--timeout", type=float, default=180.0, help="Per-scenario timeout in seconds")
args = parser.parse_args()
asyncio.run(_run(args))
if __name__ == "__main__":
main()