code-review-env / inference.py
codemaverick2
Code Review Environment OpenEnv hackathon submission
ff9fcbd
"""
Inference script for the Code Review Environment.
Environment variables:
API_BASE_URL β€” LLM API endpoint (e.g. https://openrouter.ai/api/v1)
MODEL_NAME β€” Model identifier (e.g. openai/gpt-4o-mini)
HF_TOKEN β€” API key for the LLM provider
ENV_URL β€” Environment base URL (default: localhost:7860)
Usage:
export API_BASE_URL=https://openrouter.ai/api/v1
export MODEL_NAME=openai/gpt-4o-mini
export HF_TOKEN=sk-...
python inference.py
"""
from __future__ import annotations
import os
import sys
import json
import time
import httpx
API_BASE_URL: str = os.environ.get("API_BASE_URL", "").rstrip("/")
MODEL_NAME: str = os.environ.get("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN: str = os.environ.get("HF_TOKEN", "")
ENV_URL: str = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
TASK_IDS = ["bug-detection", "security-audit", "comprehensive-review"]
SYSTEM_PROMPT = """\
You are an expert software engineer performing a thorough code review.
Your job is to identify bugs, security vulnerabilities, and performance issues in code.
For each issue you find, respond with a single JSON object:
{"action_type": "flag_issue", "line_number": <int>, "filename": "<file>", "issue_type": "bug|security|performance|logic", "severity": "low|medium|high|critical", "description": "<explanation>", "fix_suggestion": "<fix>"}
When done, respond with:
{"action_type": "submit_review"}
Rules:
- Respond with raw JSON only β€” no markdown fences, no extra text
- One action per response
- Be precise with line numbers (count from line 1)
- Only flag real issues, not style preferences
"""
def chat_completion(messages: list) -> str:
try:
from openai import OpenAI
except ImportError:
raise ImportError("pip install openai")
kwargs = {"api_key": HF_TOKEN or "no-key"}
if API_BASE_URL:
kwargs["base_url"] = API_BASE_URL
client = OpenAI(**kwargs)
response = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.0,
max_tokens=400,
)
return response.choices[0].message.content.strip()
def parse_action(text: str) -> dict:
text = text.strip()
if "```" in text:
parts = text.split("```")
for part in parts:
part = part.strip()
if part.startswith("json"):
part = part[4:].strip()
if part.startswith("{") or part.startswith("["):
text = part
break
decoder = json.JSONDecoder()
for i, ch in enumerate(text):
if ch in ("{", "["):
try:
obj, _ = decoder.raw_decode(text, i)
if isinstance(obj, dict):
return obj
if isinstance(obj, list):
for item in obj:
if isinstance(item, dict):
return item
except json.JSONDecodeError:
continue
return {"action_type": "submit_review"}
def run_keyword_fallback(base_url: str, task_id: str) -> dict:
"""Fallback: use the built-in /baseline endpoint (no LLM needed)."""
with httpx.Client(timeout=30) as client:
resp = client.post(f"{base_url}/baseline")
resp.raise_for_status()
results = resp.json()
score = results["baseline_scores"].get(task_id, {}).get("score", 0.0)
return {"task_id": task_id, "score": score, "steps": 0, "method": "keyword_heuristic"}
def run_task(task_id: str, http_client: httpx.Client) -> dict:
resp = http_client.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=30)
resp.raise_for_status()
obs = resp.json()
code_display = "\n\n".join(
f"=== {fname} ===\n{code}"
for fname, code in obs.get("code_files", {}).items()
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": (
f"Task: {obs.get('task_description', '')}\n\n"
f"{code_display}\n\n"
f"Review this code carefully. Flag every issue you find. "
f"You have {obs.get('max_steps', 20)} steps total."
),
},
]
done = False
step_count = 0
max_steps = obs.get("max_steps", 20)
final_score = 0.0
while not done and step_count < max_steps:
action_text = chat_completion(messages)
action = parse_action(action_text)
try:
step_resp = http_client.post(f"{ENV_URL}/step", json=action, timeout=30)
step_resp.raise_for_status()
obs = step_resp.json()
except Exception as e:
print(f" Step error: {e}")
break
done = obs.get("done", False)
step_count += 1
final_score = obs.get("current_score", 0.0)
reward = obs.get("reward")
messages.append({"role": "assistant", "content": action_text})
messages.append({
"role": "user",
"content": (
f"Feedback: {obs.get('feedback', '')} "
f"(step {step_count}/{max_steps}, score: {obs.get('current_score', 0.0):.3f})"
),
})
atype = action.get("action_type", "")
print(f" Step {step_count:2d}: {atype:20s} | reward={str(reward):8s} | score={obs.get('current_score', 0.0):.3f}")
if atype == "submit_review":
final_score = obs.get("reward", obs.get("current_score", 0.0)) or 0.0
break
time.sleep(0.3)
return {
"task_id": task_id,
"score": float(final_score),
"steps": step_count,
"method": "llm",
}
def main():
use_llm = bool(HF_TOKEN and API_BASE_URL)
print("Code Review Environment β€” Inference")
print(f" Model : {MODEL_NAME}")
print(f" API URL : {API_BASE_URL or '(not set β€” using keyword heuristic)'}")
print(f" Env URL : {ENV_URL}")
print(f" Tasks : {TASK_IDS}\n")
try:
with httpx.Client(timeout=10) as probe:
health = probe.get(f"{ENV_URL}/health")
health.raise_for_status()
print(f" Health: {health.json()}\n")
except Exception as e:
print(f"ERROR: Cannot reach environment at {ENV_URL}: {e}")
sys.exit(1)
results = {}
if use_llm:
with httpx.Client(timeout=60) as client:
for task_id in TASK_IDS:
print(f"Running task: {task_id}")
result = run_task(task_id, client)
results[task_id] = result
print(f" β†’ score: {result['score']:.4f} ({result['steps']} steps)\n")
else:
print("HF_TOKEN / API_BASE_URL not set β€” using built-in keyword heuristic baseline.\n")
for task_id in TASK_IDS:
print(f"Running task: {task_id}")
result = run_keyword_fallback(ENV_URL, task_id)
results[task_id] = result
print(f" β†’ score: {result['score']:.4f}\n")
print("=" * 50)
print("INFERENCE RESULTS")
print("=" * 50)
for task_id, r in results.items():
print(f" {task_id:30s} score={r['score']:.4f}")
overall = sum(r["score"] for r in results.values()) / len(results)
print(f"\n Overall average: {overall:.4f}")
print("=" * 50)
return results
if __name__ == "__main__":
main()