shank commited on
Commit Β·
2a482a5
1
Parent(s): a2ff803
Fix: Fixed exception handling in inference.py
Browse files- inference.py +73 -11
inference.py
CHANGED
|
@@ -14,7 +14,8 @@ import os
|
|
| 14 |
import json
|
| 15 |
import time
|
| 16 |
import re
|
| 17 |
-
|
|
|
|
| 18 |
import requests
|
| 19 |
|
| 20 |
# ββ Environment variables (never hardcode these) ββββββββββββββββββββββββββββββ
|
|
@@ -65,6 +66,37 @@ Guidelines:
|
|
| 65 |
- For concurrent tasks, ensure atomic operations and proper synchronization.
|
| 66 |
"""
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def parse_action(raw: str) -> dict:
|
| 70 |
"""Parse LLM response to action dict. Handle markdown code blocks."""
|
|
@@ -148,15 +180,20 @@ def run_episode(task_id: str) -> dict:
|
|
| 148 |
action = {}
|
| 149 |
|
| 150 |
while not done:
|
| 151 |
-
# Get LLM action
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# Submit action to environment
|
| 162 |
step_resp = requests.post(f"{ENV_BASE_URL}/step", json=action)
|
|
@@ -193,9 +230,19 @@ def run_episode(task_id: str) -> dict:
|
|
| 193 |
|
| 194 |
def main():
|
| 195 |
print("AgentDebuggerEnv β Baseline Inference")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
print(f"Model: {MODEL_NAME}")
|
| 197 |
print(f"API: {API_BASE_URL}")
|
|
|
|
| 198 |
print(f"Env: {ENV_BASE_URL}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
print("=" * 55)
|
| 200 |
|
| 201 |
results = []
|
|
@@ -204,7 +251,22 @@ def main():
|
|
| 204 |
for task_id in ["easy", "medium", "hard"]:
|
| 205 |
print(f"\nTask: {task_id}")
|
| 206 |
t0 = time.time()
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
elapsed = time.time() - t0
|
| 209 |
|
| 210 |
solved_str = "β SOLVED" if result["solved"] else "β UNSOLVED"
|
|
|
|
| 14 |
import json
|
| 15 |
import time
|
| 16 |
import re
|
| 17 |
+
import random
|
| 18 |
+
from openai import OpenAI, APIError, RateLimitError, APIConnectionError, APITimeoutError
|
| 19 |
import requests
|
| 20 |
|
| 21 |
# ββ Environment variables (never hardcode these) ββββββββββββββββββββββββββββββ
|
|
|
|
| 66 |
- For concurrent tasks, ensure atomic operations and proper synchronization.
|
| 67 |
"""
|
| 68 |
|
| 69 |
+
# ββ Robust API Completion Helper ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
|
| 71 |
+
def get_completion(messages: list, model: str = MODEL_NAME, max_retries: int = 5) -> str:
|
| 72 |
+
"""Gets LLM completion with exponential backoff and retry logic."""
|
| 73 |
+
for attempt in range(max_retries):
|
| 74 |
+
try:
|
| 75 |
+
completion = client.chat.completions.create(
|
| 76 |
+
model=model,
|
| 77 |
+
messages=messages,
|
| 78 |
+
max_tokens=1200,
|
| 79 |
+
temperature=0.2,
|
| 80 |
+
timeout=60.0 # Add a timeout to prevent hanging forever
|
| 81 |
+
)
|
| 82 |
+
return completion.choices[0].message.content
|
| 83 |
+
except (RateLimitError, APIConnectionError, APITimeoutError) as e:
|
| 84 |
+
if attempt == max_retries - 1:
|
| 85 |
+
raise e
|
| 86 |
+
wait_time = (2 ** attempt) + random.random()
|
| 87 |
+
print(f" [!] API Error ({type(e).__name__}). Retrying in {wait_time:.1f}s... (Attempt {attempt+1}/{max_retries})")
|
| 88 |
+
time.sleep(wait_time)
|
| 89 |
+
except APIError as e:
|
| 90 |
+
# For general API errors, log and potentially retry if it's a 5xx
|
| 91 |
+
print(f" [!] OpenAI API Error: {e}")
|
| 92 |
+
if attempt == max_retries - 1:
|
| 93 |
+
raise e
|
| 94 |
+
time.sleep(2)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f" [!] Unexpected error during completion: {e}")
|
| 97 |
+
raise e
|
| 98 |
+
return ""
|
| 99 |
+
|
| 100 |
|
| 101 |
def parse_action(raw: str) -> dict:
|
| 102 |
"""Parse LLM response to action dict. Handle markdown code blocks."""
|
|
|
|
| 180 |
action = {}
|
| 181 |
|
| 182 |
while not done:
|
| 183 |
+
# Get LLM action using the robust helper
|
| 184 |
+
try:
|
| 185 |
+
raw = get_completion(messages)
|
| 186 |
+
if not raw:
|
| 187 |
+
raise ValueError("Empty response from LLM")
|
| 188 |
+
action = parse_action(raw)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f" [β] Failed to get response from LLM after retries: {e}")
|
| 191 |
+
# Fallback action to avoid crashing the whole episode
|
| 192 |
+
action = {
|
| 193 |
+
"action_type": "give_up",
|
| 194 |
+
"final_diagnosis": f"Inference system failure: {str(e)}"
|
| 195 |
+
}
|
| 196 |
+
raw = json.dumps(action)
|
| 197 |
|
| 198 |
# Submit action to environment
|
| 199 |
step_resp = requests.post(f"{ENV_BASE_URL}/step", json=action)
|
|
|
|
| 230 |
|
| 231 |
def main():
|
| 232 |
print("AgentDebuggerEnv β Baseline Inference")
|
| 233 |
+
|
| 234 |
+
# ββ Environment validation ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 235 |
+
has_token = bool(HF_TOKEN and len(HF_TOKEN) > 5)
|
| 236 |
+
masked_token = f"{HF_TOKEN[:4]}...{HF_TOKEN[-4:]}" if has_token else "MISSING"
|
| 237 |
+
|
| 238 |
print(f"Model: {MODEL_NAME}")
|
| 239 |
print(f"API: {API_BASE_URL}")
|
| 240 |
+
print(f"Token: {masked_token}")
|
| 241 |
print(f"Env: {ENV_BASE_URL}")
|
| 242 |
+
|
| 243 |
+
if not has_token and "openai.com" in API_BASE_URL:
|
| 244 |
+
print("WARNING: HF_TOKEN is missing but using default OpenAI endpoint. This may fail.")
|
| 245 |
+
|
| 246 |
print("=" * 55)
|
| 247 |
|
| 248 |
results = []
|
|
|
|
| 251 |
for task_id in ["easy", "medium", "hard"]:
|
| 252 |
print(f"\nTask: {task_id}")
|
| 253 |
t0 = time.time()
|
| 254 |
+
try:
|
| 255 |
+
result = run_episode(task_id)
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(f" [β] Error running episode '{task_id}': {e}")
|
| 258 |
+
result = {
|
| 259 |
+
"task_id": task_id,
|
| 260 |
+
"grader_score": 0.0,
|
| 261 |
+
"cumulative_reward": 0.0,
|
| 262 |
+
"steps_taken": 0,
|
| 263 |
+
"attempts_used": 0,
|
| 264 |
+
"tests_passed": 0,
|
| 265 |
+
"tests_total": 0,
|
| 266 |
+
"solved": False,
|
| 267 |
+
"final_action_type": "error"
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
elapsed = time.time() - t0
|
| 271 |
|
| 272 |
solved_str = "β SOLVED" if result["solved"] else "β UNSOLVED"
|