SimranShaikh commited on
Commit
cfbd548
Β·
verified Β·
1 Parent(s): 998a566
Files changed (1) hide show
  1. inference.py +302 -0
inference.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py β€” Baseline inference script for CodeReview-Env.
3
+
4
+ Runs an LLM agent through all 3 tasks and logs results in the
5
+ mandatory [START] / [STEP] / [END] format required by OpenEnv evaluators.
6
+
7
+ Environment variables required:
8
+ API_BASE_URL β€” LLM API base URL (OpenAI-compatible)
9
+ MODEL_NAME β€” model identifier (e.g. gpt-4o-mini)
10
+ HF_TOKEN β€” Hugging Face / API key
11
+ SPACE_URL β€” URL of deployed HF Space (e.g. https://my-space.hf.space)
12
+ defaults to http://localhost:7860
13
+ """
14
+
15
+ import json
16
+ import os
17
+ import sys
18
+ import time
19
+ from typing import Any, Dict, List, Optional
20
+
21
+ import httpx
22
+ from openai import OpenAI
23
+
24
+ # ─────────────────────────────────────────────────────────────
25
+ # Config
26
+ # ─────────────────────────────────────────────────────────────
27
+ API_BASE_URL: str = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
28
+ MODEL_NAME: str = os.environ.get("MODEL_NAME", "gpt-4o-mini")
29
+ API_KEY: str = os.environ.get("HF_TOKEN", os.environ.get("OPENAI_API_KEY", ""))
30
+ SPACE_URL: str = os.environ.get("SPACE_URL", "http://localhost:7860").rstrip("/")
31
+
32
+ BENCHMARK = "CodeReview-Env"
33
+ MAX_TOKENS = 1024
34
+ SUCCESS_SCORE_THRESHOLD = 0.6
35
+
36
+ TASKS = ["easy_syntax", "medium_logic", "hard_security"]
37
+
38
+
39
+ # ─────────────────────────────────────────────────────────────
40
+ # Structured stdout logging (MANDATORY format)
41
+ # ─────────────────────────────────────────────────────────────
42
+
43
+ def log_start(task: str, env: str, model: str) -> None:
44
+ print(
45
+ json.dumps({"type": "START", "task": task, "env": env, "model": model}),
46
+ flush=True,
47
+ )
48
+
49
+
50
+ def log_step(
51
+ step: int,
52
+ action: Any,
53
+ reward: float,
54
+ done: bool,
55
+ error: Optional[str] = None,
56
+ ) -> None:
57
+ print(
58
+ json.dumps(
59
+ {
60
+ "type": "STEP",
61
+ "step": step,
62
+ "action": str(action)[:300], # truncate for readability
63
+ "reward": reward,
64
+ "done": done,
65
+ "error": error,
66
+ }
67
+ ),
68
+ flush=True,
69
+ )
70
+
71
+
72
+ def log_end(
73
+ success: bool, steps: int, score: float, rewards: List[float]
74
+ ) -> None:
75
+ print(
76
+ json.dumps(
77
+ {
78
+ "type": "END",
79
+ "success": success,
80
+ "steps": steps,
81
+ "score": score,
82
+ "rewards": rewards,
83
+ }
84
+ ),
85
+ flush=True,
86
+ )
87
+
88
+
89
+ # ─────────────────────────────────────────────────────────────
90
+ # Environment HTTP client (thin wrapper around the HF Space API)
91
+ # ─────────────────────────────────────────────────────────────
92
+
93
+ class CodeReviewEnvClient:
94
+ def __init__(self, base_url: str) -> None:
95
+ self.base_url = base_url
96
+ self.client = httpx.Client(timeout=60.0)
97
+
98
+ def reset(self, task_id: str) -> Dict:
99
+ r = self.client.post(f"{self.base_url}/reset", params={"task_id": task_id})
100
+ r.raise_for_status()
101
+ return r.json()
102
+
103
+ def step(self, action_payload: Dict) -> Dict:
104
+ r = self.client.post(f"{self.base_url}/step", json=action_payload)
105
+ r.raise_for_status()
106
+ return r.json()
107
+
108
+ def state(self) -> Dict:
109
+ r = self.client.get(f"{self.base_url}/state")
110
+ r.raise_for_status()
111
+ return r.json()
112
+
113
+ def close(self) -> None:
114
+ self.client.close()
115
+
116
+
117
+ # ─────────────────────────────────────────────────────────────
118
+ # Agent: LLM-powered code reviewer
119
+ # ─────────────────────────────────────────────────────────────
120
+
121
+ SYSTEM_PROMPT = """\
122
+ You are an expert software engineer specialising in code review, debugging, \
123
+ and security auditing. You will be shown a code snippet along with a task \
124
+ description. Your job is to:
125
+
126
+ 1. Carefully analyse the code.
127
+ 2. Identify ALL bugs, logic errors, and security vulnerabilities.
128
+ 3. Return a structured JSON action in EXACTLY the following format:
129
+
130
+ {
131
+ "identified_issues": [
132
+ {
133
+ "line_number": <int or null>,
134
+ "issue_type": "<syntax_error|logic_bug|security_vulnerability|performance|style>",
135
+ "description": "<clear description of the issue>",
136
+ "severity": "<low|medium|high|critical>"
137
+ }
138
+ ],
139
+ "suggested_fix": "<complete corrected code as a string, or null>",
140
+ "explanation": "<brief explanation of all findings>",
141
+ "done": true
142
+ }
143
+
144
+ Output ONLY the JSON object β€” no prose, no markdown fences.
145
+ """
146
+
147
+
148
+ def build_user_message(obs: Dict, step: int, prev_feedback: Optional[str]) -> str:
149
+ parts = [
150
+ f"Task: {obs['task_name']} ({obs['difficulty']})",
151
+ f"Language: {obs['language']}",
152
+ f"Context: {obs['context']}",
153
+ "",
154
+ "Code to review:",
155
+ "```",
156
+ obs["code_snippet"],
157
+ "```",
158
+ f"(Step {step}/{obs['max_steps']})",
159
+ ]
160
+ if prev_feedback:
161
+ parts += ["", "Previous grader feedback:", prev_feedback]
162
+ return "\n".join(parts)
163
+
164
+
165
+ def call_llm(llm_client: OpenAI, user_message: str) -> str:
166
+ try:
167
+ completion = llm_client.chat.completions.create(
168
+ model=MODEL_NAME,
169
+ messages=[
170
+ {"role": "system", "content": SYSTEM_PROMPT},
171
+ {"role": "user", "content": user_message},
172
+ ],
173
+ max_tokens=MAX_TOKENS,
174
+ temperature=0.2,
175
+ )
176
+ return (completion.choices[0].message.content or "{}").strip()
177
+ except Exception as exc:
178
+ print(f"[DEBUG] LLM call failed: {exc}", flush=True)
179
+ # Fallback minimal action
180
+ return json.dumps({
181
+ "identified_issues": [],
182
+ "suggested_fix": None,
183
+ "explanation": "LLM call failed",
184
+ "done": True,
185
+ })
186
+
187
+
188
+ def parse_action(raw: str) -> Dict:
189
+ """Parse LLM output to action dict. Tolerates minor formatting issues."""
190
+ raw = raw.strip()
191
+ # Strip markdown code fences if present
192
+ if raw.startswith("```"):
193
+ raw = raw.split("```")[1]
194
+ if raw.startswith("json"):
195
+ raw = raw[4:]
196
+ try:
197
+ return json.loads(raw)
198
+ except json.JSONDecodeError:
199
+ return {
200
+ "identified_issues": [],
201
+ "suggested_fix": None,
202
+ "explanation": raw[:500],
203
+ "done": True,
204
+ }
205
+
206
+
207
+ # ─────────────────────────────────────────────────────────────
208
+ # Main: run agent on all tasks
209
+ # ─────────────────────────────────────────────────────────────
210
+
211
+ def run_task(
212
+ task_id: str,
213
+ env_client: CodeReviewEnvClient,
214
+ llm_client: OpenAI,
215
+ ) -> float:
216
+ """Run one full episode and return the episode score [0, 1]."""
217
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
218
+
219
+ result = env_client.reset(task_id=task_id)
220
+ obs = result["observation"]
221
+
222
+ rewards: List[float] = []
223
+ steps_taken = 0
224
+ score = 0.0
225
+ success = False
226
+ max_steps = obs["max_steps"]
227
+
228
+ try:
229
+ prev_feedback: Optional[str] = None
230
+
231
+ for step in range(1, max_steps + 1):
232
+ user_msg = build_user_message(obs, step, prev_feedback)
233
+ raw_action = call_llm(llm_client, user_msg)
234
+ action_dict = parse_action(raw_action)
235
+
236
+ step_result = env_client.step(action_dict)
237
+
238
+ reward = float(step_result.get("reward", 0.0))
239
+ done = bool(step_result.get("done", False))
240
+ info = step_result.get("info", {})
241
+ prev_feedback = info.get("feedback")
242
+
243
+ rewards.append(reward)
244
+ steps_taken = step
245
+
246
+ log_step(step=step, action=action_dict.get("explanation", ""), reward=reward, done=done)
247
+
248
+ obs = step_result["observation"]
249
+
250
+ if done:
251
+ break
252
+
253
+ # Score = best single-step reward (agent submits full review each step)
254
+ score = max(rewards) if rewards else 0.0
255
+ score = min(max(score, 0.0), 1.0)
256
+ success = score >= SUCCESS_SCORE_THRESHOLD
257
+
258
+ finally:
259
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
260
+
261
+ return score
262
+
263
+
264
+ def main() -> None:
265
+ llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
266
+ env_client = CodeReviewEnvClient(SPACE_URL)
267
+
268
+ # Wait for server to be ready (useful when running right after docker start)
269
+ for attempt in range(10):
270
+ try:
271
+ env_client.client.get(f"{SPACE_URL}/health").raise_for_status()
272
+ break
273
+ except Exception:
274
+ print(f"[DEBUG] Waiting for server... attempt {attempt+1}/10", flush=True)
275
+ time.sleep(3)
276
+ else:
277
+ print("[ERROR] Server did not become ready. Exiting.", flush=True)
278
+ sys.exit(1)
279
+
280
+ task_scores: Dict[str, float] = {}
281
+ for task_id in TASKS:
282
+ print(f"\n{'='*60}", flush=True)
283
+ print(f"Running task: {task_id}", flush=True)
284
+ print("=" * 60, flush=True)
285
+ task_scores[task_id] = run_task(task_id, env_client, llm_client)
286
+ time.sleep(1)
287
+
288
+ env_client.close()
289
+
290
+ # Summary
291
+ print("\n" + "=" * 60, flush=True)
292
+ print("FINAL SCORES", flush=True)
293
+ print("=" * 60, flush=True)
294
+ for task_id, s in task_scores.items():
295
+ status = "βœ… PASS" if s >= SUCCESS_SCORE_THRESHOLD else "❌ FAIL"
296
+ print(f" {task_id:25s}: {s:.4f} {status}", flush=True)
297
+ overall = sum(task_scores.values()) / len(task_scores)
298
+ print(f"\n Overall average: {overall:.4f}", flush=True)
299
+
300
+
301
+ if __name__ == "__main__":
302
+ main()