shank commited on
Commit
2a482a5
Β·
1 Parent(s): a2ff803

Fix: Fixed exception handling in inference.py

Browse files
Files changed (1) hide show
  1. inference.py +73 -11
inference.py CHANGED
@@ -14,7 +14,8 @@ import os
14
  import json
15
  import time
16
  import re
17
- from openai import OpenAI
 
18
  import requests
19
 
20
  # ── Environment variables (never hardcode these) ──────────────────────────────
@@ -65,6 +66,37 @@ Guidelines:
65
  - For concurrent tasks, ensure atomic operations and proper synchronization.
66
  """
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def parse_action(raw: str) -> dict:
70
  """Parse LLM response to action dict. Handle markdown code blocks."""
@@ -148,15 +180,20 @@ def run_episode(task_id: str) -> dict:
148
  action = {}
149
 
150
  while not done:
151
- # Get LLM action
152
- completion = client.chat.completions.create(
153
- model=MODEL_NAME,
154
- messages=messages,
155
- max_tokens=1200,
156
- temperature=0.2
157
- )
158
- raw = completion.choices[0].message.content
159
- action = parse_action(raw)
 
 
 
 
 
160
 
161
  # Submit action to environment
162
  step_resp = requests.post(f"{ENV_BASE_URL}/step", json=action)
@@ -193,9 +230,19 @@ def run_episode(task_id: str) -> dict:
193
 
194
  def main():
195
  print("AgentDebuggerEnv β€” Baseline Inference")
 
 
 
 
 
196
  print(f"Model: {MODEL_NAME}")
197
  print(f"API: {API_BASE_URL}")
 
198
  print(f"Env: {ENV_BASE_URL}")
 
 
 
 
199
  print("=" * 55)
200
 
201
  results = []
@@ -204,7 +251,22 @@ def main():
204
  for task_id in ["easy", "medium", "hard"]:
205
  print(f"\nTask: {task_id}")
206
  t0 = time.time()
207
- result = run_episode(task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  elapsed = time.time() - t0
209
 
210
  solved_str = "βœ“ SOLVED" if result["solved"] else "βœ— UNSOLVED"
 
14
  import json
15
  import time
16
  import re
17
+ import random
18
+ from openai import OpenAI, APIError, RateLimitError, APIConnectionError, APITimeoutError
19
  import requests
20
 
21
  # ── Environment variables (never hardcode these) ──────────────────────────────
 
66
  - For concurrent tasks, ensure atomic operations and proper synchronization.
67
  """
68
 
69
+ # ── Robust API Completion Helper ──────────────────────────────────────────────
70
+
71
+ def get_completion(messages: list, model: str = MODEL_NAME, max_retries: int = 5) -> str:
72
+ """Gets LLM completion with exponential backoff and retry logic."""
73
+ for attempt in range(max_retries):
74
+ try:
75
+ completion = client.chat.completions.create(
76
+ model=model,
77
+ messages=messages,
78
+ max_tokens=1200,
79
+ temperature=0.2,
80
+ timeout=60.0 # Add a timeout to prevent hanging forever
81
+ )
82
+ return completion.choices[0].message.content
83
+ except (RateLimitError, APIConnectionError, APITimeoutError) as e:
84
+ if attempt == max_retries - 1:
85
+ raise e
86
+ wait_time = (2 ** attempt) + random.random()
87
+ print(f" [!] API Error ({type(e).__name__}). Retrying in {wait_time:.1f}s... (Attempt {attempt+1}/{max_retries})")
88
+ time.sleep(wait_time)
89
+ except APIError as e:
90
+ # For general API errors, log and potentially retry if it's a 5xx
91
+ print(f" [!] OpenAI API Error: {e}")
92
+ if attempt == max_retries - 1:
93
+ raise e
94
+ time.sleep(2)
95
+ except Exception as e:
96
+ print(f" [!] Unexpected error during completion: {e}")
97
+ raise e
98
+ return ""
99
+
100
 
101
  def parse_action(raw: str) -> dict:
102
  """Parse LLM response to action dict. Handle markdown code blocks."""
 
180
  action = {}
181
 
182
  while not done:
183
+ # Get LLM action using the robust helper
184
+ try:
185
+ raw = get_completion(messages)
186
+ if not raw:
187
+ raise ValueError("Empty response from LLM")
188
+ action = parse_action(raw)
189
+ except Exception as e:
190
+ print(f" [βœ—] Failed to get response from LLM after retries: {e}")
191
+ # Fallback action to avoid crashing the whole episode
192
+ action = {
193
+ "action_type": "give_up",
194
+ "final_diagnosis": f"Inference system failure: {str(e)}"
195
+ }
196
+ raw = json.dumps(action)
197
 
198
  # Submit action to environment
199
  step_resp = requests.post(f"{ENV_BASE_URL}/step", json=action)
 
230
 
231
  def main():
232
  print("AgentDebuggerEnv β€” Baseline Inference")
233
+
234
+ # ── Environment validation ────────────────────────────────────────────────
235
+ has_token = bool(HF_TOKEN and len(HF_TOKEN) > 5)
236
+ masked_token = f"{HF_TOKEN[:4]}...{HF_TOKEN[-4:]}" if has_token else "MISSING"
237
+
238
  print(f"Model: {MODEL_NAME}")
239
  print(f"API: {API_BASE_URL}")
240
+ print(f"Token: {masked_token}")
241
  print(f"Env: {ENV_BASE_URL}")
242
+
243
+ if not has_token and "openai.com" in API_BASE_URL:
244
+ print("WARNING: HF_TOKEN is missing but using default OpenAI endpoint. This may fail.")
245
+
246
  print("=" * 55)
247
 
248
  results = []
 
251
  for task_id in ["easy", "medium", "hard"]:
252
  print(f"\nTask: {task_id}")
253
  t0 = time.time()
254
+ try:
255
+ result = run_episode(task_id)
256
+ except Exception as e:
257
+ print(f" [βœ—] Error running episode '{task_id}': {e}")
258
+ result = {
259
+ "task_id": task_id,
260
+ "grader_score": 0.0,
261
+ "cumulative_reward": 0.0,
262
+ "steps_taken": 0,
263
+ "attempts_used": 0,
264
+ "tests_passed": 0,
265
+ "tests_total": 0,
266
+ "solved": False,
267
+ "final_action_type": "error"
268
+ }
269
+
270
  elapsed = time.time() - t0
271
 
272
  solved_str = "βœ“ SOLVED" if result["solved"] else "βœ— UNSOLVED"