databoysu commited on
Commit
d1abaef
·
1 Parent(s): e5add8c

Fixing state leak bug

Browse files
Files changed (3) hide show
  1. backend/tracefix_rl_environment.py +13 -2
  2. inference.py +6 -4
  3. vision_ui.py +10 -0
backend/tracefix_rl_environment.py CHANGED
@@ -20,7 +20,7 @@ class TraceFixRLEnvironment(Environment):
20
  self._gym = TraceFixRLGym()
21
  self._state = State(episode_id="", step_count=0)
22
 
23
- def reset(self, difficulty: str | None = None) -> CodeObservation:
24
  if difficulty == "easy":
25
  self._gym.training_step = 1
26
  elif difficulty == "medium":
@@ -28,7 +28,18 @@ class TraceFixRLEnvironment(Environment):
28
  elif difficulty == "hard":
29
  self._gym.training_step = 6000
30
 
31
- obs, system_prompt = self._gym.reset()
 
 
 
 
 
 
 
 
 
 
 
32
  self._state = State(
33
  episode_id=obs.info.get("episode_id", ""),
34
  step_count=obs.step_count,
 
20
  self._gym = TraceFixRLGym()
21
  self._state = State(episode_id="", step_count=0)
22
 
23
+ def reset(self, difficulty: str | None = None, task_name: str | None = None) -> CodeObservation:
24
  if difficulty == "easy":
25
  self._gym.training_step = 1
26
  elif difficulty == "medium":
 
28
  elif difficulty == "hard":
29
  self._gym.training_step = 6000
30
 
31
+ task_dict = None
32
+ if task_name and task_name != "tracefix_rl":
33
+ try:
34
+ from tasks.tasks import ALL_TASKS
35
+ for t in ALL_TASKS:
36
+ if t.get("name") == task_name:
37
+ task_dict = t
38
+ break
39
+ except ImportError:
40
+ pass
41
+
42
+ obs, system_prompt = self._gym.reset(task_index=task_dict)
43
  self._state = State(
44
  episode_id=obs.info.get("episode_id", ""),
45
  step_count=obs.step_count,
inference.py CHANGED
@@ -325,11 +325,13 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
325
  else:
326
  env = TraceFixRLEnv(base_url=ENV_BASE_URL)
327
 
 
328
  if difficulty:
329
- reset_kwargs = {"difficulty": difficulty}
330
- result = await env.reset(**reset_kwargs)
331
- else:
332
- result = await env.reset()
 
333
  task_name = result.observation.info.get("task_name") or TASK_NAME
334
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
335
  started = True
 
325
  else:
326
  env = TraceFixRLEnv(base_url=ENV_BASE_URL)
327
 
328
+ reset_kwargs = {}
329
  if difficulty:
330
+ reset_kwargs["difficulty"] = difficulty
331
+ if TASK_NAME and TASK_NAME != "tracefix_rl":
332
+ reset_kwargs["task_name"] = TASK_NAME
333
+
334
+ result = await env.reset(**reset_kwargs)
335
  task_name = result.observation.info.get("task_name") or TASK_NAME
336
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
337
  started = True
vision_ui.py CHANGED
@@ -287,6 +287,16 @@ def _build_env(
287
 
288
 
289
  def sync_tasks(selected, grid_name):
 
 
 
 
 
 
 
 
 
 
290
  if grid_name == "easy":
291
  easy_val = selected
292
  med_val = None
 
287
 
288
 
289
  def sync_tasks(selected, grid_name):
290
+ if not selected:
291
+ return (
292
+ gr.skip(),
293
+ gr.skip(),
294
+ gr.skip(),
295
+ gr.skip(),
296
+ gr.skip(),
297
+ gr.skip(),
298
+ gr.skip()
299
+ )
300
  if grid_name == "easy":
301
  easy_val = selected
302
  med_val = None