adityanaikhpt commited on
Commit
eb60bd2
·
1 Parent(s): 9967cb5

Rewrite inference.py for strict OpenEnv parsing + add httpx

Browse files
Files changed (2) hide show
  1. inference.py +113 -127
  2. requirements.txt +1 -0
inference.py CHANGED
@@ -1,137 +1,123 @@
1
  """
2
- CodeArena RL Inference — Phase 2 compliant.
3
- Always makes at least one API call through the LiteLLM proxy
4
- using API_BASE_URL and API_KEY environment variables.
5
  """
6
 
7
  import os
8
-
9
  from openai import OpenAI
10
 
11
- from server.env import CodeArenaEnv
12
- from server.models import CodeArenaAction
13
-
14
-
15
- def run_inference():
16
- """Run inference. ALWAYS attempts an API call before any fallback."""
17
-
18
- # ── Standalone proxy ping (MUST run before anything else) ──────────
19
- try:
20
- from openai import OpenAI
21
- import os
22
-
23
- client = OpenAI(
24
- base_url=os.environ["API_BASE_URL"],
25
- api_key=os.environ["API_KEY"],
26
- )
27
-
28
- # lightweight call (DO NOT REMOVE)
29
- _ = client.chat.completions.create(
30
- model=os.environ.get("MODEL_NAME", "gpt-4o-mini"),
31
- messages=[{"role": "user", "content": "ping"}],
32
- max_tokens=1,
33
- )
34
- except Exception:
35
- pass
36
-
37
  try:
38
- print("[START] Initializing CodeArena inference")
39
-
40
- # ── Required env vars (set by the OpenEnv evaluator) ──────────
41
- base_url = os.environ["API_BASE_URL"]
42
- api_key = os.environ["API_KEY"]
43
-
44
- client = OpenAI(
45
- base_url=base_url,
46
- api_key=api_key,
47
- )
48
-
49
- model = os.environ.get("MODEL_NAME", "gpt-4o-mini")
50
-
51
- # ── Mandatory first API call (evaluator checks this) ──────────
52
- print("[API] Making initial proxy call...")
53
- initial = client.chat.completions.create(
54
- model=model,
55
- messages=[
56
- {"role": "system", "content": "You are a helpful assistant."},
57
- {"role": "user", "content": "Say OK"},
58
- ],
59
- max_tokens=5,
60
- )
61
- print(f"[API] Proxy responded: {initial.choices[0].message.content}")
62
-
63
- # ── RL loop ───────────────────────────────────────────────────
64
- env = CodeArenaEnv()
65
- obs = env.reset()
66
-
67
- system_prompt = (
68
- "You are an expert autonomous code repair agent.\n"
69
- "Your goal is to fix the buggy code provided to you.\n"
70
- "Ensure your code is highly efficient and fully resolves all "
71
- "logical, syntax, and algorithmic bugs.\n"
72
- "Only return the fixed raw Python code. Do not output markdown "
73
- "blocks (like ```python). Do not explain your changes."
74
- )
75
-
76
- done = False
77
- step = 0
78
-
79
- while not done and step < env.max_steps:
80
- print(f"[STEP] Beginning Step {step + 1}")
81
-
82
- user_prompt = (
83
- f"Buggy Code:\n{obs.buggy_code}\n\n"
84
- f"Error Log:\n{obs.error_log}\n\n"
85
- f"Test Results:\n{obs.test_results}"
86
- )
87
-
88
- try:
89
- response = client.chat.completions.create(
90
- model=model,
91
- messages=[
92
- {"role": "system", "content": system_prompt},
93
- {"role": "user", "content": user_prompt},
94
- ],
95
- temperature=0.2,
96
- )
97
-
98
- proposed_fix = response.choices[0].message.content.strip()
99
-
100
- # Failsafe cleanup
101
- if proposed_fix.startswith("```python"):
102
- proposed_fix = proposed_fix[9:]
103
- if proposed_fix.startswith("```"):
104
- proposed_fix = proposed_fix[3:]
105
- if proposed_fix.endswith("```"):
106
- proposed_fix = proposed_fix[:-3]
107
-
108
- action = CodeArenaAction(proposed_fix=proposed_fix.strip())
109
- obs, reward, done, info = env.step(action)
110
- print(
111
- f"[STEP] Reward: {reward:.3f} | "
112
- f"Task: {info['task_id']}"
113
- )
114
-
115
- except Exception as e:
116
- print(f"[STEP] Warning: {e}")
117
- break
118
-
119
- step += 1
120
-
121
- print(f"[END] Inference complete. {step} step(s) executed.")
122
- return {
123
- "action": "analyze_code",
124
- "explanation": f"Inference completed after {step} step(s).",
125
- }
126
-
127
  except Exception as e:
128
- print(f"[ERROR] Fallback triggered: {e}")
129
- return {
130
- "action": "analyze_code",
131
- "explanation": f"Fallback: {str(e)}",
132
- }
133
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
- result = run_inference()
137
- print(result)
 
1
  """
2
+ CodeArena RL Inference
3
+ Rewritten for strict OpenEnv parsing.
 
4
  """
5
 
6
  import os
7
+ import httpx
8
  from openai import OpenAI
9
 
