| """ |
| inference.py -- AuditRepairEnv++ Baseline Inference Agent |
| ========================================================= |
| OpenEnv Submission | Cost-Constrained Ledger Repair |
| |
| STDOUT format (strict -- must match exactly): |
| |
| [START] |
| Task: easy |
| |
| [STEP] |
| Action: FIX_ENTRY 1 |
| Reward: 0.2 |
| |
| [END] |
| Final Score: 0.85 |
| |
| Uses OpenAI Client for LLM calls. |
| Reads env variables: API_BASE_URL, MODEL_NAME, HF_TOKEN |
| Runs all tasks: easy, medium, hard |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| import textwrap |
| import urllib.request |
| import urllib.error |
| from typing import List, Optional |
|
|
| from openai import OpenAI |
|
|
|
|
| |
| |
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| API_KEY = HF_TOKEN or os.getenv("API_KEY", "") |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") |
|
|
| |
| if not HF_TOKEN: |
| raise ValueError( |
| "HF_TOKEN environment variable is required. " |
| "Set it via: export HF_TOKEN='your_token_here'" |
| ) |
| if not API_KEY: |
| raise ValueError( |
| "API_KEY environment variable must be set (or HF_TOKEN)" |
| ) |
|
|
| |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") |
|
|
| BENCHMARK = "auditrepairenv" |
| TASKS = ["easy", "medium", "hard"] |
|
|
| MAX_STEPS = 15 |
| MAX_TOTAL_REWARD = 2.0 |
| SUCCESS_SCORE_THRESHOLD = 0.5 |
| TEMPERATURE = 0.2 |
| MAX_TOKENS = 300 |
|
|
|
|
| |
| |
| |
| def log_start(task: str) -> None: |
| print(f"\n[START]\nTask: {task}", flush=True) |
|
|
|
|
| def log_step(action: str, reward: float) -> None: |
| action_clean = action.replace("\n", " ").replace("\r", "").strip()[:200] |
| print(f"\n[STEP]\nAction: {action_clean}\nReward: {reward}", flush=True) |
|
|
|
|
| def log_end(score: float) -> None: |
| print(f"\n[END]\nFinal Score: {score}", flush=True) |
|
|
|
|
| |
| |
| |
| def env_request(path: str, method: str = "GET", body: dict = None) -> dict: |
| url = ENV_BASE_URL.rstrip("/") + path |
| data = json.dumps(body or {}).encode() if body is not None else b"{}" |
| req = urllib.request.Request( |
| url, data=data, method=method, |
| headers={"Content-Type": "application/json"} |
| ) |
| try: |
| with urllib.request.urlopen(req, timeout=30) as r: |
| return json.loads(r.read().decode()) |
| except urllib.error.HTTPError as e: |
| return {"error": f"HTTP {e.code}: {e.read().decode()[:100]}"} |
| except Exception as ex: |
| return {"error": str(ex)} |
|
|
|
|
| def env_reset(task_id: str) -> dict: |
| return env_request("/reset", "POST", {"task_id": task_id}) |
|
|
|
|
| def env_step(message: str) -> dict: |
| return env_request("/step", "POST", {"message": message}) |
|
|
|
|
| |
| |
| |
| SYSTEM_PROMPT = textwrap.dedent(""" |
| You are AuditRepairAgent -- an AI that repairs financial ledger inconsistencies. |
| |
| You are given a ledger with entries that may have errors (value != expected_value). |
| Each entry has an id, value, expected_value, and dependencies list. |
| |
| Available actions (respond with exactly ONE per step): |
| FIX_ENTRY <id> -- Sets value = expected_value. May trigger dependency changes. |
| ADJUST_ENTRY <id> <delta> -- Increment/decrement the entry's value by delta. |
| REVERT_ENTRY <id> -- Undo the last change to an entry. |
| NO_OP -- Do nothing. |
| |
| Rules: |
| 1. Each action costs budget. Minimize total actions. |
| 2. Fixing an already-correct entry is overcorrection (penalty). |
| 3. Dependencies: fixing one entry may change expected_value of other entries. |
| 4. Goal: fix all errors within budget. |
| |
| Respond with ONLY the action, nothing else: |
| FIX_ENTRY 3 |
| """).strip() |
|
|
|
|
| def build_prompt(obs: dict, step_num: int, last_echoed: str, |
| last_reward: float, history: List[str]) -> str: |
| """Build user prompt from the current observation.""" |
| ledger_str = "" |
| for entry in obs.get("ledger", []): |
| status = "OK" if entry["value"] == entry["expected_value"] else "ERR" |
| deps = entry.get("dependencies", []) |
| dep_str = f", deps={deps}" if deps else "" |
| ledger_str += ( |
| f" [{status}] id={entry['id']}: value={entry['value']}, " |
| f"expected={entry['expected_value']}{dep_str}\n" |
| ) |
|
|
| errors_str = "" |
| for err in obs.get("errors", []): |
| errors_str += ( |
| f" Entry {err['entry_id']}: value={err['current_value']}, " |
| f"expected={err['expected_value']}, delta={err['delta']}\n" |
| ) |
|
|
| history_block = "\n".join(history[-3:]) if history else "None" |
|
|
| return textwrap.dedent(f""" |
| Task: {obs.get('task_description', '')} |
| Step {step_num} of {obs.get('max_steps', 10)} |
| |
| Ledger: |
| {ledger_str} |
| Current Errors: |
| {errors_str if errors_str else ' None -- all entries correct!'} |
| Budget: {obs.get('remaining_budget', 0)} / {obs.get('initial_budget', 0)} |
| Last result: {last_echoed} |
| Last reward: {last_reward:+.2f} |
| History: {history_block} |
| |
| Respond with the single best action (e.g. FIX_ENTRY 3): |
| """).strip() |
|
|
|
|
| def get_model_message(client: OpenAI, step_num: int, obs: dict, |
| last_echoed: str, last_reward: float, |
| history: List[str]) -> str: |
| """Get agent action from LLM, with fallback to heuristic.""" |
| try: |
| prompt = build_prompt(obs, step_num, last_echoed, last_reward, history) |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| ], |
| max_tokens=MAX_TOKENS, |
| temperature=TEMPERATURE, |
| ) |
| response = (completion.choices[0].message.content or "").strip() |
| |
| for line in response.split("\n"): |
| line = line.strip() |
| if line and any( |
| line.upper().startswith(a) |
| for a in ["FIX_ENTRY", "ADJUST_ENTRY", "REVERT_ENTRY", "NO_OP"] |
| ): |
| return line |
| return response.split("\n")[0].strip() if response else "NO_OP" |
| except Exception: |
| |
| return _fallback_action(obs) |
|
|
|
|
| def _fallback_action(obs: dict) -> str: |
| """Deterministic fallback: fix the first error found.""" |
| errors = obs.get("errors", []) |
| if errors: |
| return f"FIX_ENTRY {errors[0]['entry_id']}" |
| return "NO_OP" |
|
|
|
|
| |
| |
| |
| def run_task(client: OpenAI, task_id: str) -> float: |
| """Run a single task episode. Returns score in [0.0, 1.0].""" |
| history: List[str] = [] |
| rewards: List[float] = [] |
| score = 0.0 |
|
|
| log_start(task=task_id) |
|
|
| try: |
| |
| result = env_reset(task_id) |
| if "error" in result: |
| log_end(score=0.0) |
| return 0.0 |
|
|
| obs = result |
| last_echoed = obs.get("echoed_message", "") |
| last_reward = 0.0 |
|
|
| max_steps = obs.get("max_steps", MAX_STEPS) |
|
|
| for step in range(1, max_steps + 1): |
| if obs.get("done", False): |
| break |
|
|
| |
| message = get_model_message( |
| client, step, obs, last_echoed, last_reward, history |
| ) |
|
|
| |
| step_result = env_step(message) |
|
|
| if "error" in step_result and "observation" not in step_result: |
| reward = 0.0 |
| done = False |
| error = step_result["error"][:80] |
| else: |
| reward = float(step_result.get("reward", 0) or 0) |
| done = bool(step_result.get("done", False)) |
| error = step_result.get("last_action_error") |
| obs = step_result.get("observation", obs) |
|
|
| rewards.append(reward) |
| last_echoed = obs.get("echoed_message", "") |
| last_reward = reward |
|
|
| log_step(action=message, reward=reward) |
|
|
| history.append(f"Step {step}: {message!r} -> reward {reward:+.2f}") |
|
|
| if done: |
| |
| info = step_result.get("info", {}) |
| final_score = info.get("final_score") |
| if final_score is not None: |
| score = float(final_score) |
| break |
|
|
| |
| if score == 0.0 and rewards: |
| score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0 |
|
|
| score = min(max(score, 0.0), 1.0) |
|
|
| except Exception: |
| pass |
|
|
| finally: |
| log_end(score=score) |
|
|
| return score |
|
|
|
|
| |
| |
| |
| async def main() -> None: |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
|
|
| for task_id in TASKS: |
| run_task(client, task_id) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |