ai_backend / inference.py
sumit989's picture
Update inference.py
97ae92b verified
import os
import json
from openai import OpenAI
# ── Config ────────────────────────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
BENCHMARK = os.getenv("BENCHMARK", "code-fix-env")
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
if not API_KEY:
raise ValueError("Missing API_KEY or HF_TOKEN environment variable")
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# ── Shared grader ─────────────────────────────────────────────────────────────
def grader(task_id, output):
if not output:
return 0.0
try:
data = json.loads(output)
score = 0.5
if data.get("fixed_code"):
score += 0.2
if data.get("explanation"):
score += 0.2
if data.get("language"):
score += 0.1
return round(min(score, 1.0), 2)
except Exception:
return 0.2
# ── Tasks ─────────────────────────────────────────────────────────────────────
TASKS = [
{
"id": "task_1",
"input": "def add(a,b): return a-b",
"expected": "fix subtraction bug β€” should return a+b",
"grader": grader,
},
{
"id": "task_2",
"input": "function x() { return 1+ }",
"expected": "fix syntax error β€” incomplete expression",
"grader": grader,
},
{
"id": "task_3",
"input": "async function f(){ fetchData() }",
"expected": "fix missing await before fetchData()",
"grader": grader,
},
]
# ── Solver ────────────────────────────────────────────────────────────────────
def solve(task: dict) -> str:
task_name = task["id"]
task_input = task["input"]
print(f"[START] task={task_name} env={BENCHMARK} model={MODEL_NAME}", flush=True)
try:
response = client.chat.completions.create(
model=MODEL_NAME,
max_tokens=2048,
messages=[
{
"role": "system",
"content": (
"You are an expert developer.\n"
"Return ONLY valid JSON β€” no markdown, no preamble.\n"
"Explanation must be MAX 2 lines.\n"
"Fixed code must be SHORT and COMPLETE.\n"
"Preserve all newlines and indentation in fixed_code.\n"
"Format strictly:\n"
"{\"explanation\":\"...\",\"fixed_code\":\"...\",\"language\":\"...\"}"
),
},
{
"role": "user",
"content": (
f"Expected fix: {task['expected']}\n"
f"Fix this code:\n{task_input}"
),
},
],
)
output = response.choices[0].message.content
score = task["grader"](task_name, output)
print(f"[STEP] step=1 action=solve reward={round(score, 2)} done=false error=null", flush=True)
print(f"[STEP] step=2 action=grade reward={round(score, 2)} done=true error=null", flush=True)
print(f"[END] success=true steps=2 score={score} rewards={score},{score}", flush=True)
return output
except Exception as e:
err = str(e).replace("\n", " ")
print(f"[STEP] step=1 action=solve reward=0.00 done=true error={err}", flush=True)
print(f"[END] success=false steps=1 score=0.0 rewards=0.00", flush=True)
return f"Error: {err}"
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
solve(TASKS[0])