10
+ def run_task(task_id: str):
11
+ # Retrieve environment variables as instructed
12
+ base_url = os.environ.get("API_BASE_URL")
13
+ api_key = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY")
14
+ model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
15
+
16
+ # We pass base_url explicitly. If os.environ["API_BASE_URL"] was strictly intended,
17
+ # it is fine since OpenAI client accepts None for default.
18
+ client = OpenAI(
19
+ base_url=base_url,
20
+ api_key=api_key or "NO_KEY_PROVIDED"
21
+ )
22
+
23
+ # 1. Print the [START] line
24
+ print(f"[START] task={task_id} env=codearena-rl-benchmark model={model_name}")
25
+
26
+ # 2. Call POST http://localhost:7860/reset
 
 
 
 
 
 
 
 
 
27
  try:
28
+ response = httpx.post("http://localhost:7860/reset", json={"task_id": task_id}, timeout=30.0)
29
+ response.raise_for_status()
30
+ obs_json = response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  except Exception as e:
32
+ error_msg = str(e).replace("\n", " ").replace("\r", "")
33
+ print(f"[STEP] step=1 action=reset_failed reward=0.01 done=true error={error_msg}")
34
+ print(f"[END] success=false steps=1 rewards=0.01")
35
+ return
36
+
37
+ rewards = []
38
+ success = False
39
+ done = False
40
+ step = 0
41
+
42
+ # 3. For up to 5 steps
43
+ for i in range(5):
44
+ if done:
45
+ break
46
+
47
+ step += 1
48
+ obs = obs_json.get("observation", {})
49
+ buggy_code = obs.get("buggy_code", "")
50
+ error_log = obs.get("error_log", "")
51
+ test_results = obs.get("test_results", "")
52
+
53
+ system_prompt = "You are an expert Python code repair agent. Fix the buggy Python code.\nReturn ONLY the fixed raw Python code. No markdown, no explanation."
54
+ user_prompt = f"Fix this buggy Python code:\n\n{buggy_code}\n\nError log:\n{error_log}\n\nTest results so far:\n{test_results}"
55
+
56
+ error_msg = "null"
57
+ proposed_fix = ""
58
+
59
+ # 3b/c. Call the LLM
60
+ try:
61
+ completion = client.chat.completions.create(
62
+ model=model_name,
63
+ messages=[
64
+ {"role": "system", "content": system_prompt},
65
+ {"role": "user", "content": user_prompt}
66
+ ]
67
+ )
68
+ proposed_fix = completion.choices[0].message.content
69
+ except Exception as e:
70
+ error_msg = str(e).replace("\n", " ").replace("\r", "")
71
+ # If the LLM call fails, use this fallback fix
72
+ proposed_fix = obs_json.get("observation", {}).get("buggy_code", "pass")
73
+
74
+ # Cleanup markdown from proposed_fix if LLM ignores instructions
75
+ if proposed_fix:
76
+ proposed_fix = proposed_fix.strip()
77
+ if proposed_fix.startswith("```python"):
78
+ proposed_fix = proposed_fix[9:]
79
+ elif proposed_fix.startswith("```"):
80
+ proposed_fix = proposed_fix[3:]
81
+ if proposed_fix.endswith("```"):
82
+ proposed_fix = proposed_fix[:-3]
83
+ proposed_fix = proposed_fix.strip()
84
+
85
+ # 3d. Send proposed_fix to /step
86
+ try:
87
+ step_resp = httpx.post("http://localhost:7860/step", json={"proposed_fix": proposed_fix}, timeout=60.0)
88
+ step_resp.raise_for_status()
89
+ step_data = step_resp.json()
90
+ raw_reward = step_data.get("reward", 0.0)
91
+ done = step_data.get("done", True)
92
+ obs_json = step_data
93
+ except Exception as e:
94
+ raw_reward = 0.01
95
+ done = True
96
+ if error_msg == "null":
97
+ error_msg = str(e).replace("\n", " ").replace("\r", "")
98
+
99
+ # 3e. Clamp it
100
+ reward = max(0.01, min(0.99, float(raw_reward)))
101
+ rewards.append(reward)
102
+
103
+ # 3f. Print [STEP] line immediately
104
+ done_str = "true" if done else "false"
105
+ action_summary = "llm_fix" if error_msg == "null" else "fallback_fix"
106
+ print(f"[STEP] step={step} action={action_summary} reward={reward:.2f} done={done_str} error={error_msg}")
107
+
108
+ # 4. Print [END]
109
+ success = any(r > 0.5 for r in rewards)
110
+ success_str = "true" if success else "false"
111
+ rewards_str = ",".join([f"{r:.2f}" for r in rewards])
112
+ print(f"[END] success={success_str} steps={step} rewards={rewards_str}")
113
+
114
+ def main():
115
+ target_task = os.environ.get("CODEARENA_TASK")
116
+ if target_task:
117
+ run_task(target_task)
118
+ else:
119
+ for t in ["easy", "medium", "hard"]:
120
+ run_task(t)
121
 
122
  if __name__ == "__main__":
123
+ main()
 
requirements.txt CHANGED
@@ -2,3 +2,4 @@ fastapi>=0.100.0
2
  uvicorn[standard]>=0.23.0
3
  pydantic>=2.0.0
4
  openai>=1.0.0
 
 
2
  uvicorn[standard]>=0.23.0
3
  pydantic>=2.0.0
4
  openai>=1.0.0
5
+ httpx>=0.24.1