Priyansh Saxena commited on
Commit
5b04645
·
1 Parent(s): 72a7241

fix: harden inference runtime and add logging tests

Browse files
Files changed (2) hide show
  1. inference.py +49 -19
  2. tests/test_inference_logging.py +28 -0
inference.py CHANGED
@@ -17,15 +17,22 @@ SUCCESS_SCORE_THRESHOLD = float(os.environ.get("SUCCESS_SCORE_THRESHOLD", "0.7")
17
  MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0"))
18
 
19
 
 
 
 
 
 
 
20
  def log_start(task, env, model):
21
  print(f"[START] task={task} env={env} model={model}", flush=True)
22
 
23
 
24
  def log_step(step, action, reward, done, error):
25
- err = "null" if error is None else str(error)
 
26
  done_str = "true" if done else "false"
27
  print(
28
- f"[STEP] step={step} action={action} reward={reward:.2f} done={done_str} error={err}",
29
  flush=True,
30
  )
31
 
@@ -65,29 +72,42 @@ History: {history}
65
  return (completion.choices[0].message.content or "").strip()
66
 
67
 
68
- async def main():
69
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
70
- tasks = [task.strip() for task in TASKS.split(",") if task.strip()]
 
71
 
72
- for task in tasks:
73
- rewards = []
74
- history = []
75
- steps_taken = 0
76
-
77
- log_start(task=task, env="pytorch-debug-env", model=MODEL_NAME)
78
 
 
79
  async with httpx.AsyncClient(timeout=60.0) as session:
80
  reset_resp = await session.post(f"{ENV_URL}/reset", params={"task_id": task})
81
  reset_resp.raise_for_status()
82
  result = reset_resp.json()
 
83
  session_id = result.get("session_id")
84
- observation = result["observation"]
 
 
 
 
85
 
86
  for step in range(1, MAX_STEPS + 1):
87
  if result.get("done"):
88
  break
89
 
90
- action_text = get_model_message(client, observation, history)
 
 
 
 
 
 
 
 
 
 
 
91
  try:
92
  action_json = json.loads(action_text)
93
  step_resp = await session.post(
@@ -99,12 +119,12 @@ async def main():
99
  result = step_resp.json()
100
  reward = result.get("reward", 0.0)
101
  done = result.get("done", False)
102
- error = None
103
- observation = result["observation"]
104
  except Exception as exc:
105
  reward = 0.0
106
  done = True
107
- error = str(exc)
108
 
109
  rewards.append(reward)
110
  steps_taken = step
@@ -113,10 +133,20 @@ async def main():
113
 
114
  if done:
115
  break
 
 
 
 
 
 
 
116
 
117
- score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0)
118
- success = score >= SUCCESS_SCORE_THRESHOLD
119
- log_end(success=success, steps=steps_taken, rewards=rewards)
 
 
 
120
 
121
 
122
  if __name__ == "__main__":
 
17
  MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0"))
18
 
19
 
20
+ def _sanitize_field(value: object) -> str:
21
+ text = str(value)
22
+ text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ")
23
+ return " ".join(text.split())
24
+
25
+
26
  def log_start(task, env, model):
27
  print(f"[START] task={task} env={env} model={model}", flush=True)
28
 
29
 
30
  def log_step(step, action, reward, done, error):
31
+ safe_action = _sanitize_field(action)
32
+ err = "null" if error is None else _sanitize_field(error)
33
  done_str = "true" if done else "false"
34
  print(
35
+ f"[STEP] step={step} action={safe_action} reward={reward:.2f} done={done_str} error={err}",
36
  flush=True,
37
  )
38
 
 
72
  return (completion.choices[0].message.content or "").strip()
73
 
74
 
75
+ async def _run_task(task: str, client: OpenAI) -> None:
76
+ rewards: List[float] = []
77
+ history: List[str] = []
78
+ steps_taken = 0
79
 
80
+ log_start(task=task, env="pytorch-debug-env", model=MODEL_NAME)
 
 
 
 
 
81
 
82
+ try:
83
  async with httpx.AsyncClient(timeout=60.0) as session:
84
  reset_resp = await session.post(f"{ENV_URL}/reset", params={"task_id": task})
85
  reset_resp.raise_for_status()
86
  result = reset_resp.json()
87
+
88
  session_id = result.get("session_id")
89
+ observation = result.get("observation")
90
+ if not session_id:
91
+ raise RuntimeError("Missing session_id in reset response")
92
+ if observation is None:
93
+ raise RuntimeError("Missing observation in reset response")
94
 
95
  for step in range(1, MAX_STEPS + 1):
96
  if result.get("done"):
97
  break
98
 
99
+ action_text = "null"
100
+ try:
101
+ action_text = get_model_message(client, observation, history)
102
+ except Exception as exc:
103
+ reward = 0.0
104
+ done = True
105
+ error = f"model_error: {exc}"
106
+ rewards.append(reward)
107
+ steps_taken = step
108
+ log_step(step=step, action=action_text, reward=reward, done=done, error=error)
109
+ break
110
+
111
  try:
112
  action_json = json.loads(action_text)
113
  step_resp = await session.post(
 
119
  result = step_resp.json()
120
  reward = result.get("reward", 0.0)
121
  done = result.get("done", False)
122
+ error = result.get("error")
123
+ observation = result.get("observation", observation)
124
  except Exception as exc:
125
  reward = 0.0
126
  done = True
127
+ error = f"step_error: {exc}"
128
 
129
  rewards.append(reward)
130
  steps_taken = step
 
133
 
134
  if done:
135
  break
136
+ except Exception:
137
+ pass
138
+
139
+ score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0)
140
+ success = score >= SUCCESS_SCORE_THRESHOLD
141
+ log_end(success=success, steps=steps_taken, rewards=rewards)
142
+
143
 
144
+ async def main():
145
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
146
+ tasks = [task.strip() for task in TASKS.split(",") if task.strip()]
147
+
148
+ for task in tasks:
149
+ await _run_task(task, client)
150
 
151
 
152
  if __name__ == "__main__":
tests/test_inference_logging.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference import log_end, log_start, log_step
2
+
3
+
4
+ def test_log_start_format(capsys):
5
+ log_start(task="easy", env="pytorch-debug-env", model="test-model")
6
+ out = capsys.readouterr().out.strip()
7
+ assert out == "[START] task=easy env=pytorch-debug-env model=test-model"
8
+
9
+
10
+ def test_log_step_sanitizes_fields(capsys):
11
+ log_step(
12
+ step=1,
13
+ action="line1\nline2",
14
+ reward=0.0,
15
+ done=False,
16
+ error="bad\nerr",
17
+ )
18
+ out = capsys.readouterr().out.strip()
19
+ assert "\n" not in out
20
+ assert "action=line1 line2" in out
21
+ assert "error=bad err" in out
22
+ assert "done=false" in out
23
+
24
+
25
+ def test_log_end_format(capsys):
26
+ log_end(success=True, steps=3, rewards=[0.0, 0.1, 1.0])
27
+ out = capsys.readouterr().out.strip()
28
+ assert out == "[END] success=true steps=3 rewards=0.00,0.10,1.00"