databoysu commited on
Commit
8bf2ad3
·
1 Parent(s): 92da498

mfw tasks were going by names not id

Browse files
Files changed (1) hide show
  1. inference.py +184 -168
inference.py CHANGED
@@ -48,6 +48,16 @@ BENCHMARK = os.getenv("BENCHMARK", "tracefix_rl")
48
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
49
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.98"))
50
 
 
 
 
 
 
 
 
 
 
 
51
  SYSTEM_PROMPT = """\
52
  You are a deterministic debugging policy agent.
53
  You must output exactly one valid CodeAction JSON object per turn and nothing else.
@@ -306,18 +316,6 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
306
  )
307
 
308
  env: Optional[TraceFixRLEnv] = None
309
- rewards: list[float] = []
310
- history: list[str] = []
311
- history_messages: list[dict[str, str]] = []
312
- action_trajectory: list[str] = []
313
- steps_taken = 0
314
- score = 0.0
315
- success = False
316
- started = False
317
- kill_switch_triggered = False
318
- last_action_type: Optional[str] = None
319
- consecutive_same_action_count = 0
320
- consecutive_parse_error_count = 0
321
 
322
  try:
323
  if LOCAL_IMAGE_NAME:
@@ -325,179 +323,197 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
325
  else:
326
  env = TraceFixRLEnv(base_url=ENV_BASE_URL)
327
 
328
- reset_kwargs = {}
329
- if difficulty:
330
- reset_kwargs["difficulty"] = difficulty
331
- if TASK_NAME and TASK_NAME != "tracefix_rl":
332
- reset_kwargs["task_name"] = TASK_NAME
333
-
334
- result = await env.reset(**reset_kwargs)
335
- task_name = result.observation.info.get("task_name") or TASK_NAME
336
- log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
337
- started = True
338
-
339
- for step in range(1, MAX_STEPS + 1):
340
- if result.done:
341
- break
342
-
343
- action: Optional[CodeAction] = None
344
- parse_error_note: Optional[str] = None
345
- if step == 1:
346
- action = CodeAction(
347
- action_type="VIEW_CODE",
348
- thought="First step policy: inspect source before testing or editing.",
349
- )
350
- if show_thought:
351
- print("[THOUGHT]", file=sys.stderr, flush=True)
352
- print(action.thought, file=sys.stderr, flush=True)
353
- else:
354
- obs_text = _build_observation_text(result.observation)
355
- obs_last_output = str(getattr(result.observation, "last_execution_output", "") or "")
356
- pass_count_text, all_tests_pass_signal = _extract_pass_signal_fields(obs_last_output)
357
- last_action = action_trajectory[-1] if action_trajectory else "none"
358
- dynamic_override = ""
359
- if action_trajectory and action_trajectory[-1] == "REPLACE_LINES":
360
- dynamic_override = (
361
- "\n[SYSTEM OVERRIDE]: Your last action was REPLACE_LINES. "
362
- "You are STRICTLY FORBIDDEN from editing the code again. "
363
- "Your action_type MUST be RUN_TESTS to verify the changes.\n"
364
- )
365
- elif action_trajectory and action_trajectory[-1] == "VIEW_CODE":
366
- dynamic_override = (
367
- "\n[SYSTEM OVERRIDE]: Your last action was VIEW_CODE. "
368
- "You MUST choose RUN_TESTS next to get test evidence.\n"
369
- )
370
- if show_thought:
371
- output_preview = "\\n".join(obs_last_output.splitlines()[:6])
372
- print("[OBS_DEBUG]", file=sys.stderr, flush=True)
373
- print(
374
- f"chars={len(obs_last_output)} pass_count={pass_count_text} all_pass={str(all_tests_pass_signal).lower()} last_action={last_action}",
375
- file=sys.stderr,
376
- flush=True,
377
- )
378
- print(output_preview if output_preview else "<empty last_execution_output>", file=sys.stderr, flush=True)
379
- history_messages.append(
380
- {
381
- "role": "user",
382
- "content": (
383
- "Pick the single best next action and return only one valid CodeAction JSON object. "
384
- "Use localized_context/last_execution_output as evidence, and do not SUBMIT unless all tests explicitly pass. "
385
- "If all_tests_pass_signal=true, you must choose SUBMIT now and must not choose RUN_TESTS again. "
386
- "Do not wait for additional test output when all_tests_pass_signal=true. "
387
- "If last_action was RUN_TESTS and all_tests_pass_signal=false, choose REPLACE_LINES or VIEW_CODE next, not RUN_TESTS again.\n\n"
388
- f"action_trajectory={(' -> '.join(action_trajectory) if action_trajectory else 'none')}\n"
389
- f"{dynamic_override}\n"
390
- f"decision_guard: last_action={last_action}, pass_count_summary={pass_count_text}, all_tests_pass_signal={str(all_tests_pass_signal).lower()}\n\n"
391
- f"{obs_text}"
392
- ),
393
- }
394
- )
395
- try:
396
- action, assistant_json = _get_model_action(client, history_messages)
397
- consecutive_parse_error_count = 0
398
- history_messages.append({"role": "assistant", "content": assistant_json})
399
- if show_thought:
400
- _print_thought(action, assistant_json)
401
- except ModelParseError as exc:
402
- cause = str(exc).replace("\n", " ")
403
- parse_error_note = cause
404
- consecutive_parse_error_count += 1
405
- raw_response = (exc.raw_response or "").strip()
406
- if raw_response:
407
- history_messages.append({"role": "assistant", "content": raw_response})
408
- history_messages.append(
409
- {
410
- "role": "user",
411
- "content": (
412
- f"PARSE_ERROR: {cause}. "
413
- "Return one valid CodeAction object only. "
414
- "Include thought and ensure strict field types."
415
- ),
416
- }
417
- )
418
- history.append(f"PARSE_ERROR: {cause}")
419
- if consecutive_parse_error_count >= 3:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  kill_switch_triggered = True
421
  history.append(
422
- "KILL_SWITCH: PARSE_ERROR occurred 3 times consecutively. "
423
- "Terminating episode early to prevent token burn."
424
  )
425
  steps_taken = step
426
  success = False
427
  score = 0.0
428
  break
429
- action = CodeAction(
430
- action_type="RUN_TESTS",
431
- thought=(
432
- "PARSE_ERROR recovery step: run tests so the step is explicit and "
433
- "collect fresh traceback context for the next valid action."
434
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  )
 
 
436
 
437
- if kill_switch_triggered:
438
- break
439
 
440
- current_action_type = action.action_type
441
- if current_action_type == last_action_type:
442
- consecutive_same_action_count += 1
443
- else:
444
- consecutive_same_action_count = 1
445
- last_action_type = current_action_type
446
-
447
- if consecutive_same_action_count >= 3:
448
- kill_switch_triggered = True
449
- history.append(
450
- f"KILL_SWITCH: {current_action_type} selected 3 times consecutively. "
451
- "Terminating episode early to prevent looping."
452
- )
453
- steps_taken = step
454
- success = False
455
  score = 0.0
456
- break
457
-
458
- result = await env.step(action)
459
-
460
- reward = float(result.reward or 0.0)
461
- done = bool(result.done)
462
- action_str = action.action_type
463
-
464
- obs_meta = result.observation.metadata or {}
465
- error = obs_meta.get("last_action_error")
466
- if error is not None:
467
- error = str(error).replace("\n", " ")
468
- if parse_error_note:
469
- error = f"PARSE_ERROR: {parse_error_note}"
470
-
471
- rewards.append(reward)
472
- steps_taken = step
473
- action_thought = (action.thought or "").strip()
474
- history.append(
475
- f"Action {action_str}; reward {reward:.2f}; error {error or 'null'}."
476
- + (f" Thought: {action_thought}" if action_thought else "")
477
- )
478
- action_trajectory.append(action_str)
479
- log_step(step=step, action=action_str, reward=reward, done=done, error=error)
480
-
481
- if done:
482
- break
483
-
484
- if not kill_switch_triggered:
485
- score = _compute_score(result, rewards)
486
- success = score >= SUCCESS_SCORE_THRESHOLD
487
-
488
- except Exception as exc:
489
- if not started:
490
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
491
- started = True
492
- score = 0.0
493
- success = False
494
  finally:
495
  if env is not None:
496
  try:
497
  await env.close()
498
  except Exception:
499
  pass
500
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
501
 
502
 
503
  if __name__ == "__main__":
 
48
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
49
  SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.98"))
50
 
51
+ # Must match openenv.yaml task ids exactly.
52
+ TASKS = ["task1_easy", "task2_medium", "task3_hard"]
53
+
54
+ # The server reset API currently resolves tasks by internal task `name`.
55
+ TASK_ID_TO_RESET_NAME = {
56
+ "task1_easy": "valid_parentheses_wrong_mapping",
57
+ "task2_medium": "binary_search_off_by_one",
58
+ "task3_hard": "reverse_string_returns_original",
59
+ }
60
+
61
  SYSTEM_PROMPT = """\
