Naman Gupta commited on
Commit
55c0431
·
1 Parent(s): ff8a596

Fix inference grade call when episode not done; update baseline scores from real run

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. inference.py +26 -0
README.md CHANGED
@@ -248,7 +248,7 @@ Scores produced by running `inference.py` with `llama-3.1-8b-instant` against th
248
  |--------|-------|--------------|-------|
249
  | Easy | 0.55 | D | 5 |
250
  | Medium | 0.63 | C | 8 |
251
- | Hard | 0.51 | D | 10 |
252
 
253
  Run baseline yourself:
254
  ```bash
 
248
  |--------|-------|--------------|-------|
249
  | Easy | 0.55 | D | 5 |
250
  | Medium | 0.63 | C | 8 |
251
+ | Hard | 0.63 | C | 10 |
252
 
253
  Run baseline yourself:
254
  ```bash
inference.py CHANGED
@@ -128,6 +128,32 @@ async def run_task(task: str) -> None:
128
  if done:
129
  break
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # Grade the episode
132
  try:
133
  grade_resp = await http.post("/grade")
 
128
  if done:
129
  break
130
 
131
+ # If episode didn't end naturally, exhaust remaining turns to close it
132
+ state_resp = await http.get("/state")
133
+ if state_resp.json().get("is_active", False):
134
+ while True:
135
+ step += 1
136
+ action = generate_attack(defender_resp, step, prev_success)
137
+ action_str = f"strategy={action['strategy_type']} category={action['target_category']} intensity={action['intensity']}"
138
+ try:
139
+ step_resp = await http.post("/step", json=action)
140
+ step_data = step_resp.json()
141
+ obs = step_data["observation"]
142
+ reward = round(float(step_data.get("reward", 0.0)), 2)
143
+ done = bool(obs["episode_done"])
144
+ defender_resp = obs["defender_response"]
145
+ prev_success = obs["attack_success_estimate"]
146
+ rewards.append(reward)
147
+ print(
148
+ f"[STEP] step={step} action={action_str!r} "
149
+ f"reward={reward:.2f} done={str(done).lower()} error=null",
150
+ flush=True,
151
+ )
152
+ except Exception:
153
+ done = True
154
+ if done:
155
+ break
156
+
157
  # Grade the episode
158
  try:
159
  grade_resp = await http.post("/grade")