sql-agent-openenv / inference.py
ar9avg's picture
Bulletproof _safe_score for all bad inputs (None, NaN, strings, bool)
2014920
"""
SQL Agent OpenEnv β€” Baseline Inference Script
==============================================
Runs a baseline LLM agent against all 3 tasks of the SQL Agent OpenEnv environment.
Environment variables (required):
API_BASE_URL β€” OpenAI-compatible base URL (default: https://router.huggingface.co/v1)
MODEL_NAME β€” Model identifier (default: Qwen/Qwen2.5-72B-Instruct)
HF_TOKEN β€” Hugging Face / API key
STDOUT format (strictly enforced):
[START] task=<task_id> env=sql-agent-openenv model=<model>
[STEP] step=<n> action=<action> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
"""
from __future__ import annotations
import asyncio
import os
import sys
import textwrap
from typing import List, Optional
# ── Path setup (inference.py lives at repo root; backend is a subdirectory) ──
_BACKEND = os.path.join(os.path.dirname(os.path.abspath(__file__)), "backend")
if _BACKEND not in sys.path:
sys.path.insert(0, _BACKEND)
from openai import OpenAI # noqa: E402
from env.sql_env import SQLAgentEnv, Action, Observation # noqa: E402
# ── Config ────────────────────────────────────────────────────────────────────
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY", "")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BENCHMARK = "sql-agent-openenv"
TASKS = ["simple_queries", "join_queries", "complex_queries"]
MAX_STEPS = 5
TEMPERATURE = 0.2
MAX_TOKENS = 50
REPAIR_ACTIONS = [
"rewrite_full",
"fix_column",
"fix_table",
"add_groupby",
"rewrite_cte",
"fix_syntax",
"change_dialect",
"relax_filter",
]
SYSTEM_PROMPT = textwrap.dedent("""
You are an expert SQL agent interacting with a SQL repair environment.
At each step you receive a natural language question, a database schema,
and optionally the last SQL attempt + error message.
Your job: pick ONE repair action from the list below that is most likely
to fix the SQL error on the next attempt.
Available actions:
generate β€” write fresh SQL from scratch (use on first attempt)
rewrite_full β€” completely rewrite the query from scratch
fix_column β€” fix wrong column name references
fix_table β€” fix wrong table name references
add_groupby β€” add or fix GROUP BY / aggregation clauses
rewrite_cte β€” restructure subqueries or CTEs
fix_syntax β€” fix syntax errors (brackets, commas, keywords)
change_dialect β€” convert to SQLite-compatible functions
relax_filter β€” broaden or remove overly strict WHERE conditions
Reply with ONLY the action name. No explanation. No punctuation.
Example: fix_column
""").strip()
# ── Logging ───────────────────────────────────────────────────────────────────
# Hard bounds: every score/reward we ever emit is clamped to this closed range.
# 0.05 margin guarantees that :.2f and :.3f formatting never produces
# "0.00", "0.000", "1.00", or "1.000" (all of which parse as exactly 0.0 / 1.0).
_MIN_SCORE = 0.05
_MAX_SCORE = 0.95
def _safe_score(x) -> float:
"""Coerce anything (None, NaN, str, bool, int, float) to a float strictly in (0, 1)."""
try:
if x is None:
return _MIN_SCORE
if isinstance(x, bool):
return _MAX_SCORE if x else _MIN_SCORE
v = float(x)
if v != v: # NaN
return _MIN_SCORE
if v == float("inf"):
return _MAX_SCORE
if v == float("-inf"):
return _MIN_SCORE
except (TypeError, ValueError):
return _MIN_SCORE
return max(_MIN_SCORE, min(_MAX_SCORE, v))
def log_start(task: str, model: str) -> None:
print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
def log_step(step: int, action: str, reward, done: bool, error: Optional[str]) -> None:
r = _safe_score(reward)
error_val = (error or "null")
if hasattr(error_val, "replace"):
error_val = error_val.replace("\n", " ").strip() or "null"
done_val = str(bool(done)).lower()
print(
f"[STEP] step={int(step)} action={action or 'noop'} reward={r:.2f} "
f"done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score, rewards: List) -> None:
s = _safe_score(score)
safe_rewards = [_safe_score(r) for r in (rewards or [])]
if not safe_rewards:
safe_rewards = [_MIN_SCORE]
rewards_str = ",".join(f"{r:.2f}" for r in safe_rewards)
print(
f"[END] success={str(bool(success)).lower()} steps={int(steps)} "
f"score={s:.3f} rewards={rewards_str}",
flush=True,
)
# ── LLM helper ────────────────────────────────────────────────────────────────
def pick_action(
client: OpenAI,
obs: Observation,
step: int,
) -> str:
"""Ask the LLM to pick a repair action given the current observation."""
if step == 1 or obs.current_sql is None:
return "generate"
user_msg = textwrap.dedent(f"""
Question: {obs.question}
Current SQL (failed):
{obs.current_sql}
Error: {obs.error_message or "unknown"}
Error class: {obs.error_class or "unknown"}
Attempt number: {obs.attempt_number} of {obs.max_attempts}
Which repair action should I use next?
""").strip()
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
raw = (completion.choices[0].message.content or "").strip().lower()
# Normalise to valid action name
for action in REPAIR_ACTIONS:
if action in raw:
return action
return "rewrite_full"
except Exception as exc:
print(f"[DEBUG] LLM call failed: {exc}", flush=True)
return "rewrite_full"
# ── Single-episode runner ─────────────────────────────────────────────────────
async def run_episode(
env: SQLAgentEnv,
client: OpenAI,
task_id: str,
) -> None:
"""Run one full episode for a task, emitting structured stdout logs."""
log_start(task=task_id, model=MODEL_NAME)
rewards: List[float] = []
steps_taken = 0
score = _MIN_SCORE
success = False
last_error: Optional[str] = None
try:
try:
obs = env.reset(task_id)
except Exception as exc:
log_step(step=1, action="reset", reward=_MIN_SCORE, done=True, error=str(exc))
rewards.append(_MIN_SCORE)
steps_taken = 1
return
for step in range(1, MAX_STEPS + 1):
try:
action_name = pick_action(client, obs, step)
except Exception:
action_name = "generate"
action = Action(repair_action=action_name)
try:
obs, reward_info = await env.step(action)
except Exception as exc:
log_step(step=step, action=action_name, reward=_MIN_SCORE, done=True, error=str(exc))
rewards.append(_MIN_SCORE)
steps_taken = step
break
reward = _safe_score(getattr(reward_info, "value", None))
done = bool(getattr(reward_info, "done", False))
last_error = getattr(obs, "error_message", None)
success = bool(getattr(reward_info, "success", False))
rewards.append(reward)
steps_taken = step
log_step(
step=step,
action=action_name,
reward=reward,
done=done,
error=last_error,
)
if done:
break
denom = max(len(rewards), 1)
avg = sum(rewards) / denom if rewards else _MIN_SCORE
score = _safe_score(avg)
except Exception as exc:
# Catch-all so we always emit a valid [END] line
log_step(step=steps_taken or 1, action="error", reward=_MIN_SCORE, done=True, error=str(exc))
if not rewards:
rewards.append(_MIN_SCORE)
score = _MIN_SCORE
finally:
log_end(
success=success,
steps=max(int(steps_taken), 1),
score=score,
rewards=rewards,
)
# ── Main ──────────────────────────────────────────────────────────────────────
async def main() -> None:
try:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = SQLAgentEnv()
except Exception as exc:
# Environment couldn't init β€” still emit a valid [START]/[STEP]/[END] per task
for task_id in TASKS:
log_start(task=task_id, model=MODEL_NAME)
log_step(step=1, action="init_error", reward=_MIN_SCORE, done=True, error=str(exc))
log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE])
print("", flush=True)
return
for task_id in TASKS:
try:
await run_episode(env, client, task_id)
except Exception as exc:
# run_episode already has its own catch-all, but guard against anything leaking
log_end(success=False, steps=1, score=_MIN_SCORE, rewards=[_MIN_SCORE])
print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", flush=True)
print("", flush=True)
if __name__ == "__main__":
asyncio.run(main())