62
  You are a deterministic debugging policy agent.
63
  You must output exactly one valid CodeAction JSON object per turn and nothing else.
 
316
  )
317
 
318
  env: Optional[TraceFixRLEnv] = None
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  try:
321
  if LOCAL_IMAGE_NAME:
 
323
  else:
324
  env = TraceFixRLEnv(base_url=ENV_BASE_URL)
325
 
326
+ for task_id in TASKS:
327
+ rewards: list[float] = []
328
+ history: list[str] = []
329
+ history_messages: list[dict[str, str]] = []
330
+ action_trajectory: list[str] = []
331
+ steps_taken = 0
332
+ score = 0.0
333
+ success = False
334
+ kill_switch_triggered = False
335
+ last_action_type: Optional[str] = None
336
+ consecutive_same_action_count = 0
337
+ consecutive_parse_error_count = 0
338
+ task_started = False
339
+
340
+ try:
341
+ reset_kwargs: dict[str, Any] = {}
342
+ if difficulty:
343
+ reset_kwargs["difficulty"] = difficulty
344
+ reset_kwargs["task_name"] = TASK_ID_TO_RESET_NAME.get(task_id, task_id)
345
+
346
+ result = await env.reset(**reset_kwargs)
347
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
348
+ task_started = True
349
+
350
+ for step in range(1, MAX_STEPS + 1):
351
+ if result.done:
352
+ break
353
+
354
+ action: Optional[CodeAction] = None
355
+ parse_error_note: Optional[str] = None
356
+ if step == 1:
357
+ action = CodeAction(
358
+ action_type="VIEW_CODE",
359
+ thought="First step policy: inspect source before testing or editing.",
360
+ )
361
+ if show_thought:
362
+ print("[THOUGHT]", file=sys.stderr, flush=True)
363
+ print(action.thought, file=sys.stderr, flush=True)
364
+ else:
365
+ obs_text = _build_observation_text(result.observation)
366
+ obs_last_output = str(getattr(result.observation, "last_execution_output", "") or "")
367
+ pass_count_text, all_tests_pass_signal = _extract_pass_signal_fields(obs_last_output)
368
+ last_action = action_trajectory[-1] if action_trajectory else "none"
369
+ dynamic_override = ""
370
+ if action_trajectory and action_trajectory[-1] == "REPLACE_LINES":
371
+ dynamic_override = (
372
+ "\n[SYSTEM OVERRIDE]: Your last action was REPLACE_LINES. "
373
+ "You are STRICTLY FORBIDDEN from editing the code again. "
374
+ "Your action_type MUST be RUN_TESTS to verify the changes.\n"
375
+ )
376
+ elif action_trajectory and action_trajectory[-1] == "VIEW_CODE":
377
+ dynamic_override = (
378
+ "\n[SYSTEM OVERRIDE]: Your last action was VIEW_CODE. "
379
+ "You MUST choose RUN_TESTS next to get test evidence.\n"
380
+ )
381
+ if show_thought:
382
+ output_preview = "\\n".join(obs_last_output.splitlines()[:6])
383
+ print("[OBS_DEBUG]", file=sys.stderr, flush=True)
384
+ print(
385
+ f"chars={len(obs_last_output)} pass_count={pass_count_text} all_pass={str(all_tests_pass_signal).lower()} last_action={last_action}",
386
+ file=sys.stderr,
387
+ flush=True,
388
+ )
389
+ print(output_preview if output_preview else "<empty last_execution_output>", file=sys.stderr, flush=True)
390
+ history_messages.append(
391
+ {
392
+ "role": "user",
393
+ "content": (
394
+ "Pick the single best next action and return only one valid CodeAction JSON object. "
395
+ "Use localized_context/last_execution_output as evidence, and do not SUBMIT unless all tests explicitly pass. "
396
+ "If all_tests_pass_signal=true, you must choose SUBMIT now and must not choose RUN_TESTS again. "
397
+ "Do not wait for additional test output when all_tests_pass_signal=true. "
398
+ "If last_action was RUN_TESTS and all_tests_pass_signal=false, choose REPLACE_LINES or VIEW_CODE next, not RUN_TESTS again.\n\n"
399
+ f"action_trajectory={(' -> '.join(action_trajectory) if action_trajectory else 'none')}\n"
400
+ f"{dynamic_override}\n"
401
+ f"decision_guard: last_action={last_action}, pass_count_summary={pass_count_text}, all_tests_pass_signal={str(all_tests_pass_signal).lower()}\n\n"
402
+ f"{obs_text}"
403
+ ),
404
+ }
405
+ )
406
+ try:
407
+ action, assistant_json = _get_model_action(client, history_messages)
408
+ consecutive_parse_error_count = 0
409
+ history_messages.append({"role": "assistant", "content": assistant_json})
410
+ if show_thought:
411
+ _print_thought(action, assistant_json)
412
+ except ModelParseError as exc:
413
+ cause = str(exc).replace("\n", " ")
414
+ parse_error_note = cause
415
+ consecutive_parse_error_count += 1
416
+ raw_response = (exc.raw_response or "").strip()
417
+ if raw_response:
418
+ history_messages.append({"role": "assistant", "content": raw_response})
419
+ history_messages.append(
420
+ {
421
+ "role": "user",
422
+ "content": (
423
+ f"PARSE_ERROR: {cause}. "
424
+ "Return one valid CodeAction object only. "
425
+ "Include thought and ensure strict field types."
426
+ ),
427
+ }
428
+ )
429
+ history.append(f"PARSE_ERROR: {cause}")
430
+ if consecutive_parse_error_count >= 3:
431
+ kill_switch_triggered = True
432
+ history.append(
433
+ "KILL_SWITCH: PARSE_ERROR occurred 3 times consecutively. "
434
+ "Terminating episode early to prevent token burn."
435
+ )
436
+ steps_taken = step
437
+ success = False
438
+ score = 0.0
439
+ break
440
+ action = CodeAction(
441
+ action_type="RUN_TESTS",
442
+ thought=(
443
+ "PARSE_ERROR recovery step: run tests so the step is explicit and "
444
+ "collect fresh traceback context for the next valid action."
445
+ ),
446
+ )
447
+
448
+ if kill_switch_triggered:
449
+ break
450
+
451
+ current_action_type = action.action_type
452
+ if current_action_type == last_action_type:
453
+ consecutive_same_action_count += 1
454
+ else:
455
+ consecutive_same_action_count = 1
456
+ last_action_type = current_action_type
457
+
458
+ if consecutive_same_action_count >= 3:
459
  kill_switch_triggered = True
