python_env / inference.py
darshanajudiya7's picture
Upload folder using huggingface_hub
d25ab77 verified
"""Baseline inference script for the Python code-review environment."""
from __future__ import annotations
import asyncio
import json
import os
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
from openai import OpenAI
from client import PythonEnv
from models import ActionType, PythonReviewAction
# Read all runtime configuration from environment variables so the script can
# be reused unchanged across local runs, CI, and HF Spaces validation.
API_BASE_URL = os.environ["API_BASE_URL"]
MODEL_NAME = os.environ["MODEL_NAME"]
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
ENV_BASE_URL = os.getenv("ENV_BASE_URL")
DOCKER_IMAGE = os.getenv("PYTHON_ENV_IMAGE", "python_env-env:latest")
MAX_STEPS = int(os.getenv("MAX_STEPS", "25"))
REPORT_PATH = Path(os.getenv("INFERENCE_REPORT_PATH", "inference_results.json"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "900"))
TASK_IDS = ["task_easy", "task_medium", "task_hard"]
SYSTEM_PROMPT = """You are a precise senior Python code reviewer.
Return strict JSON using this schema:
{
"action_type": "ADD_COMMENT|APPROVE|REQUEST_CHANGES|ASK_CONTEXT|SKIP_LINE",
"line_number": 1,
"issue_type": "STYLE|LOGIC|SECURITY|PERFORMANCE|DOCS",
"severity": "LOW|MEDIUM|HIGH|CRITICAL",
"comment": "why this matters",
"suggestion": "optional fix suggestion",
"question": "optional context question"
}
Rules:
- Output JSON only. No markdown fences.
- Only report issues supported by the visible code.
- Use one action per step.
- Prefer high precision over quantity.
- Use REQUEST_CHANGES once you believe the code should be rejected.
- Use APPROVE only when the snippet is genuinely clean.
"""
def _build_prompt(observation, step: int, history: List[str]) -> str:
"""Build the task prompt sent to the model for one step."""
numbered_lines = "\n".join(
f"{index + 1:>3}: {line}" for index, line in enumerate(observation.lines)
)
history_text = "\n".join(history[-4:]) if history else "No previous attempts."
return (
f"Task ID: {observation.task_id}\n"
f"Step: {step}\n"
f"Current score: {observation.metrics.current_score:.2f}\n"
f"Last reward: {observation.reward_summary.step_reward:.2f}\n"
f"Cumulative reward: {observation.reward_summary.cumulative_reward:.2f}\n"
f"Latest feedback: {observation.feedback or 'None'}\n"
f"Attempt history:\n{history_text}\n\n"
f"Filename: {observation.filename}\n"
f"Context: {observation.context or 'None'}\n"
"Code to review:\n"
f"{numbered_lines}"
)
def _extract_text_content(message_content: Any) -> str:
"""Normalize OpenAI response content into one text string."""
if isinstance(message_content, str):
return message_content
if isinstance(message_content, list):
parts: List[str] = []
for item in message_content:
if isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(parts)
return ""
def _extract_json_blob(content: str) -> str:
"""Extract a JSON object from plain or fenced model output."""
fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", content, re.DOTALL)
if fenced_match:
return fenced_match.group(1)
start = content.find("{")
end = content.rfind("}")
if start != -1 and end != -1 and end > start:
return content[start : end + 1]
return content
def _parse_response(content: str) -> Dict[str, Any]:
"""Parse the model response into a normalized payload dict."""
raw = _extract_json_blob(content)
try:
data = json.loads(raw)
except json.JSONDecodeError:
return {"_parse_error": raw}
return data
def _completion(client: OpenAI, prompt: str) -> Dict[str, Any]:
"""Send one completion request to the configured model endpoint."""
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
)
content = _extract_text_content(response.choices[0].message.content) or "{}"
return _parse_response(content)
def _build_fallback_action(observation, note: str) -> PythonReviewAction:
"""Create a safe fallback action when model output is unusable."""
return PythonReviewAction(
action_type=ActionType.REQUEST_CHANGES
if observation.current_step + 1 >= observation.max_steps
else ActionType.ASK_CONTEXT,
question=note if observation.current_step + 1 < observation.max_steps else None,
)
def _to_action(
payload: Dict[str, Any],
observation,
) -> PythonReviewAction:
"""Convert a parsed model payload into a valid environment action."""
try:
return PythonReviewAction.model_validate(payload)
except Exception:
note = "Model returned no valid action."
if payload.get("_parse_error"):
note = f"{note} Raw response could not be parsed as JSON."
return _build_fallback_action(observation, note)
def _make_env():
"""Connect to a live environment or launch the Docker image."""
if ENV_BASE_URL:
return PythonEnv(base_url=ENV_BASE_URL).sync()
return asyncio.run(PythonEnv.from_docker_image(DOCKER_IMAGE)).sync()
def _task_result_dict(observation, step_logs: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Build the report payload for one completed task run."""
return {
"task_id": observation.task_id,
"snippet_id": observation.snippet_id,
"score": observation.metrics.current_score,
"precision": observation.metrics.precision,
"recall": observation.metrics.recall,
"f1": observation.metrics.f1,
"true_positives": observation.metrics.true_positives,
"false_positives": observation.metrics.false_positives,
"missed_issues": observation.metrics.missed_issues,
"cumulative_reward": observation.metrics.cumulative_reward,
"steps": step_logs,
}
def main() -> None:
"""Run the configured model against the benchmark task set."""
if not API_KEY:
raise RuntimeError("Set HF_TOKEN or OPENAI_API_KEY before running inference.py")
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = _make_env()
episode_results: List[Dict[str, Any]] = []
try:
for index, task_id in enumerate(TASK_IDS, start=1):
result = env.reset(task_id=task_id)
observation = result.observation
history: List[str] = []
step_logs: List[Dict[str, Any]] = []
print(f"Task {index}: {task_id} ({observation.snippet_id})")
for step in range(1, MAX_STEPS + 1):
prompt = _build_prompt(observation, step, history)
try:
payload = _completion(client, prompt)
except Exception as exc:
payload = {"_error": str(exc)}
action = _to_action(payload=payload, observation=observation)
result = env.step(action)
observation = result.observation
step_log = {
"step": step,
"action_type": action.action_type.value,
"line_number": action.line_number,
"reward": result.reward or 0.0,
"score": observation.metrics.current_score,
"done": result.done,
"feedback": observation.feedback,
}
if payload.get("_error"):
step_log["model_error"] = payload["_error"]
if payload.get("_parse_error"):
step_log["parse_error"] = True
step_logs.append(step_log)
history.append(
f"step={step} action={action.action_type.value} "
f"line={action.line_number} score={observation.metrics.current_score:.2f} "
f"reward={(result.reward or 0.0):.2f} feedback={observation.feedback}"
)
print(
f" step={step} action={action.action_type.value} "
f"score={observation.metrics.current_score:.2f} reward={(result.reward or 0.0):.2f} "
f"done={result.done}"
)
if result.done:
break
episode_results.append(_task_result_dict(observation, step_logs))
finally:
env.close()
mean_score = sum(item["score"] for item in episode_results) / len(episode_results) if episode_results else 0.0
summary = {
"model_name": MODEL_NAME,
"api_base_url": API_BASE_URL,
"task_count": len(episode_results),
"mean_score": mean_score,
"results": episode_results,
}
REPORT_PATH.write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(json.dumps(summary, indent=2))
print(f"\nSaved report to {REPORT_PATH}")
if __name__ == "__main__":
main()