Spaces:
Sleeping
Sleeping
databoysu commited on
Commit ·
d1abaef
1
Parent(s): e5add8c
Fixing state leak bug
Browse files- backend/tracefix_rl_environment.py +13 -2
- inference.py +6 -4
- vision_ui.py +10 -0
backend/tracefix_rl_environment.py
CHANGED
|
@@ -20,7 +20,7 @@ class TraceFixRLEnvironment(Environment):
|
|
| 20 |
self._gym = TraceFixRLGym()
|
| 21 |
self._state = State(episode_id="", step_count=0)
|
| 22 |
|
| 23 |
-
def reset(self, difficulty: str | None = None) -> CodeObservation:
|
| 24 |
if difficulty == "easy":
|
| 25 |
self._gym.training_step = 1
|
| 26 |
elif difficulty == "medium":
|
|
@@ -28,7 +28,18 @@ class TraceFixRLEnvironment(Environment):
|
|
| 28 |
elif difficulty == "hard":
|
| 29 |
self._gym.training_step = 6000
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
self._state = State(
|
| 33 |
episode_id=obs.info.get("episode_id", ""),
|
| 34 |
step_count=obs.step_count,
|
|
|
|
| 20 |
self._gym = TraceFixRLGym()
|
| 21 |
self._state = State(episode_id="", step_count=0)
|
| 22 |
|
| 23 |
+
def reset(self, difficulty: str | None = None, task_name: str | None = None) -> CodeObservation:
|
| 24 |
if difficulty == "easy":
|
| 25 |
self._gym.training_step = 1
|
| 26 |
elif difficulty == "medium":
|
|
|
|
| 28 |
elif difficulty == "hard":
|
| 29 |
self._gym.training_step = 6000
|
| 30 |
|
| 31 |
+
task_dict = None
|
| 32 |
+
if task_name and task_name != "tracefix_rl":
|
| 33 |
+
try:
|
| 34 |
+
from tasks.tasks import ALL_TASKS
|
| 35 |
+
for t in ALL_TASKS:
|
| 36 |
+
if t.get("name") == task_name:
|
| 37 |
+
task_dict = t
|
| 38 |
+
break
|
| 39 |
+
except ImportError:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
obs, system_prompt = self._gym.reset(task_index=task_dict)
|
| 43 |
self._state = State(
|
| 44 |
episode_id=obs.info.get("episode_id", ""),
|
| 45 |
step_count=obs.step_count,
|
inference.py
CHANGED
|
@@ -325,11 +325,13 @@ async def run(difficulty: Optional[str] = None, show_thought: bool = False) -> N
|
|
| 325 |
else:
|
| 326 |
env = TraceFixRLEnv(base_url=ENV_BASE_URL)
|
| 327 |
|
|
|
|
| 328 |
if difficulty:
|
| 329 |
-
reset_kwargs
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
| 333 |
task_name = result.observation.info.get("task_name") or TASK_NAME
|
| 334 |
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
|
| 335 |
started = True
|
|
|
|
| 325 |
else:
|
| 326 |
env = TraceFixRLEnv(base_url=ENV_BASE_URL)
|
| 327 |
|
| 328 |
+
reset_kwargs = {}
|
| 329 |
if difficulty:
|
| 330 |
+
reset_kwargs["difficulty"] = difficulty
|
| 331 |
+
if TASK_NAME and TASK_NAME != "tracefix_rl":
|
| 332 |
+
reset_kwargs["task_name"] = TASK_NAME
|
| 333 |
+
|
| 334 |
+
result = await env.reset(**reset_kwargs)
|
| 335 |
task_name = result.observation.info.get("task_name") or TASK_NAME
|
| 336 |
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
|
| 337 |
started = True
|
vision_ui.py
CHANGED
|
@@ -287,6 +287,16 @@ def _build_env(
|
|
| 287 |
|
| 288 |
|
| 289 |
def sync_tasks(selected, grid_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
if grid_name == "easy":
|
| 291 |
easy_val = selected
|
| 292 |
med_val = None
|
|
|
|
| 287 |
|
| 288 |
|
| 289 |
def sync_tasks(selected, grid_name):
|
| 290 |
+
if not selected:
|
| 291 |
+
return (
|
| 292 |
+
gr.skip(),
|
| 293 |
+
gr.skip(),
|
| 294 |
+
gr.skip(),
|
| 295 |
+
gr.skip(),
|
| 296 |
+
gr.skip(),
|
| 297 |
+
gr.skip(),
|
| 298 |
+
gr.skip()
|
| 299 |
+
)
|
| 300 |
if grid_name == "easy":
|
| 301 |
easy_val = selected
|
| 302 |
med_val = None
|