460
  history.append(
461
+ f"KILL_SWITCH: {current_action_type} selected 3 times consecutively. "
462
+ "Terminating episode early to prevent looping."
463
  )
464
  steps_taken = step
465
  success = False
466
  score = 0.0
467
  break
468
+
469
+ result = await env.step(action)
470
+
471
+ reward = float(result.reward or 0.0)
472
+ done = bool(result.done)
473
+ action_str = action.action_type
474
+
475
+ obs_meta = result.observation.metadata or {}
476
+ error = obs_meta.get("last_action_error")
477
+ if error is not None:
478
+ error = str(error).replace("\n", " ")
479
+ if parse_error_note:
480
+ error = f"PARSE_ERROR: {parse_error_note}"
481
+
482
+ rewards.append(reward)
483
+ steps_taken = step
484
+ action_thought = (action.thought or "").strip()
485
+ history.append(
486
+ f"Action {action_str}; reward {reward:.2f}; error {error or 'null'}."
487
+ + (f" Thought: {action_thought}" if action_thought else "")
488
  )
489
+ action_trajectory.append(action_str)
490
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
491
 
492
+ if done:
493
+ break
494
 
495
+ if not kill_switch_triggered:
496
+ score = _compute_score(result, rewards)
497
+ success = score >= SUCCESS_SCORE_THRESHOLD
498
+
499
+ except Exception:
500
+ if not task_started:
501
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
502
+ task_started = True
 
 
 
 
 
 
 
503
  score = 0.0
504
+ success = False
505
+ finally:
506
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
507
+
508
+ except Exception:
509
+ # Preserve existing behavior: unexpected top-level failures should not crash silently.
510
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  finally:
512
  if env is not None:
513
  try:
514
  await env.close()
515
  except Exception:
516
  pass
 
517
 
518
 
519
  if __name__ == "__main__":