Spaces:
Sleeping
Sleeping
| """Baseline inference script for the Python code-review environment.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from openai import OpenAI | |
| from client import PythonEnv | |
| from models import ActionType, PythonReviewAction | |
| # Read all runtime configuration from environment variables so the script can | |
| # be reused unchanged across local runs, CI, and HF Spaces validation. | |
| API_BASE_URL = os.environ["API_BASE_URL"] | |
| MODEL_NAME = os.environ["MODEL_NAME"] | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL") | |
| DOCKER_IMAGE = os.getenv("PYTHON_ENV_IMAGE", "python_env-env:latest") | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "25")) | |
| REPORT_PATH = Path(os.getenv("INFERENCE_REPORT_PATH", "inference_results.json")) | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0")) | |
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "900")) | |
| TASK_IDS = ["task_easy", "task_medium", "task_hard"] | |
| SYSTEM_PROMPT = """You are a precise senior Python code reviewer. | |
| Return strict JSON using this schema: | |
| { | |
| "action_type": "ADD_COMMENT|APPROVE|REQUEST_CHANGES|ASK_CONTEXT|SKIP_LINE", | |
| "line_number": 1, | |
| "issue_type": "STYLE|LOGIC|SECURITY|PERFORMANCE|DOCS", | |
| "severity": "LOW|MEDIUM|HIGH|CRITICAL", | |
| "comment": "why this matters", | |
| "suggestion": "optional fix suggestion", | |
| "question": "optional context question" | |
| } | |
| Rules: | |
| - Output JSON only. No markdown fences. | |
| - Only report issues supported by the visible code. | |
| - Use one action per step. | |
| - Prefer high precision over quantity. | |
| - Use REQUEST_CHANGES once you believe the code should be rejected. | |
| - Use APPROVE only when the snippet is genuinely clean. | |
| """ | |
| def _build_prompt(observation, step: int, history: List[str]) -> str: | |
| """Build the task prompt sent to the model for one step.""" | |
| numbered_lines = "\n".join( | |
| f"{index + 1:>3}: {line}" for index, line in enumerate(observation.lines) | |
| ) | |
| history_text = "\n".join(history[-4:]) if history else "No previous attempts." | |
| return ( | |
| f"Task ID: {observation.task_id}\n" | |
| f"Step: {step}\n" | |
| f"Current score: {observation.metrics.current_score:.2f}\n" | |
| f"Last reward: {observation.reward_summary.step_reward:.2f}\n" | |
| f"Cumulative reward: {observation.reward_summary.cumulative_reward:.2f}\n" | |
| f"Latest feedback: {observation.feedback or 'None'}\n" | |
| f"Attempt history:\n{history_text}\n\n" | |
| f"Filename: {observation.filename}\n" | |
| f"Context: {observation.context or 'None'}\n" | |
| "Code to review:\n" | |
| f"{numbered_lines}" | |
| ) | |
| def _extract_text_content(message_content: Any) -> str: | |
| """Normalize OpenAI response content into one text string.""" | |
| if isinstance(message_content, str): | |
| return message_content | |
| if isinstance(message_content, list): | |
| parts: List[str] = [] | |
| for item in message_content: | |
| if isinstance(item, dict): | |
| text = item.get("text") | |
| if isinstance(text, str): | |
| parts.append(text) | |
| return "\n".join(parts) | |
| return "" | |
| def _extract_json_blob(content: str) -> str: | |
| """Extract a JSON object from plain or fenced model output.""" | |
| fenced_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", content, re.DOTALL) | |
| if fenced_match: | |
| return fenced_match.group(1) | |
| start = content.find("{") | |
| end = content.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| return content[start : end + 1] | |
| return content | |
| def _parse_response(content: str) -> Dict[str, Any]: | |
| """Parse the model response into a normalized payload dict.""" | |
| raw = _extract_json_blob(content) | |
| try: | |
| data = json.loads(raw) | |
| except json.JSONDecodeError: | |
| return {"_parse_error": raw} | |
| return data | |
| def _completion(client: OpenAI, prompt: str) -> Dict[str, Any]: | |
| """Send one completion request to the configured model endpoint.""" | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| content = _extract_text_content(response.choices[0].message.content) or "{}" | |
| return _parse_response(content) | |
| def _build_fallback_action(observation, note: str) -> PythonReviewAction: | |
| """Create a safe fallback action when model output is unusable.""" | |
| return PythonReviewAction( | |
| action_type=ActionType.REQUEST_CHANGES | |
| if observation.current_step + 1 >= observation.max_steps | |
| else ActionType.ASK_CONTEXT, | |
| question=note if observation.current_step + 1 < observation.max_steps else None, | |
| ) | |
| def _to_action( | |
| payload: Dict[str, Any], | |
| observation, | |
| ) -> PythonReviewAction: | |
| """Convert a parsed model payload into a valid environment action.""" | |
| try: | |
| return PythonReviewAction.model_validate(payload) | |
| except Exception: | |
| note = "Model returned no valid action." | |
| if payload.get("_parse_error"): | |
| note = f"{note} Raw response could not be parsed as JSON." | |
| return _build_fallback_action(observation, note) | |
| def _make_env(): | |
| """Connect to a live environment or launch the Docker image.""" | |
| if ENV_BASE_URL: | |
| return PythonEnv(base_url=ENV_BASE_URL).sync() | |
| return asyncio.run(PythonEnv.from_docker_image(DOCKER_IMAGE)).sync() | |
| def _task_result_dict(observation, step_logs: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Build the report payload for one completed task run.""" | |
| return { | |
| "task_id": observation.task_id, | |
| "snippet_id": observation.snippet_id, | |
| "score": observation.metrics.current_score, | |
| "precision": observation.metrics.precision, | |
| "recall": observation.metrics.recall, | |
| "f1": observation.metrics.f1, | |
| "true_positives": observation.metrics.true_positives, | |
| "false_positives": observation.metrics.false_positives, | |
| "missed_issues": observation.metrics.missed_issues, | |
| "cumulative_reward": observation.metrics.cumulative_reward, | |
| "steps": step_logs, | |
| } | |
| def main() -> None: | |
| """Run the configured model against the benchmark task set.""" | |
| if not API_KEY: | |
| raise RuntimeError("Set HF_TOKEN or OPENAI_API_KEY before running inference.py") | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| env = _make_env() | |
| episode_results: List[Dict[str, Any]] = [] | |
| try: | |
| for index, task_id in enumerate(TASK_IDS, start=1): | |
| result = env.reset(task_id=task_id) | |
| observation = result.observation | |
| history: List[str] = [] | |
| step_logs: List[Dict[str, Any]] = [] | |
| print(f"Task {index}: {task_id} ({observation.snippet_id})") | |
| for step in range(1, MAX_STEPS + 1): | |
| prompt = _build_prompt(observation, step, history) | |
| try: | |
| payload = _completion(client, prompt) | |
| except Exception as exc: | |
| payload = {"_error": str(exc)} | |
| action = _to_action(payload=payload, observation=observation) | |
| result = env.step(action) | |
| observation = result.observation | |
| step_log = { | |
| "step": step, | |
| "action_type": action.action_type.value, | |
| "line_number": action.line_number, | |
| "reward": result.reward or 0.0, | |
| "score": observation.metrics.current_score, | |
| "done": result.done, | |
| "feedback": observation.feedback, | |
| } | |
| if payload.get("_error"): | |
| step_log["model_error"] = payload["_error"] | |
| if payload.get("_parse_error"): | |
| step_log["parse_error"] = True | |
| step_logs.append(step_log) | |
| history.append( | |
| f"step={step} action={action.action_type.value} " | |
| f"line={action.line_number} score={observation.metrics.current_score:.2f} " | |
| f"reward={(result.reward or 0.0):.2f} feedback={observation.feedback}" | |
| ) | |
| print( | |
| f" step={step} action={action.action_type.value} " | |
| f"score={observation.metrics.current_score:.2f} reward={(result.reward or 0.0):.2f} " | |
| f"done={result.done}" | |
| ) | |
| if result.done: | |
| break | |
| episode_results.append(_task_result_dict(observation, step_logs)) | |
| finally: | |
| env.close() | |
| mean_score = sum(item["score"] for item in episode_results) / len(episode_results) if episode_results else 0.0 | |
| summary = { | |
| "model_name": MODEL_NAME, | |
| "api_base_url": API_BASE_URL, | |
| "task_count": len(episode_results), | |
| "mean_score": mean_score, | |
| "results": episode_results, | |
| } | |
| REPORT_PATH.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| print(json.dumps(summary, indent=2)) | |
| print(f"\nSaved report to {REPORT_PATH}") | |
| if __name__ == "__main__": | |
| main() | |