debugZero / eval /api_baseline.py
The-Fool-09's picture
Upload folder using huggingface_hub
8412998 verified
import asyncio
import inspect
import json
import os
import sys
import textwrap
from typing import Any, Optional
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from dotenv import load_dotenv
from openai import OpenAI
from client import DebugzeroEnv
from models import DebugzeroAction
load_dotenv()
API_BASE_URL = os.getenv("API_BASE_URL") or os.getenv("OPENAI_BASE_URL", "https://openrouter.ai/api/v1")
MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL", "meta-llama/llama-3.1-8b-instruct")
API_KEY = os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN")
ENV_URL = os.getenv("DEBUGZERO_ENV_URL", "http://localhost:8000")
NUM_EPISODES = int(os.getenv("NUM_EPISODES", "6"))
MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
PROPOSER_TEMPERATURE = float(os.getenv("PROPOSER_TEMPERATURE", "0.7"))
SOLVER_TEMPERATURE = float(os.getenv("SOLVER_TEMPERATURE", "0.2"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "1024"))
BUG_FOCUS = os.getenv("DEBUGZERO_BUG_FOCUS")
def extract_python_code(text: str) -> str:
content = (text or "").strip()
if content.startswith("```"):
content = content.split("\n", 1)[-1]
if content.endswith("```"):
content = content.rsplit("\n", 1)[0]
return content.strip()
def summarize_error(text: str, max_chars: int = 220) -> str:
cleaned = " ".join(text.strip().split())
if not cleaned:
return "null"
if len(cleaned) <= max_chars:
return cleaned
return cleaned[: max_chars - 3].rstrip() + "..."
def extract_env_error(result: Any) -> Optional[str]:
for attr in ("last_action_error", "error", "message"):
if hasattr(result, attr):
value = getattr(result, attr)
if value:
return str(value)
obs = getattr(result, "observation", None)
if obs is None:
return None
for attr in ("last_action_error", "error"):
if hasattr(obs, attr):
value = getattr(obs, attr)
if value:
return str(value)
execution_result = getattr(obs, "execution_result", "")
if isinstance(execution_result, str) and execution_result:
if getattr(obs, "syntax_error", False):
return summarize_error(execution_result)
if execution_result.startswith("Unsafe import detected."):
return execution_result
if not getattr(obs, "tests_passed", False):
return summarize_error(execution_result)
return None
def compact_action_string(role: str, code: str) -> str:
return json.dumps({"role": role, "code": code}, separators=(",", ":"), ensure_ascii=False)
def build_prompt(obs_dict: dict[str, Any], history: list[str]) -> str:
role = str(obs_dict.get("role_next", "proposer"))
current_code = str(obs_dict.get("current_code", ""))
execution_result = str(obs_dict.get("execution_result", ""))
tests_passed = bool(obs_dict.get("tests_passed", False))
syntax_error = bool(obs_dict.get("syntax_error", False))
metadata = obs_dict.get("metadata", {}) or {}
seed_id = metadata.get("seed_id", "unknown")
history_block = "\n".join(history[-4:]) if history else "None"
if role == "proposer":
focus_line = ""
if BUG_FOCUS:
focus_line = f"- Focus specifically on the `{BUG_FOCUS}` mutation family.\n"
instructions = textwrap.dedent(
f"""
You are the Proposer in a debugging self-play environment.
Return a full Python function with exactly one small logical bug injected.
Rules:
- Keep the code valid Python.
- Keep the same function signature.
- Preserve the overall structure and formatting as much as possible.
- Make exactly one small local behavioral change.
- Avoid comments, explanations, markdown outside the code block, and broad rewrites.
{focus_line}- Your goal is to make tests fail without creating a syntax error.
"""
).strip()
else:
instructions = textwrap.dedent(
"""
You are the Solver in a debugging self-play environment.
Return the full fixed Python function.
Rules:
- Keep the code valid Python.
- Keep the same function signature.
- Make the smallest correct local fix you can.
- Use the failure output to guide the repair.
- Avoid comments, explanations, markdown outside the code block, and unrelated refactors.
"""
).strip()
return textwrap.dedent(
f"""
{instructions}
Current environment state:
- seed_id: {seed_id}
- role_next: {role}
- tests_passed: {tests_passed}
- syntax_error: {syntax_error}
Current code:
```python
{current_code}
```
Execution result:
{execution_result if execution_result else "None"}
Previous actions:
{history_block}
Return only the full Python code inside triple backticks.
"""
).strip()
def get_model_code(client: OpenAI, obs_dict: dict[str, Any], history: list[str]) -> str:
role = str(obs_dict.get("role_next", "proposer"))
prompt = build_prompt(obs_dict, history)
temperature = PROPOSER_TEMPERATURE if role == "proposer" else SOLVER_TEMPERATURE
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": "You are an expert Python coder."},
{"role": "user", "content": prompt},
],
temperature=temperature,
max_tokens=MAX_TOKENS,
)
return extract_python_code(response.choices[0].message.content or "")
async def maybe_await(value: Any) -> Any:
if inspect.isawaitable(value):
return await value
return value
async def call_env_method(obj: Any, method_name: str, *args: Any) -> Any:
method = getattr(obj, method_name)
result = method(*args)
return await maybe_await(result)
async def make_env() -> Any:
max_retries = 30
for attempt in range(max_retries):
try:
return DebugzeroEnv(base_url=ENV_URL)
except Exception as exc:
print(
f"[SYSTEM ERROR] Env connection to {ENV_URL} failed (attempt {attempt + 1}/{max_retries}): {exc}",
file=sys.stderr,
flush=True,
)
if attempt < max_retries - 1:
await asyncio.sleep(5.0)
else:
raise
def print_live_summary(metrics: dict[str, Any]) -> None:
episodes = max(1, int(metrics["episodes"]))
proposer_attempts = max(1, int(metrics["proposer_attempts"]))
solver_attempts = max(1, int(metrics["solver_attempts"]))
rewards = metrics["rewards"]
average_reward = (sum(rewards) / len(rewards)) if rewards else 0.0
print("\n" + "=" * 80)
print("Live API summary")
print("=" * 80)
print(f"Episode success rate: {metrics['episode_successes'] / episodes:.2%}")
print(f"Proposer syntax rate: {metrics['proposer_syntax_errors'] / proposer_attempts:.2%}")
print(f"Solver syntax rate: {metrics['solver_syntax_errors'] / solver_attempts:.2%}")
print(f"Average step reward: {average_reward:.2f}")
print(f"Average steps/episode: {metrics['total_steps'] / episodes:.2f}")
print(f"Representative success: {metrics['representative_success']}")
print(f"Representative failure: {metrics['representative_failure']}")
async def run_live_api_probe() -> dict[str, Any] | None:
if not API_KEY:
print("Skipping live API probe: OPENAI_API_KEY/API_KEY is not set.")
return None
if not MODEL_NAME:
print("Skipping live API probe: OPENAI_MODEL/MODEL_NAME is not set.")
return None
client = OpenAI(api_key=API_KEY, base_url=API_BASE_URL)
env = await make_env()
metrics = {
"episodes": NUM_EPISODES,
"episode_successes": 0,
"proposer_attempts": 0,
"solver_attempts": 0,
"proposer_syntax_errors": 0,
"solver_syntax_errors": 0,
"rewards": [],
"total_steps": 0,
"representative_success": None,
"representative_failure": None,
}
print("=" * 80)
print("Live API probe")
print("=" * 80)
print(f"API base URL: {API_BASE_URL}")
print(f"Model: {MODEL_NAME}")
print(f"Env URL: {ENV_URL}")
try:
for episode in range(1, NUM_EPISODES + 1):
result = await call_env_method(env, "reset")
obs = getattr(result, "observation", None)
done = bool(getattr(result, "done", False))
history: list[str] = []
success = False
seed_id = "unknown"
if obs is not None:
metadata = getattr(obs, "metadata", {}) or {}
seed_id = metadata.get("seed_id", "unknown")
print(f"\nEpisode {episode}/{NUM_EPISODES} | seed={seed_id}")
for step in range(1, MAX_STEPS + 1):
if done or obs is None:
break
obs_dict = obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
role = str(obs_dict.get("role_next", "proposer"))
if role == "proposer":
metrics["proposer_attempts"] += 1
else:
metrics["solver_attempts"] += 1
try:
code = await asyncio.to_thread(get_model_code, client, obs_dict, history)
except Exception as exc:
print(f"[SYSTEM ERROR] Model generation failed: {exc}", file=sys.stderr, flush=True)
code = str(obs_dict.get("current_code", ""))
action = DebugzeroAction(role=role, code=code)
action_str = compact_action_string(role, code)
result = await call_env_method(env, "step", action)
obs = getattr(result, "observation", None)
done = bool(getattr(result, "done", False))
reward = float(getattr(result, "reward", 0.0) or 0.0)
error = extract_env_error(result)
metrics["rewards"].append(reward)
metrics["total_steps"] += 1
if obs is not None and getattr(obs, "syntax_error", False):
if role == "proposer":
metrics["proposer_syntax_errors"] += 1
else:
metrics["solver_syntax_errors"] += 1
print(
f" step={step} role={role} reward={reward:.2f} done={str(done).lower()} error={error or 'null'}",
flush=True,
)
history.append(f"Step {step}: {action_str} -> reward {reward:.2f}")
if done and obs is not None:
success = bool(getattr(obs, "tests_passed", False)) and not bool(
getattr(obs, "syntax_error", False)
)
if success:
metrics["episode_successes"] += 1
if metrics["representative_success"] is None:
metrics["representative_success"] = {
"seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"),
"steps": step,
"reward": reward,
}
elif metrics["representative_failure"] is None:
metrics["representative_failure"] = {
"seed_id": getattr(obs, "metadata", {}).get("seed_id", "unknown"),
"steps": step,
"execution_result": getattr(obs, "execution_result", ""),
}
break
if not success and metrics["representative_failure"] is None:
failure_seed = seed_id
failure_output = getattr(obs, "execution_result", "") if obs is not None else ""
metrics["representative_failure"] = {
"seed_id": failure_seed,
"steps": min(MAX_STEPS, len(history)),
"execution_result": failure_output,
}
return metrics
finally:
await call_env_method(env, "close")
async def main() -> None:
metrics = await run_live_api_probe()
if metrics is not None:
print_live_summary(metrics)
if __name__ == "__main__":
asyncio.run(main())