databoysu commited on
Commit
dcdd52f
·
1 Parent(s): d1abaef

revert to server directory

Browse files
server/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """TraceFix-RL server components."""
2
+
3
+ from .tracefix_rl_environment import TraceFixRLEnvironment
4
+
5
+ __all__ = ["TraceFixRLEnvironment"]
server/app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI entry point for TraceFix-RL."""
2
+
3
+ import gradio as gr
4
+ from vision_ui import demo
5
+
6
+ try:
7
+ from openenv.core.env_server.http_server import create_app
8
+ except Exception as e: # pragma: no cover
9
+ raise ImportError(
10
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
11
+ ) from e
12
+
13
+ try:
14
+ from core.models import CodeAction, CodeObservation
15
+ from server.tracefix_rl_environment import TraceFixRLEnvironment
16
+ except ImportError:
17
+ import sys
18
+ from pathlib import Path
19
+
20
+ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
21
+ from core.models import CodeAction, CodeObservation
22
+ from server.tracefix_rl_environment import TraceFixRLEnvironment
23
+
24
+
25
+ app = create_app(
26
+ TraceFixRLEnvironment,
27
+ CodeAction,
28
+ CodeObservation,
29
+ env_name="tracefix_rl",
30
+ max_concurrent_envs=1,
31
+ )
32
+
33
+ from fastapi.responses import RedirectResponse
34
+
35
+ @app.get("/", include_in_schema=False)
36
+ async def root_redirect():
37
+ return RedirectResponse(url="/web/")
38
+
39
+ @app.get("/web", include_in_schema=False)
40
+ async def web_no_slash_redirect():
41
+ return RedirectResponse(url="/web/")
42
+
43
+ app = gr.mount_gradio_app(app, demo, path="/web")
44
+
45
+
46
+ def main() -> None:
47
+ """Entry point for local and container execution."""
48
+ import os
49
+ import uvicorn
50
+
51
+ host = os.environ.get("HOST", "0.0.0.0")
52
+ port = int(os.environ.get("PORT", "7860"))
53
+ uvicorn.run(app, host=host, port=port)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
server/tracefix_rl_environment.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenEnv adapter around the TraceFix-RL core environment."""
2
+
3
+ from openenv.core.env_server.interfaces import Environment
4
+ from openenv.core.env_server.types import State
5
+
6
+ try:
7
+ from core.environment import TraceFixRLGym
8
+ from core.models import CodeAction, CodeObservation
9
+ except ImportError:
10
+ from core.environment import TraceFixRLGym
11
+ from core.models import CodeAction, CodeObservation
12
+
13
+
14
+ class TraceFixRLEnvironment(Environment):
15
+ """Environment implementation compatible with OpenEnv's server interface."""
16
+
17
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
18
+
19
+ def __init__(self):
20
+ self._gym = TraceFixRLGym()
21
+ self._state = State(episode_id="", step_count=0)
22
+
23
+ def reset(self, difficulty: str | None = None, task_name: str | None = None) -> CodeObservation:
24
+ if difficulty == "easy":
25
+ self._gym.training_step = 1
26
+ elif difficulty == "medium":
27
+ self._gym.training_step = 2000
28
+ elif difficulty == "hard":
29
+ self._gym.training_step = 6000
30
+
31
+ task_dict = None
32
+ if task_name and task_name != "tracefix_rl":
33
+ try:
34
+ from tasks.tasks import ALL_TASKS
35
+ for t in ALL_TASKS:
36
+ if t.get("name") == task_name:
37
+ task_dict = t
38
+ break
39
+ except ImportError:
40
+ pass
41
+
42
+ obs, system_prompt = self._gym.reset(task_index=task_dict)
43
+ self._state = State(
44
+ episode_id=obs.info.get("episode_id", ""),
45
+ step_count=obs.step_count,
46
+ )
47
+ metadata = dict(obs.metadata or {})
48
+ metadata["system_prompt"] = system_prompt
49
+ obs.metadata = metadata
50
+ return obs
51
+
52
+ def step(self, action: CodeAction) -> CodeObservation: # type: ignore[override]
53
+ obs, reward, done, info = self._gym.step(action)
54
+ obs.reward = reward
55
+ obs.done = done
56
+ metadata = dict(obs.metadata or {})
57
+ metadata.update(info)
58
+ obs.metadata = metadata
59
+ self._state = State(
60
+ episode_id=obs.info.get("episode_id", ""),
61
+ step_count=obs.step_count,
62
+ )
63
+ return obs
64
+
65
+ @property
66
+ def state(self) -> State:
67
+ return self._state