burtenshaw HF Staff commited on
Commit
73532b2
·
verified ·
1 Parent(s): 14fd9fa

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +36 -15
  2. README.md +401 -27
  3. __init__.py +31 -0
  4. client.py +120 -0
  5. envs/atari_env/README.md +408 -0
  6. envs/atari_env/__init__.py +31 -0
  7. envs/atari_env/client.py +120 -0
  8. envs/atari_env/models.py +85 -0
  9. envs/atari_env/server/Dockerfile +43 -0
  10. envs/atari_env/server/__init__.py +15 -0
  11. envs/atari_env/server/app.py +80 -0
  12. envs/atari_env/server/atari_environment.py +246 -0
  13. envs/atari_env/server/requirements.txt +3 -0
  14. envs/atari_env/test_atari_docker.sh +333 -0
  15. models.py +85 -0
  16. pyproject.toml +147 -0
  17. server/Dockerfile +43 -0
  18. server/__init__.py +15 -0
  19. server/app.py +80 -0
  20. server/atari_environment.py +246 -0
  21. server/requirements.txt +3 -0
  22. src/__init__.py +7 -0
  23. src/core/README.md +212 -0
  24. src/core/__init__.py +70 -8
  25. src/core/client_types.py +23 -0
  26. src/core/containers/__init__.py +1 -1
  27. src/core/containers/images/Dockerfile +29 -11
  28. src/core/containers/images/README.md +8 -8
  29. src/core/containers/runtime/__init__.py +12 -2
  30. src/core/containers/runtime/daytona_provider.py +572 -0
  31. src/core/containers/runtime/providers.py +389 -9
  32. src/core/containers/runtime/uv_provider.py +224 -0
  33. src/core/containers/test_local_docker_provider.py +8 -6
  34. src/core/env_client.py +484 -0
  35. src/core/env_server/__init__.py +118 -3
  36. src/core/env_server/base_transforms.py +1 -1
  37. src/core/env_server/exceptions.py +105 -0
  38. src/core/env_server/gradio_theme.py +128 -0
  39. src/core/env_server/gradio_ui.py +240 -0
  40. src/core/env_server/http_server.py +1263 -105
  41. src/core/env_server/interfaces.py +189 -10
  42. src/core/env_server/mcp_environment.py +624 -0
  43. src/core/env_server/mcp_types.py +321 -0
  44. src/core/env_server/route_config.py +57 -0
  45. src/core/env_server/serialization.py +137 -0
  46. src/core/env_server/types.py +361 -31
  47. src/core/env_server/web_interface.py +426 -1395
  48. src/core/evals/__init__.py +18 -0
  49. src/core/evals/base.py +62 -0
  50. src/core/evals/inspect_harness.py +160 -0
Dockerfile CHANGED
@@ -1,20 +1,39 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
 
 
 
 
 
 
3
  #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # Use the specified openenv-base image
8
- FROM ghcr.io/meta-pytorch/openenv-base:sha-7dd8148
9
- # Install ALE-specific dependencies
10
- RUN pip install --no-cache-dir \
11
- gymnasium>=0.29.0 \
12
- ale-py>=0.8.0 \
13
- numpy>=1.24.0
14
-
15
- # Copy only what's needed for this environment
16
  COPY src/core/ /app/src/core/
17
- COPY src/envs/atari_env/ /app/src/envs/atari_env/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Health check
20
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
@@ -22,4 +41,6 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
22
 
23
  # Run the FastAPI server
24
  CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
 
25
  ENV ENABLE_WEB_INTERFACE=true
 
1
+ # Dockerfile for Atari Environment
2
+ # This image provides Atari 2600 games via the Arcade Learning Environment (ALE)
3
+
4
+ # Configurable base image - defaults to local build, can be overridden for CI/CD
5
+ # Base image provides: fastapi, uvicorn, requests, curl, PYTHONPATH=/app/src
6
+ #
7
+ # Local build: docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile .
8
+ # docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
9
  #
10
+ # CI/CD build: docker build --build-arg BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest \
11
+ # -f envs/atari_env/server/Dockerfile -t atari-env:latest .
12
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
13
+ FROM ghcr.io/meta-pytorch/openenv-base:latest
14
+
15
+ # Install dependencies
16
+ COPY envs/atari_env/server/requirements.txt /tmp/requirements.txt
17
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt
18
+
19
+ # Copy OpenEnv core (base image already set WORKDIR=/app)
 
 
20
  COPY src/core/ /app/src/core/
21
+
22
+ # Copy Atari environment code
23
+ COPY envs/atari_env/ /app/envs/atari_env/
24
+
25
+ # Copy README for web interface documentation
26
+ COPY envs/atari_env/README.md /app/README.md
27
+
28
+ # Atari-specific environment variables (can be overridden at runtime)
29
+ ENV ATARI_GAME=pong
30
+ ENV ATARI_OBS_TYPE=rgb
31
+ ENV ATARI_FULL_ACTION_SPACE=false
32
+ ENV ATARI_REPEAT_ACTION_PROB=0.0
33
+ ENV ATARI_FRAMESKIP=4
34
+
35
+ # Expose port
36
+ EXPOSE 8000
37
 
38
  # Health check
39
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
 
41
 
42
  # Run the FastAPI server
43
  CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
44
+ ENV PYTHONPATH=/app/src/core:/app/src:${PYTHONPATH}
45
+
46
  ENV ENABLE_WEB_INTERFACE=true
README.md CHANGED
@@ -1,51 +1,425 @@
1
  ---
2
- title: atari_env
3
  emoji: 🕹️
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
  app_port: 8000
9
  base_path: /web
10
  tags:
 
11
  - openenv
12
  ---
13
 
14
- # Atari_env Environment Server
15
 
16
- FastAPI server for atari_env environment powered by Meta's OpenEnv.
17
 
18
- ## About
 
 
19
 
20
- This Space provides a containerized environment for atari_env interactions.
21
- Built with FastAPI and OpenEnv framework.
22
 
23
- ## Web Interface
 
24
 
25
- This deployment includes an interactive web interface for exploring the environment:
26
- - **HumanAgent Interface**: Interact with the environment using a web form
27
- - **State Observer**: Real-time view of environment state and action history
28
- - **Live Updates**: WebSocket-based real-time updates
29
 
30
- Access the web interface at: `/web`
31
 
32
- ## Atari Environment
33
 
34
- Provides Atari 2600 games via the Arcade Learning Environment (ALE).
35
 
36
- ### Usage
37
- Send a POST request to `/step` with:
38
- ```json
39
- {
40
- "action_id": 0,
41
- "game_name": "pong"
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  ```
44
 
45
- ## API Documentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- Visit `/docs` for interactive API documentation.
 
48
 
49
- ## Health Check
 
 
 
 
50
 
51
- The environment provides a health check endpoint at `/health`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Atari Environment Server
3
  emoji: 🕹️
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
  app_port: 8000
9
  base_path: /web
10
  tags:
11
+ - openenv-main
12
  - openenv
13
  ---
14
 
15
+ ## Hugging Face Space Deployment
16
 
17
+ This Space is built from OpenEnv environment `atari_env`.
18
 
19
+ - Space URL: `https://huggingface.co/spaces/openenv/atari_env`
20
+ - OpenEnv pinned ref: `main`
21
+ - Hub tag: `openenv`
22
 
23
+ ### Connecting from Code
 
24
 
25
+ ```python
26
+ from envs.atari_env import AtariEnv
27
 
28
+ env = AtariEnv(base_url="https://huggingface.co/spaces/openenv/atari_env")
29
+ ```
 
 
30
 
31
+ # Atari Environment
32
 
33
+ Integration of Atari 2600 games with the OpenEnv framework via the Arcade Learning Environment (ALE). ALE provides access to 100+ classic Atari games for RL research.
34
 
35
+ ## Supported Games
36
 
37
+ ALE supports 100+ Atari 2600 games including:
38
+
39
+ ### Popular Games
40
+ - **Pong** - Classic two-player tennis
41
+ - **Breakout** - Break bricks with a ball
42
+ - **Space Invaders** - Shoot descending aliens
43
+ - **Pac-Man / Ms. Pac-Man** - Navigate mazes and eat pellets
44
+ - **Asteroids** - Destroy asteroids in space
45
+ - **Defender** - Side-scrolling space shooter
46
+ - **Centipede** - Shoot segmented centipede
47
+ - **Donkey Kong** - Jump over barrels to save princess
48
+ - **Frogger** - Cross road and river safely
49
+ - **Q*bert** - Jump on pyramid cubes
50
+
51
+ And many more! For a complete list, see [ALE documentation](https://ale.farama.org/environments/complete_list/).
52
+
53
+ ## Architecture
54
+
55
+ ```
56
+ ┌────────────────────────────────────┐
57
+ │ RL Training Code (Client) │
58
+ │ AtariEnv.step(action) │
59
+ └──────────────┬─────────────────────┘
60
+ │ HTTP
61
+ ┌──────────────▼─────────────────────┐
62
+ │ FastAPI Server (Docker) │
63
+ │ AtariEnvironment │
64
+ │ ├─ Wraps ALEInterface │
65
+ │ ├─ Handles observations │
66
+ │ └─ Action execution │
67
+ └────────────────────────────────────┘
68
+ ```
69
+
70
+ ## Installation & Usage
71
+
72
+ ### Option 1: Local Development (without Docker)
73
+
74
+ **Requirements:**
75
+ - Python 3.11+
76
+ - ale-py installed: `pip install ale-py`
77
+
78
+ The client is **async by default**:
79
+
80
+ ```python
81
+ import asyncio
82
+ from atari_env import AtariEnv, AtariAction
83
+
84
+ async def main():
85
+ # Start local server manually: python -m atari_env.server.app
86
+ async with AtariEnv(base_url="http://localhost:8000") as env:
87
+ # Reset environment
88
+ result = await env.reset()
89
+ print(f"Screen shape: {result.observation.screen_shape}")
90
+ print(f"Legal actions: {result.observation.legal_actions}")
91
+
92
+ # Take actions
93
+ for _ in range(10):
94
+ result = await env.step(AtariAction(action_id=2, game_name="pong"))
95
+ print(f"Reward: {result.reward}, Done: {result.done}")
96
+ if result.done:
97
+ break
98
+
99
+ asyncio.run(main())
100
+ ```
101
+
102
+ For **synchronous usage**, use the `.sync()` wrapper:
103
+
104
+ ```python
105
+ from atari_env import AtariEnv, AtariAction
106
+
107
+ with AtariEnv(base_url="http://localhost:8000").sync() as env:
108
+ result = env.reset()
109
+ result = env.step(AtariAction(action_id=2, game_name="pong"))
110
+ print(f"Reward: {result.reward}")
111
  ```
112
 
113
+ ### Option 2: Docker (Recommended)
114
+
115
+ **Build Atari image:**
116
+
117
+ ```bash
118
+ cd OpenEnv
119
+
120
+ # Build the image
121
+ docker build \
122
+ -f envs/atari_env/server/Dockerfile \
123
+ -t atari-env:latest \
124
+ .
125
+ ```
126
+
127
+ **Run specific games:**
128
+
129
+ ```bash
130
+ # Pong (default)
131
+ docker run -p 8000:8000 atari-env:latest
132
 
133
+ # Breakout
134
+ docker run -p 8000:8000 -e ATARI_GAME=breakout atari-env:latest
135
 
136
+ # Space Invaders with grayscale observation
137
+ docker run -p 8000:8000 \
138
+ -e ATARI_GAME=space_invaders \
139
+ -e ATARI_OBS_TYPE=grayscale \
140
+ atari-env:latest
141
 
142
+ # Ms. Pac-Man with full action space
143
+ docker run -p 8000:8000 \
144
+ -e ATARI_GAME=ms_pacman \
145
+ -e ATARI_FULL_ACTION_SPACE=true \
146
+ atari-env:latest
147
+ ```
148
+
149
+ **Use with from_docker_image():**
150
+
151
+ ```python
152
+ import asyncio
153
+ import numpy as np
154
+ from atari_env import AtariEnv, AtariAction
155
+
156
+ async def main():
157
+ # Automatically starts container
158
+ client = await AtariEnv.from_docker_image("atari-env:latest")
159
+
160
+ async with client:
161
+ result = await client.reset()
162
+ result = await client.step(AtariAction(action_id=2)) # UP
163
+
164
+ # Reshape screen for visualization
165
+ screen = np.array(result.observation.screen).reshape(result.observation.screen_shape)
166
+ print(f"Screen shape: {screen.shape}") # (210, 160, 3) for RGB
167
+
168
+ asyncio.run(main())
169
+ ```
170
+
171
+ ## Observation Types
172
+
173
+ ### 1. RGB (Default)
174
+ - **Shape**: [210, 160, 3]
175
+ - **Description**: Full-color screen observation
176
+ - **Usage**: Most realistic, good for vision-based learning
177
+
178
+ ```python
179
+ docker run -p 8000:8000 -e ATARI_OBS_TYPE=rgb atari-env:latest
180
+ ```
181
+
182
+ ### 2. Grayscale
183
+ - **Shape**: [210, 160]
184
+ - **Description**: Grayscale screen observation
185
+ - **Usage**: Reduced dimensionality, faster processing
186
+
187
+ ```python
188
+ docker run -p 8000:8000 -e ATARI_OBS_TYPE=grayscale atari-env:latest
189
+ ```
190
+
191
+ ### 3. RAM
192
+ - **Shape**: [128]
193
+ - **Description**: Raw 128-byte Atari 2600 RAM contents
194
+ - **Usage**: Compact representation, useful for specific research
195
+
196
+ ```python
197
+ docker run -p 8000:8000 -e ATARI_OBS_TYPE=ram atari-env:latest
198
+ ```
199
+
200
+ ## Action Spaces
201
+
202
+ ### Minimal Action Set (Default)
203
+ Game-specific minimal actions (typically 4-9 actions).
204
+ - Pong: 6 actions (NOOP, FIRE, UP, DOWN, etc.)
205
+ - Breakout: 4 actions (NOOP, FIRE, LEFT, RIGHT)
206
+
207
+ ```python
208
+ docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=false atari-env:latest
209
+ ```
210
+
211
+ ### Full Action Set
212
+ All 18 possible Atari 2600 actions:
213
+ 0. NOOP
214
+ 1. FIRE
215
+ 2. UP
216
+ 3. RIGHT
217
+ 4. LEFT
218
+ 5. DOWN
219
+ 6. UPRIGHT
220
+ 7. UPLEFT
221
+ 8. DOWNRIGHT
222
+ 9. DOWNLEFT
223
+ 10. UPFIRE
224
+ 11. RIGHTFIRE
225
+ 12. LEFTFIRE
226
+ 13. DOWNFIRE
227
+ 14. UPRIGHTFIRE
228
+ 15. UPLEFTFIRE
229
+ 16. DOWNRIGHTFIRE
230
+ 17. DOWNLEFTFIRE
231
+
232
+ ```python
233
+ docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=true atari-env:latest
234
+ ```
235
+
236
+ ## Configuration
237
+
238
+ ### Environment Variables
239
+
240
+ - `ATARI_GAME`: Game name (default: "pong")
241
+ - `ATARI_OBS_TYPE`: Observation type - "rgb", "grayscale", "ram" (default: "rgb")
242
+ - `ATARI_FULL_ACTION_SPACE`: Use full action space - "true"/"false" (default: "false")
243
+ - `ATARI_MODE`: Game mode (optional, game-specific)
244
+ - `ATARI_DIFFICULTY`: Game difficulty (optional, game-specific)
245
+ - `ATARI_REPEAT_ACTION_PROB`: Sticky action probability 0.0-1.0 (default: "0.0")
246
+ - `ATARI_FRAMESKIP`: Frames to skip per action (default: "4")
247
+
248
+ ### Example: Breakout with Custom Settings
249
+
250
+ ```bash
251
+ docker run -p 8000:8000 \
252
+ -e ATARI_GAME=breakout \
253
+ -e ATARI_OBS_TYPE=grayscale \
254
+ -e ATARI_FULL_ACTION_SPACE=true \
255
+ -e ATARI_REPEAT_ACTION_PROB=0.25 \
256
+ -e ATARI_FRAMESKIP=4 \
257
+ atari-env:latest
258
+ ```
259
+
260
+ ## API Reference
261
+
262
+ ### AtariAction
263
+
264
+ ```python
265
+ @dataclass
266
+ class AtariAction(Action):
267
+ action_id: int # Action index to execute
268
+ game_name: str = "pong" # Game name
269
+ obs_type: str = "rgb" # Observation type
270
+ full_action_space: bool = False # Full or minimal action space
271
+ ```
272
+
273
+ ### AtariObservation
274
+
275
+ ```python
276
+ @dataclass
277
+ class AtariObservation(Observation):
278
+ screen: List[int] # Flattened screen pixels
279
+ screen_shape: List[int] # Original screen shape
280
+ legal_actions: List[int] # Legal action indices
281
+ lives: int # Lives remaining
282
+ episode_frame_number: int # Frame # in episode
283
+ frame_number: int # Total frame #
284
+ done: bool # Episode finished
285
+ reward: Optional[float] # Reward from last action
286
+ ```
287
+
288
+ ### AtariState
289
+
290
+ ```python
291
+ @dataclass
292
+ class AtariState(State):
293
+ episode_id: str # Unique episode ID
294
+ step_count: int # Number of steps
295
+ game_name: str # Game name
296
+ obs_type: str # Observation type
297
+ full_action_space: bool # Action space type
298
+ mode: Optional[int] # Game mode
299
+ difficulty: Optional[int] # Game difficulty
300
+ repeat_action_probability: float # Sticky action prob
301
+ frameskip: int # Frameskip setting
302
+ ```
303
+
304
+ ## Example Script
305
+
306
+ ```python
307
+ #!/usr/bin/env python3
308
+ """Example training loop with Atari environment."""
309
+
310
+ import asyncio
311
+ import numpy as np
312
+ from atari_env import AtariEnv, AtariAction
313
+
314
+ async def train():
315
+ # Start environment
316
+ client = await AtariEnv.from_docker_image("atari-env:latest")
317
+
318
+ async with client:
319
+ # Training loop
320
+ for episode in range(10):
321
+ result = await client.reset()
322
+ episode_reward = 0
323
+ steps = 0
324
+
325
+ while not result.done:
326
+ # Random policy (replace with your RL agent)
327
+ action_id = np.random.choice(result.observation.legal_actions)
328
+
329
+ # Take action
330
+ result = await client.step(AtariAction(action_id=action_id))
331
+
332
+ episode_reward += result.reward or 0
333
+ steps += 1
334
+
335
+ # Reshape screen for processing
336
+ screen = np.array(result.observation.screen).reshape(
337
+ result.observation.screen_shape
338
+ )
339
+
340
+ # Your RL training code here
341
+ # ...
342
+
343
+ print(f"Episode {episode}: reward={episode_reward:.2f}, steps={steps}")
344
+
345
+ asyncio.run(train())
346
+ ```
347
+
348
+ ## Testing
349
+
350
+ ### Local Testing
351
+
352
+ ```bash
353
+ # Install dependencies
354
+ pip install ale-py fastapi uvicorn requests
355
+
356
+ # Start server
357
+ export PYTHONPATH=src:envs
358
+ python -m atari_env.server.app
359
+
360
+ # Test from another terminal (using sync wrapper for simplicity)
361
+ python -c "
362
+ from atari_env import AtariEnv, AtariAction
363
+ with AtariEnv(base_url='http://localhost:8000').sync() as env:
364
+ result = env.reset()
365
+ print(f'Initial obs: {result.observation.screen_shape}')
366
+ result = env.step(AtariAction(action_id=2))
367
+ print(f'After step: reward={result.reward}, done={result.done}')
368
+ "
369
+ ```
370
+
371
+ ### Docker Testing
372
+
373
+ ```bash
374
+ # Build and run
375
+ docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
376
+ docker run -p 8000:8000 atari-env:latest
377
+
378
+ # Test in another terminal
379
+ curl http://localhost:8000/health
380
+ curl -X POST http://localhost:8000/reset
381
+ ```
382
+
383
+ ## Popular Games and Their Characteristics
384
+
385
+ | Game | Minimal Actions | Lives | Difficulty | Notes |
386
+ |------|----------------|-------|-----------|-------|
387
+ | Pong | 6 | 1 | Low | Good for learning basics |
388
+ | Breakout | 4 | 5 | Medium | Classic RL benchmark |
389
+ | Space Invaders | 6 | 3 | Medium | Shooting game |
390
+ | Ms. Pac-Man | 9 | 3 | High | Complex navigation |
391
+ | Asteroids | 14 | 3 | Medium | Continuous shooting |
392
+ | Montezuma's Revenge | 18 | 5 | Very High | Exploration challenge |
393
+ | Pitfall | 18 | 1 | High | Platformer |
394
+ | Seaquest | 18 | 3 | High | Submarine rescue |
395
+
396
+ ## Limitations & Notes
397
+
398
+ - **Frame perfect timing**: Some games require precise timing
399
+ - **Exploration**: Games like Montezuma's Revenge are notoriously difficult
400
+ - **Observation delay**: HTTP adds minimal latency vs local gym
401
+ - **Determinism**: Set `ATARI_REPEAT_ACTION_PROB=0.0` for deterministic behavior
402
+ - **ROMs**: All ROMs are bundled with ale-py package
403
+
404
+ ## References
405
+
406
+ - [Arcade Learning Environment Paper (2013)](https://jair.org/index.php/jair/article/view/10819)
407
+ - [ALE GitHub](https://github.com/Farama-Foundation/Arcade-Learning-Environment)
408
+ - [ALE Documentation](https://ale.farama.org/)
409
+ - [Gymnasium Atari Environments](https://gymnasium.farama.org/environments/atari/)
410
+
411
+ ## Citation
412
+
413
+ If you use ALE in your research, please cite:
414
+
415
+ ```bibtex
416
+ @Article{bellemare13arcade,
417
+ author = {{Bellemare}, M.~G. and {Naddaf}, Y. and {Veness}, J. and {Bowling}, M.},
418
+ title = {The Arcade Learning Environment: An Evaluation Platform for General Agents},
419
+ journal = {Journal of Artificial Intelligence Research},
420
+ year = "2013",
421
+ month = "jun",
422
+ volume = "47",
423
+ pages = "253--279",
424
+ }
425
+ ```
__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment for OpenEnv.
9
+
10
+ This module provides OpenEnv integration for Atari 2600 games via the
11
+ Arcade Learning Environment (ALE).
12
+
13
+ Example:
14
+ >>> from envs.atari_env import AtariEnv, AtariAction
15
+ >>>
16
+ >>> # Connect to a running server or start via Docker
17
+ >>> env = AtariEnv.from_docker_image("atari-env:latest")
18
+ >>>
19
+ >>> # Reset and interact
20
+ >>> result = env.reset()
21
+ >>> result = env.step(AtariAction(action_id=2)) # UP
22
+ >>> print(result.reward, result.done)
23
+ >>>
24
+ >>> # Cleanup
25
+ >>> env.close()
26
+ """
27
+
28
+ from .client import AtariEnv
29
+ from .models import AtariAction, AtariObservation, AtariState
30
+
31
+ __all__ = ["AtariEnv", "AtariAction", "AtariObservation", "AtariState"]
client.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment Client.
9
+
10
+ This module provides the client for connecting to an Atari Environment server
11
+ via WebSocket for persistent sessions.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, Dict, TYPE_CHECKING
17
+
18
+ from openenv.core.client_types import StepResult
19
+ from openenv.core.env_client import EnvClient
20
+
21
+ from .models import AtariAction, AtariObservation, AtariState
22
+
23
+ if TYPE_CHECKING:
24
+ from openenv.core.containers.runtime import ContainerProvider
25
+
26
+
27
+ class AtariEnv(EnvClient[AtariAction, AtariObservation, AtariState]):
28
+ """
29
+ Client for Atari Environment.
30
+
31
+ This client maintains a persistent WebSocket connection to the environment
32
+ server, enabling efficient multi-step interactions with lower latency.
33
+
34
+ Example:
35
+ >>> # Connect to a running server
36
+ >>> with AtariEnv(base_url="http://localhost:8000") as client:
37
+ ... result = client.reset()
38
+ ... print(result.observation.screen_shape)
39
+ ...
40
+ ... result = client.step(AtariAction(action_id=2)) # UP
41
+ ... print(result.reward, result.done)
42
+
43
+ Example with Docker:
44
+ >>> # Automatically start container and connect
45
+ >>> client = AtariEnv.from_docker_image("atari-env:latest")
46
+ >>> try:
47
+ ... result = client.reset()
48
+ ... result = client.step(AtariAction(action_id=0)) # NOOP
49
+ ... finally:
50
+ ... client.close()
51
+ """
52
+
53
+ def _step_payload(self, action: AtariAction) -> Dict[str, Any]:
54
+ """
55
+ Convert AtariAction to JSON payload for step request.
56
+
57
+ Args:
58
+ action: AtariAction instance.
59
+
60
+ Returns:
61
+ Dictionary representation suitable for JSON encoding.
62
+ """
63
+ return {
64
+ "action_id": action.action_id,
65
+ "game_name": action.game_name,
66
+ "obs_type": action.obs_type,
67
+ "full_action_space": action.full_action_space,
68
+ }
69
+
70
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AtariObservation]:
71
+ """
72
+ Parse server response into StepResult[AtariObservation].
73
+
74
+ Args:
75
+ payload: JSON response from server.
76
+
77
+ Returns:
78
+ StepResult with AtariObservation.
79
+ """
80
+ obs_data = payload.get("observation", {})
81
+
82
+ observation = AtariObservation(
83
+ screen=obs_data.get("screen", []),
84
+ screen_shape=obs_data.get("screen_shape", []),
85
+ legal_actions=obs_data.get("legal_actions", []),
86
+ lives=obs_data.get("lives", 0),
87
+ episode_frame_number=obs_data.get("episode_frame_number", 0),
88
+ frame_number=obs_data.get("frame_number", 0),
89
+ done=payload.get("done", False),
90
+ reward=payload.get("reward"),
91
+ metadata=obs_data.get("metadata", {}),
92
+ )
93
+
94
+ return StepResult(
95
+ observation=observation,
96
+ reward=payload.get("reward"),
97
+ done=payload.get("done", False),
98
+ )
99
+
100
+ def _parse_state(self, payload: Dict[str, Any]) -> AtariState:
101
+ """
102
+ Parse server response into AtariState object.
103
+
104
+ Args:
105
+ payload: JSON response from /state endpoint.
106
+
107
+ Returns:
108
+ AtariState object with environment state information.
109
+ """
110
+ return AtariState(
111
+ episode_id=payload.get("episode_id"),
112
+ step_count=payload.get("step_count", 0),
113
+ game_name=payload.get("game_name", "unknown"),
114
+ obs_type=payload.get("obs_type", "rgb"),
115
+ full_action_space=payload.get("full_action_space", False),
116
+ mode=payload.get("mode"),
117
+ difficulty=payload.get("difficulty"),
118
+ repeat_action_probability=payload.get("repeat_action_probability", 0.0),
119
+ frameskip=payload.get("frameskip", 4),
120
+ )
envs/atari_env/README.md ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Atari Environment Server
3
+ emoji: 🕹️
4
+ colorFrom: '#FF6200'
5
+ colorTo: '#D4151B'
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ ---
13
+
14
+ # Atari Environment
15
+
16
+ Integration of Atari 2600 games with the OpenEnv framework via the Arcade Learning Environment (ALE). ALE provides access to 100+ classic Atari games for RL research.
17
+
18
+ ## Supported Games
19
+
20
+ ALE supports 100+ Atari 2600 games including:
21
+
22
+ ### Popular Games
23
+ - **Pong** - Classic two-player tennis
24
+ - **Breakout** - Break bricks with a ball
25
+ - **Space Invaders** - Shoot descending aliens
26
+ - **Pac-Man / Ms. Pac-Man** - Navigate mazes and eat pellets
27
+ - **Asteroids** - Destroy asteroids in space
28
+ - **Defender** - Side-scrolling space shooter
29
+ - **Centipede** - Shoot segmented centipede
30
+ - **Donkey Kong** - Jump over barrels to save princess
31
+ - **Frogger** - Cross road and river safely
32
+ - **Q*bert** - Jump on pyramid cubes
33
+
34
+ And many more! For a complete list, see [ALE documentation](https://ale.farama.org/environments/complete_list/).
35
+
36
+ ## Architecture
37
+
38
+ ```
39
+ ┌────────────────────────────────────┐
40
+ │ RL Training Code (Client) │
41
+ │ AtariEnv.step(action) │
42
+ └──────────────┬─────────────────────┘
43
+ │ HTTP
44
+ ┌──────────────▼─────────────────────┐
45
+ │ FastAPI Server (Docker) │
46
+ │ AtariEnvironment │
47
+ │ ├─ Wraps ALEInterface │
48
+ │ ├─ Handles observations │
49
+ │ └─ Action execution │
50
+ └────────────────────────────────────┘
51
+ ```
52
+
53
+ ## Installation & Usage
54
+
55
+ ### Option 1: Local Development (without Docker)
56
+
57
+ **Requirements:**
58
+ - Python 3.11+
59
+ - ale-py installed: `pip install ale-py`
60
+
61
+ The client is **async by default**:
62
+
63
+ ```python
64
+ import asyncio
65
+ from atari_env import AtariEnv, AtariAction
66
+
67
+ async def main():
68
+ # Start local server manually: python -m atari_env.server.app
69
+ async with AtariEnv(base_url="http://localhost:8000") as env:
70
+ # Reset environment
71
+ result = await env.reset()
72
+ print(f"Screen shape: {result.observation.screen_shape}")
73
+ print(f"Legal actions: {result.observation.legal_actions}")
74
+
75
+ # Take actions
76
+ for _ in range(10):
77
+ result = await env.step(AtariAction(action_id=2, game_name="pong"))
78
+ print(f"Reward: {result.reward}, Done: {result.done}")
79
+ if result.done:
80
+ break
81
+
82
+ asyncio.run(main())
83
+ ```
84
+
85
+ For **synchronous usage**, use the `.sync()` wrapper:
86
+
87
+ ```python
88
+ from atari_env import AtariEnv, AtariAction
89
+
90
+ with AtariEnv(base_url="http://localhost:8000").sync() as env:
91
+ result = env.reset()
92
+ result = env.step(AtariAction(action_id=2, game_name="pong"))
93
+ print(f"Reward: {result.reward}")
94
+ ```
95
+
96
+ ### Option 2: Docker (Recommended)
97
+
98
+ **Build Atari image:**
99
+
100
+ ```bash
101
+ cd OpenEnv
102
+
103
+ # Build the image
104
+ docker build \
105
+ -f envs/atari_env/server/Dockerfile \
106
+ -t atari-env:latest \
107
+ .
108
+ ```
109
+
110
+ **Run specific games:**
111
+
112
+ ```bash
113
+ # Pong (default)
114
+ docker run -p 8000:8000 atari-env:latest
115
+
116
+ # Breakout
117
+ docker run -p 8000:8000 -e ATARI_GAME=breakout atari-env:latest
118
+
119
+ # Space Invaders with grayscale observation
120
+ docker run -p 8000:8000 \
121
+ -e ATARI_GAME=space_invaders \
122
+ -e ATARI_OBS_TYPE=grayscale \
123
+ atari-env:latest
124
+
125
+ # Ms. Pac-Man with full action space
126
+ docker run -p 8000:8000 \
127
+ -e ATARI_GAME=ms_pacman \
128
+ -e ATARI_FULL_ACTION_SPACE=true \
129
+ atari-env:latest
130
+ ```
131
+
132
+ **Use with from_docker_image():**
133
+
134
+ ```python
135
+ import asyncio
136
+ import numpy as np
137
+ from atari_env import AtariEnv, AtariAction
138
+
139
+ async def main():
140
+ # Automatically starts container
141
+ client = await AtariEnv.from_docker_image("atari-env:latest")
142
+
143
+ async with client:
144
+ result = await client.reset()
145
+ result = await client.step(AtariAction(action_id=2)) # UP
146
+
147
+ # Reshape screen for visualization
148
+ screen = np.array(result.observation.screen).reshape(result.observation.screen_shape)
149
+ print(f"Screen shape: {screen.shape}") # (210, 160, 3) for RGB
150
+
151
+ asyncio.run(main())
152
+ ```
153
+
154
+ ## Observation Types
155
+
156
+ ### 1. RGB (Default)
157
+ - **Shape**: [210, 160, 3]
158
+ - **Description**: Full-color screen observation
159
+ - **Usage**: Most realistic, good for vision-based learning
160
+
161
+ ```python
162
+ docker run -p 8000:8000 -e ATARI_OBS_TYPE=rgb atari-env:latest
163
+ ```
164
+
165
+ ### 2. Grayscale
166
+ - **Shape**: [210, 160]
167
+ - **Description**: Grayscale screen observation
168
+ - **Usage**: Reduced dimensionality, faster processing
169
+
170
+ ```python
171
+ docker run -p 8000:8000 -e ATARI_OBS_TYPE=grayscale atari-env:latest
172
+ ```
173
+
174
+ ### 3. RAM
175
+ - **Shape**: [128]
176
+ - **Description**: Raw 128-byte Atari 2600 RAM contents
177
+ - **Usage**: Compact representation, useful for specific research
178
+
179
+ ```python
180
+ docker run -p 8000:8000 -e ATARI_OBS_TYPE=ram atari-env:latest
181
+ ```
182
+
183
+ ## Action Spaces
184
+
185
+ ### Minimal Action Set (Default)
186
+ Game-specific minimal actions (typically 4-9 actions).
187
+ - Pong: 6 actions (NOOP, FIRE, UP, DOWN, etc.)
188
+ - Breakout: 4 actions (NOOP, FIRE, LEFT, RIGHT)
189
+
190
+ ```python
191
+ docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=false atari-env:latest
192
+ ```
193
+
194
+ ### Full Action Set
195
+ All 18 possible Atari 2600 actions:
196
+ 0. NOOP
197
+ 1. FIRE
198
+ 2. UP
199
+ 3. RIGHT
200
+ 4. LEFT
201
+ 5. DOWN
202
+ 6. UPRIGHT
203
+ 7. UPLEFT
204
+ 8. DOWNRIGHT
205
+ 9. DOWNLEFT
206
+ 10. UPFIRE
207
+ 11. RIGHTFIRE
208
+ 12. LEFTFIRE
209
+ 13. DOWNFIRE
210
+ 14. UPRIGHTFIRE
211
+ 15. UPLEFTFIRE
212
+ 16. DOWNRIGHTFIRE
213
+ 17. DOWNLEFTFIRE
214
+
215
+ ```python
216
+ docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=true atari-env:latest
217
+ ```
218
+
219
+ ## Configuration
220
+
221
+ ### Environment Variables
222
+
223
+ - `ATARI_GAME`: Game name (default: "pong")
224
+ - `ATARI_OBS_TYPE`: Observation type - "rgb", "grayscale", "ram" (default: "rgb")
225
+ - `ATARI_FULL_ACTION_SPACE`: Use full action space - "true"/"false" (default: "false")
226
+ - `ATARI_MODE`: Game mode (optional, game-specific)
227
+ - `ATARI_DIFFICULTY`: Game difficulty (optional, game-specific)
228
+ - `ATARI_REPEAT_ACTION_PROB`: Sticky action probability 0.0-1.0 (default: "0.0")
229
+ - `ATARI_FRAMESKIP`: Frames to skip per action (default: "4")
230
+
231
+ ### Example: Breakout with Custom Settings
232
+
233
+ ```bash
234
+ docker run -p 8000:8000 \
235
+ -e ATARI_GAME=breakout \
236
+ -e ATARI_OBS_TYPE=grayscale \
237
+ -e ATARI_FULL_ACTION_SPACE=true \
238
+ -e ATARI_REPEAT_ACTION_PROB=0.25 \
239
+ -e ATARI_FRAMESKIP=4 \
240
+ atari-env:latest
241
+ ```
242
+
243
+ ## API Reference
244
+
245
+ ### AtariAction
246
+
247
+ ```python
248
+ @dataclass
249
+ class AtariAction(Action):
250
+ action_id: int # Action index to execute
251
+ game_name: str = "pong" # Game name
252
+ obs_type: str = "rgb" # Observation type
253
+ full_action_space: bool = False # Full or minimal action space
254
+ ```
255
+
256
+ ### AtariObservation
257
+
258
+ ```python
259
+ @dataclass
260
+ class AtariObservation(Observation):
261
+ screen: List[int] # Flattened screen pixels
262
+ screen_shape: List[int] # Original screen shape
263
+ legal_actions: List[int] # Legal action indices
264
+ lives: int # Lives remaining
265
+ episode_frame_number: int # Frame # in episode
266
+ frame_number: int # Total frame #
267
+ done: bool # Episode finished
268
+ reward: Optional[float] # Reward from last action
269
+ ```
270
+
271
+ ### AtariState
272
+
273
+ ```python
274
+ @dataclass
275
+ class AtariState(State):
276
+ episode_id: str # Unique episode ID
277
+ step_count: int # Number of steps
278
+ game_name: str # Game name
279
+ obs_type: str # Observation type
280
+ full_action_space: bool # Action space type
281
+ mode: Optional[int] # Game mode
282
+ difficulty: Optional[int] # Game difficulty
283
+ repeat_action_probability: float # Sticky action prob
284
+ frameskip: int # Frameskip setting
285
+ ```
286
+
287
+ ## Example Script
288
+
289
+ ```python
290
+ #!/usr/bin/env python3
291
+ """Example training loop with Atari environment."""
292
+
293
+ import asyncio
294
+ import numpy as np
295
+ from atari_env import AtariEnv, AtariAction
296
+
297
+ async def train():
298
+ # Start environment
299
+ client = await AtariEnv.from_docker_image("atari-env:latest")
300
+
301
+ async with client:
302
+ # Training loop
303
+ for episode in range(10):
304
+ result = await client.reset()
305
+ episode_reward = 0
306
+ steps = 0
307
+
308
+ while not result.done:
309
+ # Random policy (replace with your RL agent)
310
+ action_id = np.random.choice(result.observation.legal_actions)
311
+
312
+ # Take action
313
+ result = await client.step(AtariAction(action_id=action_id))
314
+
315
+ episode_reward += result.reward or 0
316
+ steps += 1
317
+
318
+ # Reshape screen for processing
319
+ screen = np.array(result.observation.screen).reshape(
320
+ result.observation.screen_shape
321
+ )
322
+
323
+ # Your RL training code here
324
+ # ...
325
+
326
+ print(f"Episode {episode}: reward={episode_reward:.2f}, steps={steps}")
327
+
328
+ asyncio.run(train())
329
+ ```
330
+
331
+ ## Testing
332
+
333
+ ### Local Testing
334
+
335
+ ```bash
336
+ # Install dependencies
337
+ pip install ale-py fastapi uvicorn requests
338
+
339
+ # Start server
340
+ export PYTHONPATH=src:envs
341
+ python -m atari_env.server.app
342
+
343
+ # Test from another terminal (using sync wrapper for simplicity)
344
+ python -c "
345
+ from atari_env import AtariEnv, AtariAction
346
+ with AtariEnv(base_url='http://localhost:8000').sync() as env:
347
+ result = env.reset()
348
+ print(f'Initial obs: {result.observation.screen_shape}')
349
+ result = env.step(AtariAction(action_id=2))
350
+ print(f'After step: reward={result.reward}, done={result.done}')
351
+ "
352
+ ```
353
+
354
+ ### Docker Testing
355
+
356
+ ```bash
357
+ # Build and run
358
+ docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
359
+ docker run -p 8000:8000 atari-env:latest
360
+
361
+ # Test in another terminal
362
+ curl http://localhost:8000/health
363
+ curl -X POST http://localhost:8000/reset
364
+ ```
365
+
366
+ ## Popular Games and Their Characteristics
367
+
368
+ | Game | Minimal Actions | Lives | Difficulty | Notes |
369
+ |------|----------------|-------|-----------|-------|
370
+ | Pong | 6 | 1 | Low | Good for learning basics |
371
+ | Breakout | 4 | 5 | Medium | Classic RL benchmark |
372
+ | Space Invaders | 6 | 3 | Medium | Shooting game |
373
+ | Ms. Pac-Man | 9 | 3 | High | Complex navigation |
374
+ | Asteroids | 14 | 3 | Medium | Continuous shooting |
375
+ | Montezuma's Revenge | 18 | 5 | Very High | Exploration challenge |
376
+ | Pitfall | 18 | 1 | High | Platformer |
377
+ | Seaquest | 18 | 3 | High | Submarine rescue |
378
+
379
+ ## Limitations & Notes
380
+
381
+ - **Frame perfect timing**: Some games require precise timing
382
+ - **Exploration**: Games like Montezuma's Revenge are notoriously difficult
383
+ - **Observation delay**: HTTP adds minimal latency vs local gym
384
+ - **Determinism**: Set `ATARI_REPEAT_ACTION_PROB=0.0` for deterministic behavior
385
+ - **ROMs**: All ROMs are bundled with ale-py package
386
+
387
+ ## References
388
+
389
+ - [Arcade Learning Environment Paper (2013)](https://jair.org/index.php/jair/article/view/10819)
390
+ - [ALE GitHub](https://github.com/Farama-Foundation/Arcade-Learning-Environment)
391
+ - [ALE Documentation](https://ale.farama.org/)
392
+ - [Gymnasium Atari Environments](https://gymnasium.farama.org/environments/atari/)
393
+
394
+ ## Citation
395
+
396
+ If you use ALE in your research, please cite:
397
+
398
+ ```bibtex
399
+ @Article{bellemare13arcade,
400
+ author = {{Bellemare}, M.~G. and {Naddaf}, Y. and {Veness}, J. and {Bowling}, M.},
401
+ title = {The Arcade Learning Environment: An Evaluation Platform for General Agents},
402
+ journal = {Journal of Artificial Intelligence Research},
403
+ year = "2013",
404
+ month = "jun",
405
+ volume = "47",
406
+ pages = "253--279",
407
+ }
408
+ ```
envs/atari_env/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment for OpenEnv.
9
+
10
+ This module provides OpenEnv integration for Atari 2600 games via the
11
+ Arcade Learning Environment (ALE).
12
+
13
+ Example:
14
+ >>> from envs.atari_env import AtariEnv, AtariAction
15
+ >>>
16
+ >>> # Connect to a running server or start via Docker
17
+ >>> env = AtariEnv.from_docker_image("atari-env:latest")
18
+ >>>
19
+ >>> # Reset and interact
20
+ >>> result = env.reset()
21
+ >>> result = env.step(AtariAction(action_id=2)) # UP
22
+ >>> print(result.reward, result.done)
23
+ >>>
24
+ >>> # Cleanup
25
+ >>> env.close()
26
+ """
27
+
28
+ from .client import AtariEnv
29
+ from .models import AtariAction, AtariObservation, AtariState
30
+
31
+ __all__ = ["AtariEnv", "AtariAction", "AtariObservation", "AtariState"]
envs/atari_env/client.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment Client.
9
+
10
+ This module provides the client for connecting to an Atari Environment server
11
+ via WebSocket for persistent sessions.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, Dict, TYPE_CHECKING
17
+
18
+ from openenv.core.client_types import StepResult
19
+ from openenv.core.env_client import EnvClient
20
+
21
+ from .models import AtariAction, AtariObservation, AtariState
22
+
23
+ if TYPE_CHECKING:
24
+ from openenv.core.containers.runtime import ContainerProvider
25
+
26
+
27
+ class AtariEnv(EnvClient[AtariAction, AtariObservation, AtariState]):
28
+ """
29
+ Client for Atari Environment.
30
+
31
+ This client maintains a persistent WebSocket connection to the environment
32
+ server, enabling efficient multi-step interactions with lower latency.
33
+
34
+ Example:
35
+ >>> # Connect to a running server
36
+ >>> with AtariEnv(base_url="http://localhost:8000") as client:
37
+ ... result = client.reset()
38
+ ... print(result.observation.screen_shape)
39
+ ...
40
+ ... result = client.step(AtariAction(action_id=2)) # UP
41
+ ... print(result.reward, result.done)
42
+
43
+ Example with Docker:
44
+ >>> # Automatically start container and connect
45
+ >>> client = AtariEnv.from_docker_image("atari-env:latest")
46
+ >>> try:
47
+ ... result = client.reset()
48
+ ... result = client.step(AtariAction(action_id=0)) # NOOP
49
+ ... finally:
50
+ ... client.close()
51
+ """
52
+
53
+ def _step_payload(self, action: AtariAction) -> Dict[str, Any]:
54
+ """
55
+ Convert AtariAction to JSON payload for step request.
56
+
57
+ Args:
58
+ action: AtariAction instance.
59
+
60
+ Returns:
61
+ Dictionary representation suitable for JSON encoding.
62
+ """
63
+ return {
64
+ "action_id": action.action_id,
65
+ "game_name": action.game_name,
66
+ "obs_type": action.obs_type,
67
+ "full_action_space": action.full_action_space,
68
+ }
69
+
70
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AtariObservation]:
71
+ """
72
+ Parse server response into StepResult[AtariObservation].
73
+
74
+ Args:
75
+ payload: JSON response from server.
76
+
77
+ Returns:
78
+ StepResult with AtariObservation.
79
+ """
80
+ obs_data = payload.get("observation", {})
81
+
82
+ observation = AtariObservation(
83
+ screen=obs_data.get("screen", []),
84
+ screen_shape=obs_data.get("screen_shape", []),
85
+ legal_actions=obs_data.get("legal_actions", []),
86
+ lives=obs_data.get("lives", 0),
87
+ episode_frame_number=obs_data.get("episode_frame_number", 0),
88
+ frame_number=obs_data.get("frame_number", 0),
89
+ done=payload.get("done", False),
90
+ reward=payload.get("reward"),
91
+ metadata=obs_data.get("metadata", {}),
92
+ )
93
+
94
+ return StepResult(
95
+ observation=observation,
96
+ reward=payload.get("reward"),
97
+ done=payload.get("done", False),
98
+ )
99
+
100
+ def _parse_state(self, payload: Dict[str, Any]) -> AtariState:
101
+ """
102
+ Parse server response into AtariState object.
103
+
104
+ Args:
105
+ payload: JSON response from /state endpoint.
106
+
107
+ Returns:
108
+ AtariState object with environment state information.
109
+ """
110
+ return AtariState(
111
+ episode_id=payload.get("episode_id"),
112
+ step_count=payload.get("step_count", 0),
113
+ game_name=payload.get("game_name", "unknown"),
114
+ obs_type=payload.get("obs_type", "rgb"),
115
+ full_action_space=payload.get("full_action_space", False),
116
+ mode=payload.get("mode"),
117
+ difficulty=payload.get("difficulty"),
118
+ repeat_action_probability=payload.get("repeat_action_probability", 0.0),
119
+ frameskip=payload.get("frameskip", 4),
120
+ )
envs/atari_env/models.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for Atari Environment.
9
+
10
+ This module defines the Action, Observation, and State types for Atari games
11
+ via the Arcade Learning Environment (ALE).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, Dict, List, Literal, Optional
17
+
18
+ from openenv.core.env_server import Action, Observation, State
19
+
20
+
21
+ class AtariAction(Action):
22
+ """
23
+ Action for Atari environments.
24
+
25
+ Attributes:
26
+ action_id: The integer action ID to take (from legal_actions).
27
+ game_name: Name of the Atari game (e.g., "pong", "breakout", "space_invaders").
28
+ obs_type: Observation type ("rgb", "grayscale", or "ram").
29
+ full_action_space: Whether to use full (18 actions) or minimal action space.
30
+ """
31
+
32
+ action_id: int
33
+ game_name: str = "pong"
34
+ obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
35
+ full_action_space: bool = False
36
+
37
+
38
+ class AtariObservation(Observation):
39
+ """
40
+ Observation from Atari environment.
41
+
42
+ This represents what the agent sees after taking an action.
43
+
44
+ Attributes:
45
+ screen: Screen observation as a flattened list of pixels.
46
+ Shape depends on obs_type:
47
+ - rgb: [210, 160, 3] flattened
48
+ - grayscale: [210, 160] flattened
49
+ - ram: [128] (RAM contents)
50
+ screen_shape: Original shape of the screen before flattening.
51
+ legal_actions: List of legal action IDs the agent can take.
52
+ lives: Number of lives remaining.
53
+ episode_frame_number: Frame number within current episode.
54
+ frame_number: Total frame number since environment creation.
55
+ """
56
+
57
+ screen: List[int]
58
+ screen_shape: List[int]
59
+ legal_actions: List[int]
60
+ lives: int = 0
61
+ episode_frame_number: int = 0
62
+ frame_number: int = 0
63
+
64
+
65
+ class AtariState(State):
66
+ """
67
+ State for Atari environment.
68
+
69
+ Attributes:
70
+ game_name: Name of the Atari game.
71
+ obs_type: Observation type ("rgb", "grayscale", or "ram").
72
+ full_action_space: Whether using full or minimal action space.
73
+ mode: Game mode (if applicable).
74
+ difficulty: Game difficulty (if applicable).
75
+ repeat_action_probability: Probability of repeating previous action (sticky actions).
76
+ frameskip: Number of frames to skip per action.
77
+ """
78
+
79
+ game_name: str = "pong"
80
+ obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
81
+ full_action_space: bool = False
82
+ mode: Optional[int] = None
83
+ difficulty: Optional[int] = None
84
+ repeat_action_probability: float = 0.0
85
+ frameskip: int = 4
envs/atari_env/server/Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Atari Environment
2
+ # This image provides Atari 2600 games via the Arcade Learning Environment (ALE)
3
+
4
+ # Configurable base image - defaults to local build, can be overridden for CI/CD
5
+ # Base image provides: fastapi, uvicorn, requests, curl, PYTHONPATH=/app/src
6
+ #
7
+ # Local build: docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile .
8
+ # docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
9
+ #
10
+ # CI/CD build: docker build --build-arg BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest \
11
+ # -f envs/atari_env/server/Dockerfile -t atari-env:latest .
12
+ ARG BASE_IMAGE=openenv-base:latest
13
+ FROM ${BASE_IMAGE}
14
+
15
+ # Install dependencies
16
+ COPY envs/atari_env/server/requirements.txt /tmp/requirements.txt
17
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt
18
+
19
+ # Copy OpenEnv core (base image already set WORKDIR=/app)
20
+ COPY src/core/ /app/src/core/
21
+
22
+ # Copy Atari environment code
23
+ COPY envs/atari_env/ /app/envs/atari_env/
24
+
25
+ # Copy README for web interface documentation
26
+ COPY envs/atari_env/README.md /app/README.md
27
+
28
+ # Atari-specific environment variables (can be overridden at runtime)
29
+ ENV ATARI_GAME=pong
30
+ ENV ATARI_OBS_TYPE=rgb
31
+ ENV ATARI_FULL_ACTION_SPACE=false
32
+ ENV ATARI_REPEAT_ACTION_PROB=0.0
33
+ ENV ATARI_FRAMESKIP=4
34
+
35
+ # Expose port
36
+ EXPOSE 8000
37
+
38
+ # Health check
39
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
40
+ CMD curl -f http://localhost:8000/health || exit 1
41
+
42
+ # Run the FastAPI server
43
+ CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
envs/atari_env/server/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment Server.
9
+
10
+ Server-side implementation of Atari environment for OpenEnv.
11
+ """
12
+
13
+ from .atari_environment import AtariEnvironment
14
+
15
+ __all__ = ["AtariEnvironment"]
envs/atari_env/server/app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Atari Environment.
9
+
10
+ This module creates an HTTP server that exposes Atari games
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Usage:
14
+ # Development (with auto-reload):
15
+ uvicorn envs.atari_env.server.app:app --reload --host 0.0.0.0 --port 8000
16
+
17
+ # Production:
18
+ uvicorn envs.atari_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4
19
+
20
+ # Or run directly:
21
+ python -m envs.atari_env.server.app
22
+
23
+ Environment variables:
24
+ ATARI_GAME: Game name to serve (default: "pong")
25
+ ATARI_OBS_TYPE: Observation type (default: "rgb")
26
+ ATARI_FULL_ACTION_SPACE: Use full action space (default: "false")
27
+ ATARI_MODE: Game mode (optional)
28
+ ATARI_DIFFICULTY: Game difficulty (optional)
29
+ ATARI_REPEAT_ACTION_PROB: Sticky action probability (default: "0.0")
30
+ ATARI_FRAMESKIP: Frameskip (default: "4")
31
+ """
32
+
33
+ import os
34
+
35
+ from openenv.core.env_server import create_app
36
+
37
+ from ..models import AtariAction, AtariObservation
38
+ from .atari_environment import AtariEnvironment
39
+
40
+ # Get configuration from environment variables
41
+ game_name = os.getenv("ATARI_GAME", "pong")
42
+ obs_type = os.getenv("ATARI_OBS_TYPE", "rgb")
43
+ full_action_space = os.getenv("ATARI_FULL_ACTION_SPACE", "false").lower() == "true"
44
+ repeat_action_prob = float(os.getenv("ATARI_REPEAT_ACTION_PROB", "0.0"))
45
+ frameskip = int(os.getenv("ATARI_FRAMESKIP", "4"))
46
+
47
+ # Optional parameters
48
+ mode = os.getenv("ATARI_MODE")
49
+ difficulty = os.getenv("ATARI_DIFFICULTY")
50
+
51
+ # Convert to int if specified
52
+ mode = int(mode) if mode is not None else None
53
+ difficulty = int(difficulty) if difficulty is not None else None
54
+
55
+
56
+ # Factory function to create AtariEnvironment instances
57
+ def create_atari_environment():
58
+ """Factory function that creates AtariEnvironment with config."""
59
+ return AtariEnvironment(
60
+ game_name=game_name,
61
+ obs_type=obs_type,
62
+ full_action_space=full_action_space,
63
+ mode=mode,
64
+ difficulty=difficulty,
65
+ repeat_action_probability=repeat_action_prob,
66
+ frameskip=frameskip,
67
+ )
68
+
69
+
70
+ # Create the FastAPI app with web interface and README integration
71
+ # Pass the factory function instead of an instance for WebSocket session support
72
+ app = create_app(
73
+ create_atari_environment, AtariAction, AtariObservation, env_name="atari_env"
74
+ )
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import uvicorn
79
+
80
+ uvicorn.run(app, host="0.0.0.0", port=8000)
envs/atari_env/server/atari_environment.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment Server Implementation.
9
+
10
+ This module wraps ALE's ALEInterface and exposes it
11
+ via the OpenEnv Environment interface.
12
+ """
13
+
14
+ import uuid
15
+ from typing import Any, Dict, Literal, Optional
16
+
17
+ from openenv.core.env_server import Action, Environment, Observation
18
+
19
+ from ..models import AtariAction, AtariObservation, AtariState
20
+
21
+ # Import ALE
22
+ try:
23
+ import numpy as np
24
+ from ale_py import ALEInterface, roms
25
+ except ImportError as e:
26
+ raise ImportError(
27
+ "ALE (Arcade Learning Environment) is not installed. "
28
+ "Please install it with: pip install ale-py"
29
+ ) from e
30
+
31
+
32
+ class AtariEnvironment(Environment):
33
+ """
34
+ Atari Environment wrapper for OpenEnv.
35
+
36
+ This environment wraps Atari 2600 games via the Arcade Learning Environment (ALE)
37
+ and provides a clean interface for RL training.
38
+
39
+ Supported games include: pong, breakout, space_invaders, and 100+ others.
40
+
41
+ Args:
42
+ game_name: Name of the Atari game (e.g., "pong", "breakout").
43
+ obs_type: Observation type - "rgb", "grayscale", or "ram".
44
+ full_action_space: Use full action space (18 actions) vs minimal.
45
+ mode: Game mode (if applicable).
46
+ difficulty: Game difficulty (if applicable).
47
+ repeat_action_probability: Sticky action probability (default 0.0).
48
+ frameskip: Number of frames to skip per action (default 4).
49
+
50
+ Example:
51
+ >>> env = AtariEnvironment("pong")
52
+ >>> obs = env.reset()
53
+ >>> print(obs.screen_shape) # [210, 160, 3]
54
+ >>> obs = env.step(AtariAction(action_id=2)) # UP
55
+ >>> print(obs.reward, obs.done)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ game_name: str = "pong",
61
+ obs_type: Literal["rgb", "grayscale", "ram"] = "rgb",
62
+ full_action_space: bool = False,
63
+ mode: Optional[int] = None,
64
+ difficulty: Optional[int] = None,
65
+ repeat_action_probability: float = 0.0,
66
+ frameskip: int = 4,
67
+ ):
68
+ """Initialize Atari environment."""
69
+ super().__init__()
70
+
71
+ self.game_name = game_name
72
+ self.obs_type = obs_type
73
+ self.full_action_space = full_action_space
74
+ self.mode = mode
75
+ self.difficulty = difficulty
76
+ self.repeat_action_probability = repeat_action_probability
77
+ self.frameskip = frameskip
78
+
79
+ # Create ALE interface
80
+ self.ale = ALEInterface()
81
+
82
+ # Configure ALE
83
+ from ale_py import LoggerMode
84
+
85
+ self.ale.setLoggerMode(LoggerMode.Error) # Error mode only
86
+ self.ale.setFloat("repeat_action_probability", repeat_action_probability)
87
+
88
+ # Load ROM
89
+ try:
90
+ rom_path = roms.get_rom_path(game_name)
91
+ self.ale.loadROM(rom_path)
92
+ except Exception as e:
93
+ raise ValueError(
94
+ f"Failed to load Atari game '{game_name}': {e}\n"
95
+ f"Available games can be found via: ale_py.roms.list_roms()"
96
+ ) from e
97
+
98
+ # Set mode and difficulty if specified
99
+ if mode is not None:
100
+ self.ale.setMode(mode)
101
+ if difficulty is not None:
102
+ self.ale.setDifficulty(difficulty)
103
+
104
+ # Get action set
105
+ if full_action_space:
106
+ self._action_set = self.ale.getLegalActionSet()
107
+ else:
108
+ self._action_set = self.ale.getMinimalActionSet()
109
+
110
+ # Get screen dimensions for observation space
111
+ self.screen_height, self.screen_width = self.ale.getScreenDims()
112
+ if obs_type == "rgb":
113
+ self.screen_shape = [self.screen_height, self.screen_width, 3]
114
+ elif obs_type == "grayscale":
115
+ self.screen_shape = [self.screen_height, self.screen_width]
116
+ elif obs_type == "ram":
117
+ self.screen_shape = [self.ale.getRAMSize()]
118
+ else:
119
+ raise ValueError(f"Invalid obs_type: {obs_type}")
120
+
121
+ # Initialize state
122
+ self._state = AtariState(
123
+ game_name=game_name,
124
+ obs_type=obs_type,
125
+ full_action_space=full_action_space,
126
+ mode=mode,
127
+ difficulty=difficulty,
128
+ repeat_action_probability=repeat_action_probability,
129
+ frameskip=frameskip,
130
+ )
131
+
132
+ def reset(self) -> Observation:
133
+ """
134
+ Reset the environment and return initial observation.
135
+
136
+ Returns:
137
+ Initial observation for the agent.
138
+ """
139
+ # Reset ALE
140
+ self.ale.reset_game()
141
+
142
+ # Reset state tracking
143
+ self._state.episode_id = str(uuid.uuid4())
144
+ self._state.step_count = 0
145
+
146
+ # Get initial observation
147
+ return self._make_observation()
148
+
149
+ def step(self, action: Action) -> Observation:
150
+ """
151
+ Execute agent's action and return resulting observation.
152
+
153
+ Args:
154
+ action: AtariAction containing the action_id to execute.
155
+
156
+ Returns:
157
+ Observation after action execution.
158
+
159
+ Raises:
160
+ ValueError: If action is not an AtariAction.
161
+ """
162
+ if not isinstance(action, AtariAction):
163
+ raise ValueError(f"Expected AtariAction, got {type(action)}")
164
+
165
+ # Validate action_id
166
+ if action.action_id < 0 or action.action_id >= len(self._action_set):
167
+ raise ValueError(
168
+ f"Invalid action_id: {action.action_id}. "
169
+ f"Valid range: [0, {len(self._action_set) - 1}]"
170
+ )
171
+
172
+ # Get actual ALE action
173
+ ale_action = self._action_set[action.action_id]
174
+
175
+ # Execute action with frameskip
176
+ total_reward = 0.0
177
+ for _ in range(self.frameskip):
178
+ total_reward += self.ale.act(ale_action)
179
+ if self.ale.game_over():
180
+ break
181
+
182
+ self._state.step_count += 1
183
+
184
+ # Get observation
185
+ obs = self._make_observation()
186
+ obs.reward = total_reward
187
+
188
+ return obs
189
+
190
+ @property
191
+ def state(self) -> AtariState:
192
+ """Get current environment state."""
193
+ return self._state
194
+
195
+ def _make_observation(self) -> AtariObservation:
196
+ """
197
+ Create an AtariObservation from current ALE state.
198
+
199
+ Returns:
200
+ AtariObservation for the agent.
201
+ """
202
+ # Get screen observation
203
+ if self.obs_type == "rgb":
204
+ screen = self.ale.getScreenRGB()
205
+ elif self.obs_type == "grayscale":
206
+ screen = self.ale.getScreenGrayscale()
207
+ elif self.obs_type == "ram":
208
+ screen = self.ale.getRAM()
209
+ else:
210
+ raise ValueError(f"Invalid obs_type: {self.obs_type}")
211
+
212
+ # Flatten screen for JSON serialization
213
+ # Handle both numpy arrays and lists
214
+ if hasattr(screen, "flatten"):
215
+ screen_flat = screen.flatten().tolist()
216
+ elif hasattr(screen, "tolist"):
217
+ screen_flat = screen.tolist()
218
+ else:
219
+ screen_flat = list(screen)
220
+
221
+ # Get game info
222
+ lives = self.ale.lives()
223
+ episode_frame_number = self.ale.getEpisodeFrameNumber()
224
+ frame_number = self.ale.getFrameNumber()
225
+ done = self.ale.game_over()
226
+
227
+ # Create legal actions list (indices into action_set)
228
+ legal_actions = list(range(len(self._action_set)))
229
+
230
+ # Create observation
231
+ obs = AtariObservation(
232
+ screen=screen_flat,
233
+ screen_shape=self.screen_shape,
234
+ legal_actions=legal_actions,
235
+ lives=lives,
236
+ episode_frame_number=episode_frame_number,
237
+ frame_number=frame_number,
238
+ done=done,
239
+ reward=0.0, # Will be filled in by step()
240
+ metadata={
241
+ "game_name": self.game_name,
242
+ "action_meanings": [str(a) for a in self._action_set],
243
+ },
244
+ )
245
+
246
+ return obs
envs/atari_env/server/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gymnasium>=0.29.0
2
+ ale-py>=0.8.0
3
+ numpy>=1.24.0
envs/atari_env/test_atari_docker.sh ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Comprehensive Docker test for Atari environment
3
+ # Tests: Build, Start, Health, Reset, Step, State, Cleanup
4
+
5
+ set -e # Exit on error
6
+
7
+ # Colors for output
8
+ RED='\033[0;31m'
9
+ GREEN='\033[0;32m'
10
+ YELLOW='\033[1;33m'
11
+ BLUE='\033[0;34m'
12
+ NC='\033[0m' # No Color
13
+
14
+ # Configuration
15
+ IMAGE_NAME="atari-env"
16
+ IMAGE_TAG="test"
17
+ CONTAINER_NAME="atari-env-test"
18
+ PORT="8765" # Use non-standard port to avoid conflicts
19
+ HEALTH_RETRIES=30
20
+ HEALTH_DELAY=2
21
+
22
+ # Cleanup function
23
+ cleanup() {
24
+ echo -e "\n${BLUE}Cleaning up...${NC}"
25
+ docker stop ${CONTAINER_NAME} 2>/dev/null || true
26
+ docker rm ${CONTAINER_NAME} 2>/dev/null || true
27
+ echo -e "${GREEN}✓${NC} Cleanup complete"
28
+ }
29
+
30
+ # Set trap to cleanup on exit
31
+ trap cleanup EXIT
32
+
33
+ # Header
34
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
35
+ echo " ATARI ENVIRONMENT DOCKER TEST"
36
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
37
+ echo ""
38
+
39
+ # Check prerequisites
40
+ echo -e "${BLUE}Checking prerequisites...${NC}"
41
+ if ! command -v docker &> /dev/null; then
42
+ echo -e "${RED}✗${NC} Docker is not installed"
43
+ exit 1
44
+ fi
45
+ echo -e "${GREEN}✓${NC} Docker is installed"
46
+
47
+ if ! command -v curl &> /dev/null; then
48
+ echo -e "${RED}✗${NC} curl is not installed"
49
+ exit 1
50
+ fi
51
+ echo -e "${GREEN}✓${NC} curl is installed"
52
+
53
+ # Check if we're in the right directory
54
+ if [ ! -f "envs/atari_env/server/Dockerfile" ]; then
55
+ echo -e "${RED}✗${NC} Must run from OpenEnv root directory"
56
+ exit 1
57
+ fi
58
+ echo -e "${GREEN}✓${NC} In correct directory"
59
+
60
+ # Step 1: Build Docker image
61
+ echo ""
62
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
63
+ echo -e "${BLUE}STEP 1: Building Docker Image${NC}"
64
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
65
+
66
+ echo "Building ${IMAGE_NAME}:${IMAGE_TAG}..."
67
+ if docker build -f envs/atari_env/server/Dockerfile -t ${IMAGE_NAME}:${IMAGE_TAG} . 2>&1 | tee /tmp/atari_build.log | tail -n 20; then
68
+ echo -e "${GREEN}✓${NC} Docker image built successfully"
69
+ else
70
+ echo -e "${RED}✗${NC} Docker build failed"
71
+ echo "See /tmp/atari_build.log for full output"
72
+ exit 1
73
+ fi
74
+
75
+ # Check image exists
76
+ if docker image inspect ${IMAGE_NAME}:${IMAGE_TAG} &> /dev/null; then
77
+ IMAGE_SIZE=$(docker image inspect ${IMAGE_NAME}:${IMAGE_TAG} --format='{{.Size}}' | awk '{print $1/1024/1024}')
78
+ echo -e "${GREEN}✓${NC} Image size: ${IMAGE_SIZE} MB"
79
+ else
80
+ echo -e "${RED}✗${NC} Image not found after build"
81
+ exit 1
82
+ fi
83
+
84
+ # Step 2: Start container
85
+ echo ""
86
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
87
+ echo -e "${BLUE}STEP 2: Starting Container${NC}"
88
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
89
+
90
+ # Clean up any existing container
91
+ docker rm -f ${CONTAINER_NAME} 2>/dev/null || true
92
+
93
+ echo "Starting container on port ${PORT}..."
94
+ docker run -d \
95
+ --name ${CONTAINER_NAME} \
96
+ -p ${PORT}:8000 \
97
+ -e ATARI_GAME=pong \
98
+ -e ATARI_OBS_TYPE=ram \
99
+ -e ATARI_FRAMESKIP=4 \
100
+ ${IMAGE_NAME}:${IMAGE_TAG}
101
+
102
+ if [ $? -eq 0 ]; then
103
+ echo -e "${GREEN}✓${NC} Container started: ${CONTAINER_NAME}"
104
+ else
105
+ echo -e "${RED}✗${NC} Failed to start container"
106
+ exit 1
107
+ fi
108
+
109
+ # Wait for container to be running
110
+ sleep 2
111
+ if docker ps | grep -q ${CONTAINER_NAME}; then
112
+ echo -e "${GREEN}✓${NC} Container is running"
113
+ else
114
+ echo -e "${RED}✗${NC} Container is not running"
115
+ docker logs ${CONTAINER_NAME}
116
+ exit 1
117
+ fi
118
+
119
+ # Step 3: Wait for health check
120
+ echo ""
121
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
122
+ echo -e "${BLUE}STEP 3: Waiting for Server${NC}"
123
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
124
+
125
+ echo "Waiting for server to be ready (timeout: ${HEALTH_RETRIES}s)..."
126
+ for i in $(seq 1 ${HEALTH_RETRIES}); do
127
+ if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then
128
+ echo -e "${GREEN}✓${NC} Server is ready (${i}s)"
129
+ break
130
+ fi
131
+
132
+ if [ $i -eq ${HEALTH_RETRIES} ]; then
133
+ echo -e "${RED}✗${NC} Server did not become ready in time"
134
+ echo "Container logs:"
135
+ docker logs ${CONTAINER_NAME}
136
+ exit 1
137
+ fi
138
+
139
+ echo -n "."
140
+ sleep ${HEALTH_DELAY}
141
+ done
142
+
143
+ # Step 4: Test health endpoint
144
+ echo ""
145
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
146
+ echo -e "${BLUE}STEP 4: Testing Health Endpoint${NC}"
147
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
148
+
149
+ HEALTH_RESPONSE=$(curl -s http://localhost:${PORT}/health)
150
+ echo "Response: ${HEALTH_RESPONSE}"
151
+
152
+ if echo "${HEALTH_RESPONSE}" | grep -q "healthy"; then
153
+ echo -e "${GREEN}✓${NC} Health endpoint working"
154
+ else
155
+ echo -e "${RED}✗${NC} Health endpoint failed"
156
+ exit 1
157
+ fi
158
+
159
+ # Step 5: Test reset endpoint
160
+ echo ""
161
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
162
+ echo -e "${BLUE}STEP 5: Testing Reset Endpoint${NC}"
163
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
164
+
165
+ RESET_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/reset -H "Content-Type: application/json" -d '{}')
166
+
167
+ if [ -z "${RESET_RESPONSE}" ]; then
168
+ echo -e "${RED}✗${NC} Reset endpoint returned empty response"
169
+ docker logs ${CONTAINER_NAME} | tail -20
170
+ exit 1
171
+ fi
172
+
173
+ echo "Response (first 200 chars): ${RESET_RESPONSE:0:200}..."
174
+
175
+ # Check if response contains expected fields
176
+ if echo "${RESET_RESPONSE}" | grep -q "observation" && \
177
+ echo "${RESET_RESPONSE}" | grep -q "screen" && \
178
+ echo "${RESET_RESPONSE}" | grep -q "legal_actions"; then
179
+ echo -e "${GREEN}✓${NC} Reset endpoint working"
180
+
181
+ # Extract some info
182
+ SCREEN_LEN=$(echo "${RESET_RESPONSE}" | grep -o '"screen":\[[^]]*\]' | wc -c)
183
+ echo " Screen data length: ${SCREEN_LEN} chars"
184
+ else
185
+ echo -e "${RED}✗${NC} Reset response missing required fields"
186
+ echo "Full response: ${RESET_RESPONSE}"
187
+ exit 1
188
+ fi
189
+
190
+ # Step 6: Test step endpoint
191
+ echo ""
192
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
193
+ echo -e "${BLUE}STEP 6: Testing Step Endpoint${NC}"
194
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
195
+
196
+ STEP_PAYLOAD='{"action": {"action_id": 0, "game_name": "pong"}}'
197
+ STEP_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/step -H "Content-Type: application/json" -d "${STEP_PAYLOAD}")
198
+
199
+ if [ -z "${STEP_RESPONSE}" ]; then
200
+ echo -e "${RED}✗${NC} Step endpoint returned empty response"
201
+ docker logs ${CONTAINER_NAME} | tail -20
202
+ exit 1
203
+ fi
204
+
205
+ echo "Response (first 200 chars): ${STEP_RESPONSE:0:200}..."
206
+
207
+ # Check if response contains expected fields
208
+ if echo "${STEP_RESPONSE}" | grep -q "observation" && \
209
+ echo "${STEP_RESPONSE}" | grep -q "reward" && \
210
+ echo "${STEP_RESPONSE}" | grep -q "done"; then
211
+ echo -e "${GREEN}✓${NC} Step endpoint working"
212
+
213
+ # Extract reward and done
214
+ REWARD=$(echo "${STEP_RESPONSE}" | grep -o '"reward":[^,}]*' | cut -d: -f2)
215
+ DONE=$(echo "${STEP_RESPONSE}" | grep -o '"done":[^,}]*' | cut -d: -f2)
216
+ echo " Reward: ${REWARD}"
217
+ echo " Done: ${DONE}"
218
+ else
219
+ echo -e "${RED}✗${NC} Step response missing required fields"
220
+ echo "Full response: ${STEP_RESPONSE}"
221
+ exit 1
222
+ fi
223
+
224
+ # Step 7: Test state endpoint
225
+ echo ""
226
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
227
+ echo -e "${BLUE}STEP 7: Testing State Endpoint${NC}"
228
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
229
+
230
+ STATE_RESPONSE=$(curl -s http://localhost:${PORT}/state)
231
+
232
+ if [ -z "${STATE_RESPONSE}" ]; then
233
+ echo -e "${RED}✗${NC} State endpoint returned empty response"
234
+ docker logs ${CONTAINER_NAME} | tail -20
235
+ exit 1
236
+ fi
237
+
238
+ echo "Response: ${STATE_RESPONSE}"
239
+
240
+ # Check if response contains expected fields
241
+ if echo "${STATE_RESPONSE}" | grep -q "episode_id" && \
242
+ echo "${STATE_RESPONSE}" | grep -q "step_count" && \
243
+ echo "${STATE_RESPONSE}" | grep -q "game_name"; then
244
+ echo -e "${GREEN}✓${NC} State endpoint working"
245
+
246
+ # Extract info
247
+ GAME_NAME=$(echo "${STATE_RESPONSE}" | grep -o '"game_name":"[^"]*"' | cut -d'"' -f4)
248
+ STEP_COUNT=$(echo "${STATE_RESPONSE}" | grep -o '"step_count":[^,}]*' | cut -d: -f2)
249
+ echo " Game: ${GAME_NAME}"
250
+ echo " Steps: ${STEP_COUNT}"
251
+ else
252
+ echo -e "${RED}✗${NC} State response missing required fields"
253
+ echo "Full response: ${STATE_RESPONSE}"
254
+ exit 1
255
+ fi
256
+
257
+ # Step 8: Test multiple steps
258
+ echo ""
259
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
260
+ echo -e "${BLUE}STEP 8: Testing Multiple Steps${NC}"
261
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
262
+
263
+ echo "Taking 10 steps..."
264
+ TOTAL_REWARD=0
265
+ for i in {1..10}; do
266
+ ACTION_ID=$((RANDOM % 3)) # Random action 0-2
267
+ STEP_PAYLOAD="{\"action\": {\"action_id\": ${ACTION_ID}, \"game_name\": \"pong\"}}"
268
+ STEP_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/step -H "Content-Type: application/json" -d "${STEP_PAYLOAD}")
269
+
270
+ if ! echo "${STEP_RESPONSE}" | grep -q "observation"; then
271
+ echo -e "${RED}✗${NC} Step ${i} failed"
272
+ exit 1
273
+ fi
274
+
275
+ REWARD=$(echo "${STEP_RESPONSE}" | grep -o '"reward":[^,}]*' | cut -d: -f2 | sed 's/null/0/')
276
+ DONE=$(echo "${STEP_RESPONSE}" | grep -o '"done":[^,}]*' | cut -d: -f2)
277
+
278
+ echo " Step ${i}: action=${ACTION_ID}, reward=${REWARD}, done=${DONE}"
279
+
280
+ if [ "${DONE}" = "true" ]; then
281
+ echo " Episode completed early at step ${i}"
282
+ break
283
+ fi
284
+ done
285
+
286
+ echo -e "${GREEN}✓${NC} Multiple steps completed successfully"
287
+
288
+ # Step 9: Check container logs for errors
289
+ echo ""
290
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
291
+ echo -e "${BLUE}STEP 9: Checking Container Logs${NC}"
292
+ echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
293
+
294
+ LOGS=$(docker logs ${CONTAINER_NAME} 2>&1)
295
+
296
+ if echo "${LOGS}" | grep -i "error" | grep -v "LoggerMode.Error"; then
297
+ echo -e "${YELLOW}⚠${NC} Found errors in logs:"
298
+ echo "${LOGS}" | grep -i "error" | head -5
299
+ else
300
+ echo -e "${GREEN}✓${NC} No errors in container logs"
301
+ fi
302
+
303
+ if echo "${LOGS}" | grep -i "exception"; then
304
+ echo -e "${RED}✗${NC} Found exceptions in logs:"
305
+ echo "${LOGS}" | grep -i "exception" | head -5
306
+ exit 1
307
+ else
308
+ echo -e "${GREEN}✓${NC} No exceptions in container logs"
309
+ fi
310
+
311
+ # Final Summary
312
+ echo ""
313
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
314
+ echo -e "${GREEN}✅ ALL DOCKER TESTS PASSED${NC}"
315
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
316
+ echo ""
317
+ echo "Summary:"
318
+ echo " ✓ Docker image built successfully"
319
+ echo " ✓ Container started and ran"
320
+ echo " ✓ Health endpoint working"
321
+ echo " ✓ Reset endpoint working"
322
+ echo " ✓ Step endpoint working"
323
+ echo " ✓ State endpoint working"
324
+ echo " ✓ Multiple steps working"
325
+ echo " ✓ No errors or exceptions"
326
+ echo ""
327
+ echo "Image: ${IMAGE_NAME}:${IMAGE_TAG}"
328
+ echo "Container: ${CONTAINER_NAME}"
329
+ echo "Port: ${PORT}"
330
+ echo ""
331
+ echo "To keep container running: docker start ${CONTAINER_NAME}"
332
+ echo "To view logs: docker logs ${CONTAINER_NAME}"
333
+ echo ""
models.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for Atari Environment.
9
+
10
+ This module defines the Action, Observation, and State types for Atari games
11
+ via the Arcade Learning Environment (ALE).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, Dict, List, Literal, Optional
17
+
18
+ from openenv.core.env_server import Action, Observation, State
19
+
20
+
21
+ class AtariAction(Action):
22
+ """
23
+ Action for Atari environments.
24
+
25
+ Attributes:
26
+ action_id: The integer action ID to take (from legal_actions).
27
+ game_name: Name of the Atari game (e.g., "pong", "breakout", "space_invaders").
28
+ obs_type: Observation type ("rgb", "grayscale", or "ram").
29
+ full_action_space: Whether to use full (18 actions) or minimal action space.
30
+ """
31
+
32
+ action_id: int
33
+ game_name: str = "pong"
34
+ obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
35
+ full_action_space: bool = False
36
+
37
+
38
+ class AtariObservation(Observation):
39
+ """
40
+ Observation from Atari environment.
41
+
42
+ This represents what the agent sees after taking an action.
43
+
44
+ Attributes:
45
+ screen: Screen observation as a flattened list of pixels.
46
+ Shape depends on obs_type:
47
+ - rgb: [210, 160, 3] flattened
48
+ - grayscale: [210, 160] flattened
49
+ - ram: [128] (RAM contents)
50
+ screen_shape: Original shape of the screen before flattening.
51
+ legal_actions: List of legal action IDs the agent can take.
52
+ lives: Number of lives remaining.
53
+ episode_frame_number: Frame number within current episode.
54
+ frame_number: Total frame number since environment creation.
55
+ """
56
+
57
+ screen: List[int]
58
+ screen_shape: List[int]
59
+ legal_actions: List[int]
60
+ lives: int = 0
61
+ episode_frame_number: int = 0
62
+ frame_number: int = 0
63
+
64
+
65
+ class AtariState(State):
66
+ """
67
+ State for Atari environment.
68
+
69
+ Attributes:
70
+ game_name: Name of the Atari game.
71
+ obs_type: Observation type ("rgb", "grayscale", or "ram").
72
+ full_action_space: Whether using full or minimal action space.
73
+ mode: Game mode (if applicable).
74
+ difficulty: Game difficulty (if applicable).
75
+ repeat_action_probability: Probability of repeating previous action (sticky actions).
76
+ frameskip: Number of frames to skip per action.
77
+ """
78
+
79
+ game_name: str = "pong"
80
+ obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
81
+ full_action_space: bool = False
82
+ mode: Optional[int] = None
83
+ difficulty: Optional[int] = None
84
+ repeat_action_probability: float = 0.0
85
+ frameskip: int = 4
pyproject.toml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-core"
7
+ version = "0.2.2.dev0"
8
+ description = "A unified framework for reinforcement learning environments"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ # Core shared dependencies - minimal set required for all environments
13
+ # Heavy dependencies (torch, numpy, smolagents, etc.) should be in
14
+ # individual environment pyproject.toml files
15
+ "fastapi>=0.104.0",
16
+ "pydantic>=2.0.0",
17
+ "uvicorn>=0.24.0",
18
+ "requests>=2.25.0",
19
+ # CLI dependencies
20
+ "typer>=0.9.0",
21
+ "rich>=13.0.0",
22
+ "pyyaml>=6.0",
23
+ "huggingface_hub>=0.20.0",
24
+ "openai>=2.7.2",
25
+ "tomli>=2.3.0",
26
+ "tomli-w>=1.2.0",
27
+ "websockets>=15.0.1",
28
+ # MCP support
29
+ "fastmcp>=3.0.0",
30
+ # Web UI dependencies
31
+ "gradio>=4.0.0",
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ core = [
36
+ "fastapi>=0.104.0",
37
+ "pydantic>=2.0.0",
38
+ "uvicorn>=0.24.0",
39
+ "requests>=2.25.0",
40
+ "websockets>=15.0.1",
41
+ ]
42
+ cli = [
43
+ "typer>=0.9.0",
44
+ "rich>=13.0.0",
45
+ "pyyaml>=6.0",
46
+ "huggingface_hub>=0.20.0",
47
+ "openai>=2.7.2",
48
+ "tomli>=2.3.0",
49
+ "tomli-w>=1.2.0",
50
+ ]
51
+ docs = [
52
+ "sphinx==7.2.6",
53
+ "pytorch-sphinx-theme2",
54
+ "sphinxcontrib.katex==0.9.10",
55
+ "docutils>=0.18.1,<0.21",
56
+ "sphinx-design==0.6.1",
57
+ "sphinxcontrib-mermaid==1.0.0",
58
+ "myst-parser",
59
+ "sphinxext-opengraph",
60
+ "sphinx-sitemap==2.7.1",
61
+ "sphinx-gallery>=0.14.0",
62
+ "matplotlib",
63
+ "nest-asyncio",
64
+ "smolagents",
65
+ ]
66
+ all = [
67
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main",
68
+ "openenv-core[cli]",
69
+ ]
70
+ daytona = [
71
+ "daytona>=0.136.0",
72
+ "pyyaml>=6.0",
73
+ ]
74
+ inspect = [
75
+ "inspect-ai>=0.3.0",
76
+ ]
77
+
78
+ [project.scripts]
79
+ openenv = "openenv.cli.__main__:main"
80
+
81
+ [tool.setuptools]
82
+ package-dir = {"" = "src"}
83
+ include-package-data = true
84
+
85
+ [tool.setuptools.package-data]
86
+ "openenv.cli" = ["templates/**/*"]
87
+
88
+ [tool.setuptools.packages.find]
89
+ where = ["src"]
90
+
91
+ [tool.coverage.run]
92
+ omit = [
93
+ "openenv/cli/templates/**",
94
+ "**/templates/**",
95
+ "openenv/cli/__main__.py",
96
+ ]
97
+
98
+ [tool.coverage.report]
99
+ exclude_lines = [
100
+ "pragma: no cover",
101
+ "def __repr__",
102
+ "raise AssertionError",
103
+ "raise NotImplementedError",
104
+ "if __name__ == .__main__.:",
105
+ "if TYPE_CHECKING:",
106
+ ]
107
+
108
+ [tool.pytest.ini_options]
109
+ asyncio_mode = "auto"
110
+ asyncio_default_fixture_loop_scope = "function"
111
+ markers = [
112
+ "docker: Tests that require Docker to be running",
113
+ "network: Tests that require network access (HuggingFace, etc.)",
114
+ "integration: Integration tests with external resources",
115
+ ]
116
+
117
+ [dependency-groups]
118
+ dev = [
119
+ "ruff>=0.14.0",
120
+ "usort>=1.1.0",
121
+ "pytest>=7.0",
122
+ "pytest-asyncio>=0.21",
123
+ ]
124
+
125
+ [tool.usort]
126
+ # Disable first_party auto-detection so all non-stdlib imports land in
127
+ # the same "third_party" bucket (the default_category). This matches
128
+ # pyfmt's usort behavior inside arc f, which groups openenv.* and env
129
+ # package imports together without blank-line separators.
130
+ first_party_detection = false
131
+
132
+ [tool.ruff]
133
+ line-length = 88
134
+
135
+ [tool.ruff.lint]
136
+ select = ["E", "F", "W"]
137
+ ignore = [
138
+ "E402", # Module level import not at top of file (needed for pytest.importorskip patterns)
139
+ "E501", # Line too long (not enforced previously, would require large refactor)
140
+ ]
141
+
142
+ [tool.ruff.lint.per-file-ignores]
143
+ # Context manager variables that are intentionally unused
144
+ "tests/envs/test_websockets.py" = ["F841"]
145
+ "tests/test_cli/test_push.py" = ["F841"]
146
+ # Compatibility shim module
147
+ "src/openenv_core/__init__.py" = ["F401"]
server/Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile for Atari Environment
2
+ # This image provides Atari 2600 games via the Arcade Learning Environment (ALE)
3
+
4
+ # Configurable base image - defaults to local build, can be overridden for CI/CD
5
+ # Base image provides: fastapi, uvicorn, requests, curl, PYTHONPATH=/app/src
6
+ #
7
+ # Local build: docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile .
8
+ # docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
9
+ #
10
+ # CI/CD build: docker build --build-arg BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest \
11
+ # -f envs/atari_env/server/Dockerfile -t atari-env:latest .
12
+ ARG BASE_IMAGE=openenv-base:latest
13
+ FROM ${BASE_IMAGE}
14
+
15
+ # Install dependencies
16
+ COPY envs/atari_env/server/requirements.txt /tmp/requirements.txt
17
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt
18
+
19
+ # Copy OpenEnv core (base image already set WORKDIR=/app)
20
+ COPY src/core/ /app/src/core/
21
+
22
+ # Copy Atari environment code
23
+ COPY envs/atari_env/ /app/envs/atari_env/
24
+
25
+ # Copy README for web interface documentation
26
+ COPY envs/atari_env/README.md /app/README.md
27
+
28
+ # Atari-specific environment variables (can be overridden at runtime)
29
+ ENV ATARI_GAME=pong
30
+ ENV ATARI_OBS_TYPE=rgb
31
+ ENV ATARI_FULL_ACTION_SPACE=false
32
+ ENV ATARI_REPEAT_ACTION_PROB=0.0
33
+ ENV ATARI_FRAMESKIP=4
34
+
35
+ # Expose port
36
+ EXPOSE 8000
37
+
38
+ # Health check
39
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
40
+ CMD curl -f http://localhost:8000/health || exit 1
41
+
42
+ # Run the FastAPI server
43
+ CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
server/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment Server.
9
+
10
+ Server-side implementation of Atari environment for OpenEnv.
11
+ """
12
+
13
+ from .atari_environment import AtariEnvironment
14
+
15
+ __all__ = ["AtariEnvironment"]
server/app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Atari Environment.
9
+
10
+ This module creates an HTTP server that exposes Atari games
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Usage:
14
+ # Development (with auto-reload):
15
+ uvicorn envs.atari_env.server.app:app --reload --host 0.0.0.0 --port 8000
16
+
17
+ # Production:
18
+ uvicorn envs.atari_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4
19
+
20
+ # Or run directly:
21
+ python -m envs.atari_env.server.app
22
+
23
+ Environment variables:
24
+ ATARI_GAME: Game name to serve (default: "pong")
25
+ ATARI_OBS_TYPE: Observation type (default: "rgb")
26
+ ATARI_FULL_ACTION_SPACE: Use full action space (default: "false")
27
+ ATARI_MODE: Game mode (optional)
28
+ ATARI_DIFFICULTY: Game difficulty (optional)
29
+ ATARI_REPEAT_ACTION_PROB: Sticky action probability (default: "0.0")
30
+ ATARI_FRAMESKIP: Frameskip (default: "4")
31
+ """
32
+
33
+ import os
34
+
35
+ from openenv.core.env_server import create_app
36
+
37
+ from ..models import AtariAction, AtariObservation
38
+ from .atari_environment import AtariEnvironment
39
+
40
+ # Get configuration from environment variables
41
+ game_name = os.getenv("ATARI_GAME", "pong")
42
+ obs_type = os.getenv("ATARI_OBS_TYPE", "rgb")
43
+ full_action_space = os.getenv("ATARI_FULL_ACTION_SPACE", "false").lower() == "true"
44
+ repeat_action_prob = float(os.getenv("ATARI_REPEAT_ACTION_PROB", "0.0"))
45
+ frameskip = int(os.getenv("ATARI_FRAMESKIP", "4"))
46
+
47
+ # Optional parameters
48
+ mode = os.getenv("ATARI_MODE")
49
+ difficulty = os.getenv("ATARI_DIFFICULTY")
50
+
51
+ # Convert to int if specified
52
+ mode = int(mode) if mode is not None else None
53
+ difficulty = int(difficulty) if difficulty is not None else None
54
+
55
+
56
+ # Factory function to create AtariEnvironment instances
57
+ def create_atari_environment():
58
+ """Factory function that creates AtariEnvironment with config."""
59
+ return AtariEnvironment(
60
+ game_name=game_name,
61
+ obs_type=obs_type,
62
+ full_action_space=full_action_space,
63
+ mode=mode,
64
+ difficulty=difficulty,
65
+ repeat_action_probability=repeat_action_prob,
66
+ frameskip=frameskip,
67
+ )
68
+
69
+
70
+ # Create the FastAPI app with web interface and README integration
71
+ # Pass the factory function instead of an instance for WebSocket session support
72
+ app = create_app(
73
+ create_atari_environment, AtariAction, AtariObservation, env_name="atari_env"
74
+ )
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import uvicorn
79
+
80
+ uvicorn.run(app, host="0.0.0.0", port=8000)
server/atari_environment.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Atari Environment Server Implementation.
9
+
10
+ This module wraps ALE's ALEInterface and exposes it
11
+ via the OpenEnv Environment interface.
12
+ """
13
+
14
+ import uuid
15
+ from typing import Any, Dict, Literal, Optional
16
+
17
+ from openenv.core.env_server import Action, Environment, Observation
18
+
19
+ from ..models import AtariAction, AtariObservation, AtariState
20
+
21
+ # Import ALE
22
+ try:
23
+ import numpy as np
24
+ from ale_py import ALEInterface, roms
25
+ except ImportError as e:
26
+ raise ImportError(
27
+ "ALE (Arcade Learning Environment) is not installed. "
28
+ "Please install it with: pip install ale-py"
29
+ ) from e
30
+
31
+
32
+ class AtariEnvironment(Environment):
33
+ """
34
+ Atari Environment wrapper for OpenEnv.
35
+
36
+ This environment wraps Atari 2600 games via the Arcade Learning Environment (ALE)
37
+ and provides a clean interface for RL training.
38
+
39
+ Supported games include: pong, breakout, space_invaders, and 100+ others.
40
+
41
+ Args:
42
+ game_name: Name of the Atari game (e.g., "pong", "breakout").
43
+ obs_type: Observation type - "rgb", "grayscale", or "ram".
44
+ full_action_space: Use full action space (18 actions) vs minimal.
45
+ mode: Game mode (if applicable).
46
+ difficulty: Game difficulty (if applicable).
47
+ repeat_action_probability: Sticky action probability (default 0.0).
48
+ frameskip: Number of frames to skip per action (default 4).
49
+
50
+ Example:
51
+ >>> env = AtariEnvironment("pong")
52
+ >>> obs = env.reset()
53
+ >>> print(obs.screen_shape) # [210, 160, 3]
54
+ >>> obs = env.step(AtariAction(action_id=2)) # UP
55
+ >>> print(obs.reward, obs.done)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ game_name: str = "pong",
61
+ obs_type: Literal["rgb", "grayscale", "ram"] = "rgb",
62
+ full_action_space: bool = False,
63
+ mode: Optional[int] = None,
64
+ difficulty: Optional[int] = None,
65
+ repeat_action_probability: float = 0.0,
66
+ frameskip: int = 4,
67
+ ):
68
+ """Initialize Atari environment."""
69
+ super().__init__()
70
+
71
+ self.game_name = game_name
72
+ self.obs_type = obs_type
73
+ self.full_action_space = full_action_space
74
+ self.mode = mode
75
+ self.difficulty = difficulty
76
+ self.repeat_action_probability = repeat_action_probability
77
+ self.frameskip = frameskip
78
+
79
+ # Create ALE interface
80
+ self.ale = ALEInterface()
81
+
82
+ # Configure ALE
83
+ from ale_py import LoggerMode
84
+
85
+ self.ale.setLoggerMode(LoggerMode.Error) # Error mode only
86
+ self.ale.setFloat("repeat_action_probability", repeat_action_probability)
87
+
88
+ # Load ROM
89
+ try:
90
+ rom_path = roms.get_rom_path(game_name)
91
+ self.ale.loadROM(rom_path)
92
+ except Exception as e:
93
+ raise ValueError(
94
+ f"Failed to load Atari game '{game_name}': {e}\n"
95
+ f"Available games can be found via: ale_py.roms.list_roms()"
96
+ ) from e
97
+
98
+ # Set mode and difficulty if specified
99
+ if mode is not None:
100
+ self.ale.setMode(mode)
101
+ if difficulty is not None:
102
+ self.ale.setDifficulty(difficulty)
103
+
104
+ # Get action set
105
+ if full_action_space:
106
+ self._action_set = self.ale.getLegalActionSet()
107
+ else:
108
+ self._action_set = self.ale.getMinimalActionSet()
109
+
110
+ # Get screen dimensions for observation space
111
+ self.screen_height, self.screen_width = self.ale.getScreenDims()
112
+ if obs_type == "rgb":
113
+ self.screen_shape = [self.screen_height, self.screen_width, 3]
114
+ elif obs_type == "grayscale":
115
+ self.screen_shape = [self.screen_height, self.screen_width]
116
+ elif obs_type == "ram":
117
+ self.screen_shape = [self.ale.getRAMSize()]
118
+ else:
119
+ raise ValueError(f"Invalid obs_type: {obs_type}")
120
+
121
+ # Initialize state
122
+ self._state = AtariState(
123
+ game_name=game_name,
124
+ obs_type=obs_type,
125
+ full_action_space=full_action_space,
126
+ mode=mode,
127
+ difficulty=difficulty,
128
+ repeat_action_probability=repeat_action_probability,
129
+ frameskip=frameskip,
130
+ )
131
+
132
+ def reset(self) -> Observation:
133
+ """
134
+ Reset the environment and return initial observation.
135
+
136
+ Returns:
137
+ Initial observation for the agent.
138
+ """
139
+ # Reset ALE
140
+ self.ale.reset_game()
141
+
142
+ # Reset state tracking
143
+ self._state.episode_id = str(uuid.uuid4())
144
+ self._state.step_count = 0
145
+
146
+ # Get initial observation
147
+ return self._make_observation()
148
+
149
+ def step(self, action: Action) -> Observation:
150
+ """
151
+ Execute agent's action and return resulting observation.
152
+
153
+ Args:
154
+ action: AtariAction containing the action_id to execute.
155
+
156
+ Returns:
157
+ Observation after action execution.
158
+
159
+ Raises:
160
+ ValueError: If action is not an AtariAction.
161
+ """
162
+ if not isinstance(action, AtariAction):
163
+ raise ValueError(f"Expected AtariAction, got {type(action)}")
164
+
165
+ # Validate action_id
166
+ if action.action_id < 0 or action.action_id >= len(self._action_set):
167
+ raise ValueError(
168
+ f"Invalid action_id: {action.action_id}. "
169
+ f"Valid range: [0, {len(self._action_set) - 1}]"
170
+ )
171
+
172
+ # Get actual ALE action
173
+ ale_action = self._action_set[action.action_id]
174
+
175
+ # Execute action with frameskip
176
+ total_reward = 0.0
177
+ for _ in range(self.frameskip):
178
+ total_reward += self.ale.act(ale_action)
179
+ if self.ale.game_over():
180
+ break
181
+
182
+ self._state.step_count += 1
183
+
184
+ # Get observation
185
+ obs = self._make_observation()
186
+ obs.reward = total_reward
187
+
188
+ return obs
189
+
190
+ @property
191
+ def state(self) -> AtariState:
192
+ """Get current environment state."""
193
+ return self._state
194
+
195
+ def _make_observation(self) -> AtariObservation:
196
+ """
197
+ Create an AtariObservation from current ALE state.
198
+
199
+ Returns:
200
+ AtariObservation for the agent.
201
+ """
202
+ # Get screen observation
203
+ if self.obs_type == "rgb":
204
+ screen = self.ale.getScreenRGB()
205
+ elif self.obs_type == "grayscale":
206
+ screen = self.ale.getScreenGrayscale()
207
+ elif self.obs_type == "ram":
208
+ screen = self.ale.getRAM()
209
+ else:
210
+ raise ValueError(f"Invalid obs_type: {self.obs_type}")
211
+
212
+ # Flatten screen for JSON serialization
213
+ # Handle both numpy arrays and lists
214
+ if hasattr(screen, "flatten"):
215
+ screen_flat = screen.flatten().tolist()
216
+ elif hasattr(screen, "tolist"):
217
+ screen_flat = screen.tolist()
218
+ else:
219
+ screen_flat = list(screen)
220
+
221
+ # Get game info
222
+ lives = self.ale.lives()
223
+ episode_frame_number = self.ale.getEpisodeFrameNumber()
224
+ frame_number = self.ale.getFrameNumber()
225
+ done = self.ale.game_over()
226
+
227
+ # Create legal actions list (indices into action_set)
228
+ legal_actions = list(range(len(self._action_set)))
229
+
230
+ # Create observation
231
+ obs = AtariObservation(
232
+ screen=screen_flat,
233
+ screen_shape=self.screen_shape,
234
+ legal_actions=legal_actions,
235
+ lives=lives,
236
+ episode_frame_number=episode_frame_number,
237
+ frame_number=frame_number,
238
+ done=done,
239
+ reward=0.0, # Will be filled in by step()
240
+ metadata={
241
+ "game_name": self.game_name,
242
+ "action_meanings": [str(a) for a in self._action_set],
243
+ },
244
+ )
245
+
246
+ return obs
server/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gymnasium>=0.29.0
2
+ ale-py>=0.8.0
3
+ numpy>=1.24.0
src/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """EnvTorch: Standardized agentic execution environments."""
src/core/README.md ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # <img width="35" height="35" alt="image" src="https://github.com/user-attachments/assets/2700a971-e5d6-4036-b03f-2f89c9791609" /> OpenEnv: Agentic Execution Environments
2
+
3
+ An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - step(), reset(), state(). Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs.
4
+
5
+ In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use.
6
+
7
+
8
+ ## Overview
9
+ `openenv.core` provides the foundational building blocks for creating and interacting with containerized environments over HTTP. It enables you to build agent environments that can be deployed as Docker containers and accessed via a simple HTTP API.
10
+
11
+ > ⚠️ **Early Development Warning** OpenEnv is currently in an experimental
12
+ > stage. You should expect bugs, incomplete features, and APIs that may change
13
+ > in future versions. The project welcomes bugfixes, but to make sure things are
14
+ > well coordinated you should discuss any significant change before starting the
15
+ > work. It's recommended that you signal your intention to contribute in the
16
+ > issue tracker, either by filing a new issue or by claiming an existing one.
17
+
18
+
19
+ # OpenEnv Core
20
+
21
+ Core components for OpenEnv - a framework for building HTTP-based agentic environments.
22
+
23
+ ## Features
24
+
25
+ - **EnvClient**: Async-first client for interacting with remote environments
26
+ - **SyncEnvClient**: Synchronous wrapper via `.sync()` for sync codebases
27
+ - **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket
28
+ - **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.)
29
+ - **Type System**: Strongly-typed Action/Observation/State interfaces
30
+ - **Web Interface**: Optional web UI for interacting with environments
31
+
32
+ ## Installation
33
+
34
+ ```bash
35
+ pip install "openenv[core]"
36
+ ```
37
+
38
+ For development:
39
+ ```bash
40
+ pip install "openenv[core]"
41
+ ```
42
+
43
+ ## Quick Start
44
+
45
+ ### Creating an Environment Client
46
+
47
+ EnvClient is **async by default**. Use `async with` and `await` for all operations:
48
+
49
+ ```python
50
+ import asyncio
51
+ from openenv.core import EnvClient, StepResult
52
+ from dataclasses import dataclass
53
+ from typing import Any
54
+
55
+ @dataclass
56
+ class MyAction:
57
+ text: str
58
+
59
+ @dataclass
60
+ class MyObservation:
61
+ response: str
62
+
63
+ class MyEnvClient(EnvClient[MyAction, MyObservation, Any]):
64
+ def _step_payload(self, action: MyAction) -> dict:
65
+ return {"text": action.text}
66
+
67
+ def _parse_result(self, payload: dict) -> StepResult[MyObservation]:
68
+ obs_data = payload["observation"]
69
+ return StepResult(
70
+ observation=MyObservation(**obs_data),
71
+ reward=payload.get("reward"),
72
+ done=payload.get("done", False)
73
+ )
74
+
75
+ def _parse_state(self, payload: dict) -> Any:
76
+ return payload
77
+
78
+ # Async usage (recommended)
79
+ async def main():
80
+ client = await MyEnvClient.from_docker_image("my-env:latest")
81
+ async with client:
82
+ result = await client.reset()
83
+ step_result = await client.step(MyAction(text="hello"))
84
+
85
+ asyncio.run(main())
86
+
87
+ # Sync usage (via .sync() wrapper)
88
+ with MyEnvClient(base_url="http://localhost:8000").sync() as client:
89
+ result = client.reset()
90
+ step_result = client.step(MyAction(text="hello"))
91
+ ```
92
+
93
+ ### Creating an Environment Server
94
+
95
+ ```python
96
+ from openenv.core.env_server import Environment, HTTPEnvServer, create_app
97
+ from dataclasses import dataclass
98
+
99
+ @dataclass
100
+ class MyAction:
101
+ text: str
102
+
103
+ @dataclass
104
+ class MyObservation:
105
+ response: str
106
+ reward: float = 0.0
107
+ done: bool = False
108
+
109
+ class MyEnvironment(Environment):
110
+ def reset(self) -> MyObservation:
111
+ return MyObservation(response="Ready")
112
+
113
+ def step(self, action: MyAction) -> MyObservation:
114
+ return MyObservation(
115
+ response=f"Echo: {action.text}",
116
+ reward=1.0,
117
+ done=False
118
+ )
119
+
120
+ # Create FastAPI app
121
+ env = MyEnvironment()
122
+ app = create_app(env, MyAction, MyObservation)
123
+
124
+ # Run with: uvicorn module:app --host 0.0.0.0 --port 8000
125
+ ```
126
+
127
+ ## Container Providers
128
+
129
+ OpenEnv Core supports multiple container providers:
130
+
131
+ ### Local Docker Provider
132
+
133
+ ```python
134
+ from openenv.core.containers.runtime import LocalDockerProvider
135
+
136
+ provider = LocalDockerProvider()
137
+ base_url = provider.start_container("my-env:latest")
138
+ provider.wait_for_ready(base_url)
139
+ # Use environment...
140
+ provider.stop_container()
141
+ ```
142
+
143
+ ### Kubernetes Provider (Coming Soon)
144
+
145
+ ```python
146
+ from openenv.core.containers.runtime import KubernetesProvider
147
+
148
+ provider = KubernetesProvider(namespace="envs")
149
+ base_url = provider.start_container("my-env:latest")
150
+ # Use environment...
151
+ provider.stop_container()
152
+ ```
153
+
154
+
155
+ ## API Reference
156
+
157
+ ### EnvClient
158
+
159
+ Async base class for environment clients. Key methods:
160
+
161
+ - `async connect()`: Establish WebSocket connection
162
+ - `async reset(**kwargs)`: Reset environment
163
+ - `async step(action)`: Execute action
164
+ - `async state()`: Get current state
165
+ - `async close()`: Close connection and cleanup
166
+ - `sync()`: Return a SyncEnvClient wrapper for synchronous usage
167
+
168
+ Abstract methods to implement:
169
+ - `_step_payload(action)`: Convert action to JSON
170
+ - `_parse_result(payload)`: Parse response to StepResult
171
+ - `_parse_state(payload)`: Parse state response
172
+
173
+ ### SyncEnvClient
174
+
175
+ Synchronous wrapper around EnvClient. Use `client.sync()` to get one:
176
+
177
+ ```python
178
+ sync_client = async_client.sync()
179
+ with sync_client:
180
+ result = sync_client.reset()
181
+ result = sync_client.step(action)
182
+ ```
183
+
184
+ ### HTTPEnvServer
185
+
186
+ Server wrapper with these methods:
187
+
188
+ - `register_routes(app)`: Register endpoints on FastAPI app
189
+ - `_deserialize_action(data)`: Convert JSON to Action
190
+ - `_serialize_observation(obs)`: Convert Observation to JSON
191
+
192
+ ### Environment Interface
193
+
194
+ Base interface for environment implementations:
195
+
196
+ - `reset()`: Reset environment and return initial observation
197
+ - `step(action)`: Execute action and return observation
198
+ - `state`: Property returning current environment state
199
+
200
+ ## License
201
+
202
+ This project is licensed under the BSD-3-Clause License - see the LICENSE file for details.
203
+
204
+ ## Contributing
205
+
206
+ Contributions are welcome! Please see the main OpenEnv repository for contribution guidelines.
207
+
208
+ ## Links
209
+
210
+ - **Homepage**: https://github.com/meta-pytorch/OpenEnv
211
+ - **Documentation**: https://github.com/meta-pytorch/OpenEnv/blob/main/README.md
212
+ - **Bug Tracker**: https://github.com/meta-pytorch/OpenEnv/issues
src/core/__init__.py CHANGED
@@ -6,14 +6,76 @@
6
 
7
  """Core components for agentic environments."""
8
 
9
- # Re-export main components from submodules for convenience
10
- from .env_server import *
11
- from .http_env_client import HTTPEnvClient
12
- from .types import StepResult
13
 
14
- # Note: MCP module doesn't export anything yet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  __all__ = [
17
- "HTTPEnvClient",
18
- "StepResult",
19
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  """Core components for agentic environments."""
8
 
9
+ from __future__ import annotations
 
 
 
10
 
11
+ from importlib import import_module
12
+ from typing import TYPE_CHECKING
13
+
14
+ from . import env_server
15
+ from .env_server import * # noqa: F403
16
+
17
+ if TYPE_CHECKING:
18
+ from .env_client import EnvClient
19
+ from .generic_client import GenericAction, GenericEnvClient
20
+ from .llm_client import (
21
+ AnthropicClient,
22
+ create_llm_client,
23
+ LLMClient,
24
+ LLMResponse,
25
+ OpenAIClient,
26
+ ToolCall,
27
+ )
28
+ from .mcp_client import MCPClientBase, MCPToolClient
29
+ from .sync_client import SyncEnvClient
30
 
31
  __all__ = [
32
+ "EnvClient",
33
+ "SyncEnvClient",
34
+ "GenericEnvClient",
35
+ "GenericAction",
36
+ "MCPClientBase",
37
+ "MCPToolClient",
38
+ "AnthropicClient",
39
+ "LLMClient",
40
+ "LLMResponse",
41
+ "OpenAIClient",
42
+ "ToolCall",
43
+ "create_llm_client",
44
+ ] + env_server.__all__ # type: ignore
45
+
46
+
47
+ _LAZY_ATTRS = {
48
+ "EnvClient": (".env_client", "EnvClient"),
49
+ "SyncEnvClient": (".sync_client", "SyncEnvClient"),
50
+ "GenericEnvClient": (".generic_client", "GenericEnvClient"),
51
+ "GenericAction": (".generic_client", "GenericAction"),
52
+ "MCPClientBase": (".mcp_client", "MCPClientBase"),
53
+ "MCPToolClient": (".mcp_client", "MCPToolClient"),
54
+ "AnthropicClient": (".llm_client", "AnthropicClient"),
55
+ "LLMClient": (".llm_client", "LLMClient"),
56
+ "LLMResponse": (".llm_client", "LLMResponse"),
57
+ "OpenAIClient": (".llm_client", "OpenAIClient"),
58
+ "ToolCall": (".llm_client", "ToolCall"),
59
+ "create_llm_client": (".llm_client", "create_llm_client"),
60
+ }
61
+
62
+
63
+ def __getattr__(name: str):
64
+ if name in _LAZY_ATTRS:
65
+ module_path, attr_name = _LAZY_ATTRS[name]
66
+ module = import_module(module_path, __name__)
67
+ value = getattr(module, attr_name)
68
+ globals()[name] = value
69
+ return value
70
+
71
+ try:
72
+ value = getattr(env_server, name)
73
+ except AttributeError as exc:
74
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc
75
+
76
+ globals()[name] = value
77
+ return value
78
+
79
+
80
+ def __dir__() -> list[str]:
81
+ return sorted(set(globals().keys()) | set(__all__))
src/core/client_types.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Type definitions for EnvTorch
2
+ from dataclasses import dataclass
3
+ from typing import Generic, Optional, TypeVar
4
+
5
+ # Generic type for observations
6
+ ObsT = TypeVar("ObsT")
7
+ StateT = TypeVar("StateT")
8
+
9
+
10
+ @dataclass
11
+ class StepResult(Generic[ObsT]):
12
+ """
13
+ Represents the result of one environment step.
14
+
15
+ Attributes:
16
+ observation: The environment's observation after the action.
17
+ reward: Scalar reward for this step (optional).
18
+ done: Whether the episode is finished.
19
+ """
20
+
21
+ observation: ObsT
22
+ reward: Optional[float] = None
23
+ done: bool = False
src/core/containers/__init__.py CHANGED
@@ -4,4 +4,4 @@
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- """Container management for environment servers."""
 
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ """Container management for environment servers."""
src/core/containers/images/Dockerfile CHANGED
@@ -8,30 +8,47 @@
8
  # OpenEnv Base Image
9
  #
10
  # This is the standard base image for all OpenEnv environment servers.
11
- # It includes the minimal dependencies needed to run HTTP environment servers.
 
12
  #
13
- # Build: docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
14
- # Tag: docker tag openenv-base:latest openenv-base:0.1.0
15
  #
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  FROM python:3.11-slim
18
 
19
  # Set metadata
20
  LABEL maintainer="OpenEnv Team"
21
- LABEL description="Base image for OpenEnv based environment servers"
22
- LABEL version="0.1.0"
23
 
24
  # Install system dependencies
25
  RUN apt-get update && apt-get install -y --no-install-recommends \
26
  curl \
 
27
  && rm -rf /var/lib/apt/lists/*
28
 
29
- # Install Python dependencies that all environments need
30
- RUN pip install --no-cache-dir \
31
- fastapi>=0.104.0 \
32
- "uvicorn[standard]>=0.24.0" \
33
- requests>=2.25.0 \
34
- wsproto>=1.0.0
 
 
35
 
36
  # Set working directory
37
  WORKDIR /app
@@ -39,6 +56,7 @@ WORKDIR /app
39
  # Default environment variables
40
  ENV PYTHONPATH=/app/src
41
  ENV PYTHONUNBUFFERED=1
 
42
 
43
  # Default expose port (can be overridden)
44
  EXPOSE 8000
 
8
  # OpenEnv Base Image
9
  #
10
  # This is the standard base image for all OpenEnv environment servers.
11
+ # It includes the minimal dependencies needed to run HTTP environment servers
12
+ # and uv for fast dependency management.
13
  #
14
+ # Build from repo root: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
15
+ # Tag: docker tag openenv-base:latest openenv-base:0.2.0
16
  #
17
 
18
+ FROM ghcr.io/astral-sh/uv:0.5.27-python3.11-bookworm-slim AS builder
19
+
20
+ # Set working directory
21
+ WORKDIR /app
22
+
23
+ # Copy core pyproject.toml and lockfile for dependency installation
24
+ COPY pyproject.toml uv.lock* ./
25
+
26
+ # Install core dependencies using uv with cache mount
27
+ RUN --mount=type=cache,target=/root/.cache/uv \
28
+ uv pip install --system -r pyproject.toml
29
+
30
+ # Final runtime stage
31
  FROM python:3.11-slim
32
 
33
  # Set metadata
34
  LABEL maintainer="OpenEnv Team"
35
+ LABEL description="Base image for OpenEnv based environment servers with uv"
36
+ LABEL version="0.2.0"
37
 
38
  # Install system dependencies
39
  RUN apt-get update && apt-get install -y --no-install-recommends \
40
  curl \
41
+ ca-certificates \
42
  && rm -rf /var/lib/apt/lists/*
43
 
44
+ # Copy uv from builder
45
+ COPY --from=builder /usr/local/bin/uv /usr/local/bin/uvx /usr/local/bin/
46
+
47
+ # Copy installed Python packages from builder
48
+ COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
49
+
50
+ # Copy console scripts installed by pip (uvicorn, fastapi, etc.)
51
+ COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/fastapi /usr/local/bin/
52
 
53
  # Set working directory
54
  WORKDIR /app
 
56
  # Default environment variables
57
  ENV PYTHONPATH=/app/src
58
  ENV PYTHONUNBUFFERED=1
59
+ ENV UV_SYSTEM_PYTHON=1
60
 
61
  # Default expose port (can be overridden)
62
  EXPOSE 8000
src/core/containers/images/README.md CHANGED
@@ -36,7 +36,7 @@ Total: 465 MB (base shared, minimal duplication)
36
 
37
  ```bash
38
  # From project root
39
- docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
40
  ```
41
 
42
  ## Usage in Environment Dockerfiles
@@ -47,8 +47,8 @@ Each environment Dockerfile should start with:
47
  FROM openenv-base:latest
48
 
49
  # Copy only environment-specific files
50
- COPY src/core/ /app/src/core/
51
- COPY src/envs/my_env/ /app/src/envs/my_env/
52
 
53
  # Run the server
54
  CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
@@ -66,10 +66,10 @@ CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "80
66
 
67
  ```bash
68
  # Step 1: Build base image (do this once)
69
- docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
70
 
71
  # Step 2: Build echo environment (uses base)
72
- docker build -t echo-env:latest -f src/envs/echo_env/server/Dockerfile .
73
 
74
  # Step 3: Run echo environment
75
  docker run -p 8000:8000 echo-env:latest
@@ -79,14 +79,14 @@ docker run -p 8000:8000 echo-env:latest
79
 
80
  When dependencies need updating:
81
 
82
- 1. Update `src/core/containers/images/Dockerfile`
83
  2. Rebuild base image
84
  3. Rebuild all environment images (they'll use new base)
85
 
86
  ```bash
87
  # Update base
88
- docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
89
 
90
  # Rebuild environments (they automatically use new base)
91
- docker build -t echo-env:latest -f src/envs/echo_env/server/Dockerfile .
92
  ```
 
36
 
37
  ```bash
38
  # From project root
39
+ docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
40
  ```
41
 
42
  ## Usage in Environment Dockerfiles
 
47
  FROM openenv-base:latest
48
 
49
  # Copy only environment-specific files
50
+ COPY src/openenv/core/ /app/src/openenv/core/
51
+ COPY envs/my_env/ /app/envs/my_env/
52
 
53
  # Run the server
54
  CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
66
 
67
  ```bash
68
  # Step 1: Build base image (do this once)
69
+ docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
70
 
71
  # Step 2: Build echo environment (uses base)
72
+ docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
73
 
74
  # Step 3: Run echo environment
75
  docker run -p 8000:8000 echo-env:latest
 
79
 
80
  When dependencies need updating:
81
 
82
+ 1. Update `src/openenv/core/containers/images/Dockerfile`
83
  2. Rebuild base image
84
  3. Rebuild all environment images (they'll use new base)
85
 
86
  ```bash
87
  # Update base
88
+ docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
89
 
90
  # Rebuild environments (they automatically use new base)
91
+ docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
92
  ```
src/core/containers/runtime/__init__.py CHANGED
@@ -6,10 +6,20 @@
6
 
7
  """Container runtime providers."""
8
 
9
- from .providers import ContainerProvider, KubernetesProvider, LocalDockerProvider
 
 
 
 
 
 
 
10
 
11
  __all__ = [
12
  "ContainerProvider",
 
13
  "LocalDockerProvider",
14
  "KubernetesProvider",
15
- ]
 
 
 
6
 
7
  """Container runtime providers."""
8
 
9
+ from .providers import (
10
+ ContainerProvider,
11
+ DockerSwarmProvider,
12
+ KubernetesProvider,
13
+ LocalDockerProvider,
14
+ RuntimeProvider,
15
+ )
16
+ from .uv_provider import UVProvider
17
 
18
  __all__ = [
19
  "ContainerProvider",
20
+ "DockerSwarmProvider",
21
  "LocalDockerProvider",
22
  "KubernetesProvider",
23
+ "RuntimeProvider",
24
+ "UVProvider",
25
+ ]
src/core/containers/runtime/daytona_provider.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Daytona container provider for running OpenEnv environments in Daytona cloud sandboxes.
9
+
10
+ Requires the ``daytona`` SDK: ``pip install daytona>=0.10``
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import shlex
18
+ import time
19
+ from typing import Any, Callable, Dict, Optional
20
+
21
+ import yaml
22
+
23
+ from .providers import ContainerProvider
24
+
25
+
26
+ class DaytonaProvider(ContainerProvider):
27
+ """
28
+ Container provider that runs environments in Daytona cloud sandboxes.
29
+
30
+ Example:
31
+ >>> provider = DaytonaProvider(api_key="your-key")
32
+ >>> image = DaytonaProvider.image_from_dockerfile("envs/echo_env/server/Dockerfile")
33
+ >>> base_url = provider.start_container(image)
34
+ >>> provider.wait_for_ready(base_url)
35
+ >>> provider.stop_container()
36
+ """
37
+
38
+ _dockerfile_registry: Dict[str, Dict[str, Any]] = {}
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ api_key: Optional[str] = None,
44
+ public: bool = False,
45
+ resources: Optional[Any] = None,
46
+ auto_stop_interval: int = 15,
47
+ target: Optional[str] = None,
48
+ on_snapshot_create_logs: Optional[Callable[[str], None]] = None,
49
+ cmd: Optional[str] = None,
50
+ create_timeout: float = 300,
51
+ ):
52
+ """
53
+ Args:
54
+ api_key: Daytona API key. Falls back to ``DAYTONA_API_KEY`` env var.
55
+ public: If True, the sandbox preview is publicly accessible.
56
+ resources: Optional ``daytona.Resources`` instance for CPU/memory.
57
+ auto_stop_interval: Minutes of inactivity before auto-stop (0 disables).
58
+ target: Daytona target region (e.g. "us").
59
+ on_snapshot_create_logs: Callback for snapshot build log lines.
60
+ cmd: Shell command to start the server inside the sandbox.
61
+ create_timeout: Seconds to wait for sandbox creation (default 300).
62
+ Heavy images (e.g. with Playwright/Chromium) may need more.
63
+ """
64
+ from daytona import Daytona, DaytonaConfig
65
+
66
+ config_kwargs: Dict[str, Any] = {}
67
+ resolved_key = api_key or os.environ.get("DAYTONA_API_KEY")
68
+ if resolved_key:
69
+ config_kwargs["api_key"] = resolved_key
70
+ if target:
71
+ config_kwargs["target"] = target
72
+
73
+ self._daytona = Daytona(DaytonaConfig(**config_kwargs))
74
+ self._public = public
75
+ self._resources = resources
76
+ self._auto_stop_interval = auto_stop_interval
77
+ self._on_snapshot_create_logs = on_snapshot_create_logs
78
+ self._cmd = cmd
79
+ self._create_timeout = create_timeout
80
+ self._sandbox: Any = None
81
+ self._preview_url: Optional[str] = None
82
+
83
+ def _discover_server_cmd(self, sandbox: Any, port: int = 8000) -> str:
84
+ """Discover the server command from ``openenv.yaml`` inside *sandbox*.
85
+
86
+ Finds the file, reads the ``app`` field, and constructs a command
87
+ of the form ``cd <env_root> && python -m uvicorn <app> --host 0.0.0.0 --port <port>``.
88
+
89
+ Raises:
90
+ ValueError: If ``openenv.yaml`` is not found or lacks an ``app`` field.
91
+ """
92
+ yaml_path = self._find_openenv_yaml(sandbox)
93
+ if yaml_path is None:
94
+ raise ValueError(
95
+ "Could not find openenv.yaml inside the sandbox. "
96
+ "Pass an explicit cmd= to DaytonaProvider or start_container()."
97
+ )
98
+
99
+ cat_resp = sandbox.process.exec(f"cat {shlex.quote(yaml_path)}", timeout=10)
100
+ content = cat_resp.result if hasattr(cat_resp, "result") else str(cat_resp)
101
+ app = self._parse_app_field(content)
102
+ if app is None:
103
+ raise ValueError(
104
+ f"openenv.yaml at {yaml_path} does not contain an 'app' field. "
105
+ "Pass an explicit cmd= to DaytonaProvider or start_container()."
106
+ )
107
+
108
+ # The directory containing openenv.yaml is the env root
109
+ env_root = yaml_path.rsplit("/", 1)[0]
110
+ return (
111
+ f"cd {shlex.quote(env_root)} && "
112
+ f"python -m uvicorn {shlex.quote(app)} --host 0.0.0.0 --port {port}"
113
+ )
114
+
115
+ def _find_openenv_yaml(self, sandbox: Any) -> Optional[str]:
116
+ """Locate ``openenv.yaml`` inside the sandbox.
117
+
118
+ Tries the modern layout path ``/app/env/openenv.yaml`` first,
119
+ then falls back to a ``find`` command for the old layout.
120
+ """
121
+ # Fast path: modern Dockerfile layout
122
+ resp = sandbox.process.exec(
123
+ "test -f /app/env/openenv.yaml && echo found", timeout=10
124
+ )
125
+ out = resp.result if hasattr(resp, "result") else str(resp)
126
+ if "found" in (out or ""):
127
+ return "/app/env/openenv.yaml"
128
+
129
+ # Fallback: search for it (redirect stderr so error messages
130
+ # like "No such file or directory" don't get mistaken for paths).
131
+ resp = sandbox.process.exec(
132
+ "find /app -maxdepth 4 -name openenv.yaml -print -quit 2>/dev/null",
133
+ timeout=10,
134
+ )
135
+ path = (resp.result if hasattr(resp, "result") else str(resp) or "").strip()
136
+ if path and path.startswith("/"):
137
+ return path
138
+
139
+ return None
140
+
141
+ @staticmethod
142
+ def _parse_app_field(yaml_content: str) -> Optional[str]:
143
+ """Extract the ``app`` value from raw openenv.yaml content.
144
+
145
+ Uses PyYAML to handle comments, quotes, and nested keys correctly.
146
+ """
147
+ try:
148
+ data = yaml.safe_load(yaml_content) or {}
149
+ except Exception:
150
+ return None
151
+
152
+ if not isinstance(data, dict):
153
+ return None
154
+
155
+ value = data.get("app")
156
+ if isinstance(value, str):
157
+ value = value.strip()
158
+ return value if value else None
159
+ return None
160
+
161
+ @staticmethod
162
+ def _parse_dockerfile_cmd(dockerfile_content: str) -> Optional[str]:
163
+ """Extract the server command from the last ``CMD`` in a Dockerfile.
164
+
165
+ Handles exec form (``CMD ["prog", "arg"]``) and shell form
166
+ (``CMD prog arg``). When a Dockerfile has multiple ``CMD``
167
+ instructions (e.g. multi-stage builds), the last one wins - same
168
+ semantics as Docker itself. Lines where ``CMD`` appears inside a
169
+ comment are ignored.
170
+
171
+ Returns:
172
+ The command as a single string, or ``None`` if no ``CMD`` found.
173
+ """
174
+ import re
175
+
176
+ last_cmd: Optional[str] = None
177
+ for line in dockerfile_content.splitlines():
178
+ stripped = line.strip()
179
+ # Skip comments
180
+ if stripped.startswith("#"):
181
+ continue
182
+ match = re.match(r"CMD\s+(.+)", stripped, flags=re.IGNORECASE)
183
+ if match:
184
+ last_cmd = match.group(1).strip()
185
+
186
+ if last_cmd is None:
187
+ return None
188
+
189
+ # Exec form: CMD ["executable", "param1", ...]
190
+ if last_cmd.startswith("["):
191
+ try:
192
+ parts = json.loads(last_cmd)
193
+ if isinstance(parts, list) and all(isinstance(p, str) for p in parts):
194
+ return " ".join(parts)
195
+ except (json.JSONDecodeError, TypeError):
196
+ pass
197
+
198
+ # Shell form: CMD executable param1 ...
199
+ return last_cmd if last_cmd else None
200
+
201
+ @staticmethod
202
+ def strip_buildkit_syntax(dockerfile_content: str) -> str:
203
+ """Remove BuildKit ``--mount=...`` flags from ``RUN`` instructions.
204
+
205
+ Handles single-line flags, multi-line continuations, and multiple
206
+ ``--mount`` flags spread across continuation lines. Only leading
207
+ ``--mount`` flags are removed (before the actual command starts).
208
+
209
+ Daytona's ``Image.from_dockerfile`` does not support BuildKit
210
+ ``--mount`` syntax. This helper strips the flags so that standard
211
+ Dockerfiles (like the ones generated by ``openenv build``) can
212
+ be used directly.
213
+ """
214
+ import re
215
+
216
+ def strip_leading_mounts(text: str) -> str:
217
+ remaining = text
218
+ while True:
219
+ match = re.match(r"\s*--mount=\S+\s*", remaining)
220
+ if not match:
221
+ return remaining
222
+ remaining = remaining[match.end() :]
223
+
224
+ lines = dockerfile_content.split("\n")
225
+ result: list[str] = []
226
+ in_run = False
227
+ in_mount_prefix = False
228
+
229
+ for line in lines:
230
+ line_out = line
231
+ run_start = False
232
+ if re.match(r"\s*RUN(\s+|$)", line, flags=re.IGNORECASE):
233
+ in_run = True
234
+ in_mount_prefix = True
235
+ run_start = True
236
+
237
+ if in_run and in_mount_prefix:
238
+ original_ends_with_slash = line_out.rstrip().endswith("\\")
239
+ if run_start:
240
+ match = re.match(r"(\s*RUN\s+)(.*)$", line_out, flags=re.IGNORECASE)
241
+ if match:
242
+ run_prefix, remainder = match.group(1), match.group(2)
243
+ else:
244
+ run_prefix, remainder = line_out, ""
245
+ new_remainder = strip_leading_mounts(remainder)
246
+ line_out = run_prefix + new_remainder
247
+ content_for_check = new_remainder
248
+ else:
249
+ new_remainder = strip_leading_mounts(line_out)
250
+ line_out = new_remainder
251
+ content_for_check = new_remainder
252
+
253
+ if original_ends_with_slash and not line_out.rstrip().endswith("\\"):
254
+ line_out = line_out.rstrip() + " \\"
255
+
256
+ if content_for_check.strip() not in ("", "\\"):
257
+ in_mount_prefix = False
258
+
259
+ if in_run and not line_out.rstrip().endswith("\\"):
260
+ in_run = False
261
+ in_mount_prefix = False
262
+
263
+ result.append(line_out)
264
+
265
+ return "\n".join(result)
266
+
267
+ @classmethod
268
+ def image_from_dockerfile(
269
+ cls,
270
+ dockerfile_path: str,
271
+ context_dir: str | None = None,
272
+ ) -> str:
273
+ """Validate a Dockerfile and return a ``dockerfile:`` URI for
274
+ :meth:`start_container`.
275
+
276
+ Eagerly validates the Dockerfile (existence, COPY sources,
277
+ BuildKit stripping) and stores the processed content in an
278
+ internal registry. The actual ``daytona.Image`` is created
279
+ later inside ``start_container``.
280
+
281
+ Args:
282
+ dockerfile_path: Path to the Dockerfile on disk.
283
+ context_dir: Build context directory. Defaults to the
284
+ Dockerfile's grandparent directory, matching the
285
+ ``openenv init`` convention where Dockerfiles live in
286
+ ``<env>/server/Dockerfile`` and the build context is
287
+ ``<env>/``. Pass explicitly for non-standard layouts
288
+ (e.g. ``context_dir="."`` for repo-root contexts).
289
+
290
+ Returns:
291
+ A ``"dockerfile:<abs_path>"`` string to pass to
292
+ ``start_container``.
293
+
294
+ Raises:
295
+ FileNotFoundError: If *dockerfile_path* does not exist.
296
+ ValueError: If *context_dir* is given but does not exist,
297
+ or if COPY sources in the Dockerfile cannot be found
298
+ under the resolved context directory.
299
+ """
300
+ import pathlib
301
+ import re
302
+
303
+ src = pathlib.Path(dockerfile_path).resolve()
304
+ if not src.is_file():
305
+ raise FileNotFoundError(f"Dockerfile not found: {dockerfile_path}")
306
+
307
+ if context_dir is not None:
308
+ ctx = pathlib.Path(context_dir)
309
+ if not ctx.is_dir():
310
+ raise ValueError(f"context_dir does not exist: {context_dir}")
311
+ else:
312
+ # Default: grandparent of the Dockerfile, matching the
313
+ # openenv init layout (<env>/server/Dockerfile -> <env>/).
314
+ ctx = src.parent.parent
315
+
316
+ content = src.read_text()
317
+ stripped = cls.strip_buildkit_syntax(content)
318
+
319
+ # Validate that COPY sources exist under the context directory.
320
+ # This catches mismatches early (e.g. a Dockerfile expecting repo
321
+ # root as context when we defaulted to the env directory).
322
+ for line in stripped.splitlines():
323
+ m = re.match(r"^\s*COPY\s+(?!--from=)(\S+)\s+", line, re.IGNORECASE)
324
+ if not m:
325
+ continue
326
+ copy_src = m.group(1)
327
+ if copy_src.startswith("/"):
328
+ continue
329
+ resolved = ctx / copy_src
330
+ if not resolved.exists() and not any(ctx.glob(copy_src)):
331
+ raise ValueError(
332
+ f"Dockerfile COPY source '{copy_src}' not found "
333
+ f"under context_dir '{ctx}'. This Dockerfile may "
334
+ f"expect a different build context (e.g. the repo "
335
+ f"root). Pass context_dir explicitly."
336
+ )
337
+
338
+ # Parse CMD from the original Dockerfile so start_container can
339
+ # use it as a fallback when openenv.yaml is unavailable.
340
+ parsed_cmd = cls._parse_dockerfile_cmd(content)
341
+
342
+ cls._dockerfile_registry[str(src)] = {
343
+ "stripped_content": stripped,
344
+ "context_dir": str(ctx),
345
+ "server_cmd": parsed_cmd,
346
+ }
347
+
348
+ return f"dockerfile:{src}"
349
+
350
+ def start_container(
351
+ self,
352
+ image: str,
353
+ port: Optional[int] = None,
354
+ env_vars: Optional[Dict[str, str]] = None,
355
+ **kwargs: Any,
356
+ ) -> str:
357
+ """
358
+ Create a Daytona sandbox from a Docker image or snapshot.
359
+
360
+ Daytona does not execute the image's CMD (known bug — ENTRYPOINT
361
+ runs, CMD does not). The server command is resolved in order:
362
+
363
+ 1. Explicit ``cmd`` passed to the constructor.
364
+ 2. ``cmd`` key in ``**kwargs`` (popped before forwarding).
365
+ 3. Auto-discovered from ``openenv.yaml`` inside the sandbox.
366
+ 4. ``CMD`` parsed from the Dockerfile (when *image* came from
367
+ ``image_from_dockerfile``).
368
+
369
+ Args:
370
+ image: Docker image name (e.g. ``"echo-env:latest"``),
371
+ ``"snapshot:<name>"`` to create from a pre-built snapshot,
372
+ or ``"dockerfile:<path>"`` returned by
373
+ :meth:`image_from_dockerfile`.
374
+ port: Must be ``None`` or ``8000``. Daytona exposes port 8000
375
+ via its preview proxy; other ports raise ``ValueError``.
376
+ env_vars: Environment variables forwarded to the sandbox.
377
+ **kwargs: ``cmd`` (str) to override the server command;
378
+ remaining kwargs passed through to ``Daytona.create()``.
379
+
380
+ Returns:
381
+ HTTPS preview URL for the sandbox (base_url).
382
+ """
383
+ if port is not None and port != 8000:
384
+ raise ValueError(
385
+ f"DaytonaProvider only supports port 8000 (got {port}). "
386
+ "The Daytona preview proxy routes to port 8000 inside the sandbox."
387
+ )
388
+
389
+ # Resolve the server command (may be None; discovery happens after
390
+ # sandbox creation when we can inspect the filesystem).
391
+ cmd = kwargs.pop("cmd", None) or self._cmd
392
+
393
+ # CMD parsed from Dockerfile (populated for "dockerfile:" images).
394
+ parsed_cmd: Optional[str] = None
395
+
396
+ # Build creation params
397
+ create_kwargs: Dict[str, Any] = {}
398
+ if env_vars:
399
+ create_kwargs["env_vars"] = env_vars
400
+ if self._public:
401
+ create_kwargs["public"] = True
402
+ if self._auto_stop_interval != 15:
403
+ create_kwargs["auto_stop_interval"] = self._auto_stop_interval
404
+
405
+ if image.startswith("snapshot:"):
406
+ from daytona import CreateSandboxFromSnapshotParams
407
+
408
+ snapshot_name = image[len("snapshot:") :]
409
+ params = CreateSandboxFromSnapshotParams(
410
+ snapshot=snapshot_name, **create_kwargs
411
+ )
412
+ elif image.startswith("dockerfile:"):
413
+ from daytona import CreateSandboxFromImageParams, Image
414
+
415
+ dockerfile_path = image[len("dockerfile:") :]
416
+ meta = self._dockerfile_registry.get(dockerfile_path)
417
+ if meta is None:
418
+ raise ValueError(
419
+ f"No registered Dockerfile metadata for {dockerfile_path}. "
420
+ "Call DaytonaProvider.image_from_dockerfile() first."
421
+ )
422
+
423
+ parsed_cmd = meta.get("server_cmd")
424
+
425
+ # Build the daytona Image from the pre-stripped content.
426
+ import pathlib
427
+ import uuid
428
+
429
+ ctx = pathlib.Path(meta["context_dir"])
430
+ tmp_name = f".daytona-{uuid.uuid4().hex[:8]}.dockerfile"
431
+ tmp_path = ctx / tmp_name
432
+ try:
433
+ tmp_path.write_text(meta["stripped_content"])
434
+ daytona_image = Image.from_dockerfile(str(tmp_path))
435
+ finally:
436
+ tmp_path.unlink(missing_ok=True)
437
+
438
+ img_kwargs: Dict[str, Any] = {
439
+ "image": daytona_image,
440
+ **create_kwargs,
441
+ }
442
+ if self._resources is not None:
443
+ img_kwargs["resources"] = self._resources
444
+ params = CreateSandboxFromImageParams(**img_kwargs)
445
+ else:
446
+ from daytona import CreateSandboxFromImageParams
447
+
448
+ img_kwargs = {"image": image, **create_kwargs}
449
+ if self._resources is not None:
450
+ img_kwargs["resources"] = self._resources
451
+ params = CreateSandboxFromImageParams(**img_kwargs)
452
+
453
+ # Create sandbox
454
+ extra: Dict[str, Any] = dict(kwargs)
455
+ if self._on_snapshot_create_logs is not None:
456
+ extra["on_snapshot_create_logs"] = self._on_snapshot_create_logs
457
+
458
+ self._sandbox = self._daytona.create(
459
+ params, timeout=self._create_timeout, **extra
460
+ )
461
+
462
+ try:
463
+ # Discover server command from openenv.yaml if not explicitly set.
464
+ if cmd is None:
465
+ try:
466
+ cmd = self._discover_server_cmd(self._sandbox)
467
+ except ValueError:
468
+ # Fall back to CMD parsed from Dockerfile (if available).
469
+ if parsed_cmd:
470
+ cmd = parsed_cmd
471
+ else:
472
+ raise
473
+
474
+ # Wrap in bash -c so compound commands (cd ... && uvicorn ...)
475
+ # are handled correctly by nohup. Write PID so we can check
476
+ # if the process crashed later in wait_for_ready().
477
+ escaped_cmd = shlex.quote(cmd)
478
+ self._sandbox.process.exec(
479
+ f"nohup bash -c {escaped_cmd} > /tmp/openenv-server.log 2>&1 &"
480
+ " echo $! > /tmp/openenv-server.pid",
481
+ timeout=10,
482
+ )
483
+
484
+ # Get a signed preview URL for port 8000. The token is
485
+ # embedded in the URL itself so no extra headers are needed.
486
+ signed = self._sandbox.create_signed_preview_url(
487
+ 8000, expires_in_seconds=86400
488
+ )
489
+ self._preview_url = signed.url
490
+ except Exception:
491
+ self.stop_container()
492
+ raise
493
+
494
+ return self._preview_url
495
+
496
+ def refresh_preview_url(self) -> str:
497
+ """Get a fresh signed preview URL (valid for 24h).
498
+
499
+ Daytona signed URLs expire after at most 24 hours. Call this to
500
+ get a new one for long-running sessions. The returned URL points
501
+ to the same sandbox — clients will need to reconnect using it.
502
+ """
503
+ if self._sandbox is None:
504
+ raise RuntimeError("No active sandbox to refresh URL for.")
505
+ signed = self._sandbox.create_signed_preview_url(8000, expires_in_seconds=86400)
506
+ self._preview_url = signed.url
507
+ return self._preview_url
508
+
509
+ def stop_container(self) -> None:
510
+ """Delete the Daytona sandbox."""
511
+ if self._sandbox is None:
512
+ return
513
+
514
+ try:
515
+ self._daytona.delete(self._sandbox)
516
+ finally:
517
+ self._sandbox = None
518
+ self._preview_url = None
519
+
520
+ def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
521
+ """
522
+ Poll the /health endpoint until the sandbox is ready.
523
+
524
+ Uses a longer default timeout (120s) than Docker providers because
525
+ Daytona sandboxes may have cold-start latency.
526
+
527
+ Args:
528
+ base_url: Preview URL returned by ``start_container()``.
529
+ timeout_s: Maximum seconds to wait.
530
+
531
+ Raises:
532
+ TimeoutError: If the sandbox doesn't become ready in time.
533
+ RuntimeError: If the server process died (detected via PID check).
534
+ """
535
+ import requests
536
+
537
+ health_url = f"{base_url}/health"
538
+
539
+ deadline = time.time() + timeout_s
540
+ while time.time() < deadline:
541
+ try:
542
+ response = requests.get(health_url, timeout=5.0)
543
+ if response.status_code == 200:
544
+ return
545
+ except requests.RequestException:
546
+ pass
547
+
548
+ # Early exit: if the server process died, raise immediately
549
+ # instead of waiting for the full health-check timeout.
550
+ if self._sandbox is not None:
551
+ resp = self._sandbox.process.exec(
552
+ "kill -0 $(cat /tmp/openenv-server.pid) 2>/dev/null"
553
+ " && echo RUNNING || echo DEAD",
554
+ timeout=10,
555
+ )
556
+ out = resp.result if hasattr(resp, "result") else str(resp)
557
+ if "DEAD" in (out or ""):
558
+ log_resp = self._sandbox.process.exec(
559
+ "cat /tmp/openenv-server.log 2>/dev/null", timeout=10
560
+ )
561
+ log = (
562
+ log_resp.result
563
+ if hasattr(log_resp, "result")
564
+ else str(log_resp)
565
+ )
566
+ raise RuntimeError(f"Server process died.\nLog:\n{log}")
567
+
568
+ time.sleep(1.0)
569
+
570
+ raise TimeoutError(
571
+ f"Daytona sandbox at {base_url} did not become ready within {timeout_s}s"
572
+ )
src/core/containers/runtime/providers.py CHANGED
@@ -8,13 +8,13 @@
8
  Container provider abstractions for running environment servers.
9
 
10
  This module provides a pluggable architecture for different container providers
11
- (local Docker, Kubernetes, cloud providers, etc.) to be used with HTTPEnvClient.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  from abc import ABC, abstractmethod
17
- from typing import Any, Dict, Optional
18
 
19
 
20
  class ContainerProvider(ABC):
@@ -118,7 +118,11 @@ class LocalDockerProvider(ContainerProvider):
118
  capture_output=True,
119
  timeout=5,
120
  )
121
- except (subprocess.CalledProcessError, FileNotFoundError, subprocess.TimeoutExpired):
 
 
 
 
122
  raise RuntimeError(
123
  "Docker is not available. Please install Docker Desktop or Docker Engine."
124
  )
@@ -154,10 +158,13 @@ class LocalDockerProvider(ContainerProvider):
154
 
155
  # Build docker run command
156
  cmd = [
157
- "docker", "run",
 
158
  "-d", # Detached
159
- "--name", self._container_name,
160
- "-p", f"{port}:8000", # Map port
 
 
161
  ]
162
 
163
  # Add environment variables
@@ -169,8 +176,12 @@ class LocalDockerProvider(ContainerProvider):
169
  cmd.append(image)
170
 
171
  # Run container
172
- result = subprocess.run(cmd, capture_output=True, text=True, check=True)
173
- self._container_id = result.stdout.strip()
 
 
 
 
174
 
175
  # Wait a moment for container to start
176
  time.sleep(1)
@@ -222,14 +233,18 @@ class LocalDockerProvider(ContainerProvider):
222
  TimeoutError: If container doesn't become ready
223
  """
224
  import time
 
225
  import requests
226
 
227
  start_time = time.time()
228
  health_url = f"{base_url}/health"
229
 
 
 
 
230
  while time.time() - start_time < timeout_s:
231
  try:
232
- response = requests.get(health_url, timeout=2.0)
233
  if response.status_code == 200:
234
  return
235
  except requests.RequestException:
@@ -273,6 +288,308 @@ class LocalDockerProvider(ContainerProvider):
273
  return f"{clean_image}-{timestamp}"
274
 
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  class KubernetesProvider(ContainerProvider):
277
  """
278
  Container provider for Kubernetes clusters.
@@ -286,4 +603,67 @@ class KubernetesProvider(ContainerProvider):
286
  >>> # Pod running in k8s, accessible via service or port-forward
287
  >>> provider.stop_container()
288
  """
 
289
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  Container provider abstractions for running environment servers.
9
 
10
  This module provides a pluggable architecture for different container providers
11
+ (local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  from abc import ABC, abstractmethod
17
+ from typing import Any, Dict, Optional, Sequence
18
 
19
 
20
  class ContainerProvider(ABC):
 
118
  capture_output=True,
119
  timeout=5,
120
  )
121
+ except (
122
+ subprocess.CalledProcessError,
123
+ FileNotFoundError,
124
+ subprocess.TimeoutExpired,
125
+ ):
126
  raise RuntimeError(
127
  "Docker is not available. Please install Docker Desktop or Docker Engine."
128
  )
 
158
 
159
  # Build docker run command
160
  cmd = [
161
+ "docker",
162
+ "run",
163
  "-d", # Detached
164
+ "--name",
165
+ self._container_name,
166
+ "-p",
167
+ f"{port}:8000", # Map port
168
  ]
169
 
170
  # Add environment variables
 
176
  cmd.append(image)
177
 
178
  # Run container
179
+ try:
180
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
181
+ self._container_id = result.stdout.strip()
182
+ except subprocess.CalledProcessError as e:
183
+ error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}"
184
+ raise RuntimeError(error_msg) from e
185
 
186
  # Wait a moment for container to start
187
  time.sleep(1)
 
233
  TimeoutError: If container doesn't become ready
234
  """
235
  import time
236
+
237
  import requests
238
 
239
  start_time = time.time()
240
  health_url = f"{base_url}/health"
241
 
242
+ # Bypass proxy for localhost to avoid proxy issues
243
+ proxies = {"http": None, "https": None}
244
+
245
  while time.time() - start_time < timeout_s:
246
  try:
247
+ response = requests.get(health_url, timeout=2.0, proxies=proxies)
248
  if response.status_code == 200:
249
  return
250
  except requests.RequestException:
 
288
  return f"{clean_image}-{timestamp}"
289
 
290
 
291
+ class DockerSwarmProvider(ContainerProvider):
292
+ """
293
+ Container provider that uses Docker Swarm services for local concurrency.
294
+
295
+ This provider creates a replicated Swarm service backed by the local Docker
296
+ engine. The built-in load-balancer fans requests across the replicas,
297
+ allowing multiple container instances to run concurrently on the developer
298
+ workstation (mirroring the workflow described in the Docker stack docs).
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ *,
304
+ auto_init_swarm: bool = True,
305
+ overlay_network: Optional[str] = None,
306
+ ):
307
+ """
308
+ Args:
309
+ auto_init_swarm: Whether to call ``docker swarm init`` when Swarm
310
+ is not active. Otherwise, user must manually initialize Swarm.
311
+ overlay_network: Optional overlay network name for the service.
312
+ When provided, the network is created with
313
+ ``docker network create --driver overlay --attachable`` if it
314
+ does not already exist.
315
+ """
316
+ self._service_name: Optional[str] = None
317
+ self._service_id: Optional[str] = None
318
+ self._published_port: Optional[int] = None
319
+ self._overlay_network = overlay_network
320
+ self._auto_init_swarm = auto_init_swarm
321
+
322
+ self._ensure_docker_available()
323
+ self._ensure_swarm_initialized()
324
+ if self._overlay_network:
325
+ self._ensure_overlay_network(self._overlay_network)
326
+
327
+ def start_container(
328
+ self,
329
+ image: str,
330
+ port: Optional[int] = None,
331
+ env_vars: Optional[Dict[str, str]] = None,
332
+ **kwargs: Any,
333
+ ) -> str:
334
+ """
335
+ Start (or scale) a Swarm service for the given image.
336
+
337
+ Supported kwargs:
338
+ replicas (int): Number of container replicas (default: 2).
339
+ cpu_limit (float | str): CPU limit passed to ``--limit-cpu``.
340
+ memory_limit (str): Memory limit passed to ``--limit-memory``.
341
+ constraints (Sequence[str]): Placement constraints.
342
+ labels (Dict[str, str]): Service labels.
343
+ command (Sequence[str] | str): Override container command.
344
+ """
345
+ import shlex
346
+ import subprocess
347
+ import time
348
+
349
+ allowed_kwargs = {
350
+ "replicas",
351
+ "cpu_limit",
352
+ "memory_limit",
353
+ "constraints",
354
+ "labels",
355
+ "command",
356
+ }
357
+ unknown = set(kwargs) - allowed_kwargs
358
+ if unknown:
359
+ raise ValueError(f"Unsupported kwargs for DockerSwarmProvider: {unknown}")
360
+
361
+ replicas = int(kwargs.get("replicas", 2))
362
+ cpu_limit = kwargs.get("cpu_limit")
363
+ memory_limit = kwargs.get("memory_limit")
364
+ constraints: Optional[Sequence[str]] = kwargs.get("constraints")
365
+ labels: Optional[Dict[str, str]] = kwargs.get("labels")
366
+ command_override = kwargs.get("command")
367
+
368
+ if port is None:
369
+ port = self._find_available_port()
370
+
371
+ self._service_name = self._generate_service_name(image)
372
+ self._published_port = port
373
+
374
+ cmd = [
375
+ "docker",
376
+ "service",
377
+ "create",
378
+ "--detach",
379
+ "--name",
380
+ self._service_name,
381
+ "--replicas",
382
+ str(max(1, replicas)),
383
+ "--publish",
384
+ f"{port}:8000",
385
+ ]
386
+
387
+ if self._overlay_network:
388
+ cmd.extend(["--network", self._overlay_network])
389
+
390
+ if env_vars:
391
+ for key, value in env_vars.items():
392
+ cmd.extend(["--env", f"{key}={value}"])
393
+
394
+ if cpu_limit is not None:
395
+ cmd.extend(["--limit-cpu", str(cpu_limit)])
396
+
397
+ if memory_limit is not None:
398
+ cmd.extend(["--limit-memory", str(memory_limit)])
399
+
400
+ if constraints:
401
+ for constraint in constraints:
402
+ cmd.extend(["--constraint", constraint])
403
+
404
+ if labels:
405
+ for key, value in labels.items():
406
+ cmd.extend(["--label", f"{key}={value}"])
407
+
408
+ cmd.append(image)
409
+
410
+ if command_override:
411
+ if isinstance(command_override, str):
412
+ cmd.extend(shlex.split(command_override))
413
+ else:
414
+ cmd.extend(command_override)
415
+
416
+ try:
417
+ result = subprocess.run(
418
+ cmd,
419
+ capture_output=True,
420
+ text=True,
421
+ check=True,
422
+ )
423
+ self._service_id = result.stdout.strip()
424
+ except subprocess.CalledProcessError as e:
425
+ error_msg = (
426
+ "Failed to start Docker Swarm service.\n"
427
+ f"Command: {' '.join(cmd)}\n"
428
+ f"Exit code: {e.returncode}\n"
429
+ f"Stdout: {e.stdout}\n"
430
+ f"Stderr: {e.stderr}"
431
+ )
432
+ raise RuntimeError(error_msg) from e
433
+
434
+ # Give Swarm a brief moment to schedule the tasks.
435
+ time.sleep(1.0)
436
+
437
+ return f"http://localhost:{port}"
438
+
439
+ def stop_container(self) -> None:
440
+ """
441
+ Remove the Swarm service (and keep the Swarm manager running).
442
+ """
443
+ if not self._service_name:
444
+ return
445
+
446
+ import subprocess
447
+
448
+ try:
449
+ subprocess.run(
450
+ ["docker", "service", "rm", self._service_name],
451
+ capture_output=True,
452
+ check=True,
453
+ timeout=10,
454
+ )
455
+ except subprocess.CalledProcessError:
456
+ # Service may already be gone; ignore.
457
+ pass
458
+ finally:
459
+ self._service_name = None
460
+ self._service_id = None
461
+ self._published_port = None
462
+
463
+ def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
464
+ """
465
+ Wait for at least one replica to become healthy by polling /health.
466
+
467
+ Note: With Swarm's load balancer, requests round-robin across replicas,
468
+ so this only verifies that at least one replica is responding. Some
469
+ replicas may still be starting when this returns.
470
+ """
471
+ import time
472
+
473
+ import requests
474
+
475
+ deadline = time.time() + timeout_s
476
+ health_url = f"{base_url}/health"
477
+
478
+ # Bypass proxy for localhost to avoid proxy issues
479
+ proxies = {"http": None, "https": None}
480
+
481
+ while time.time() < deadline:
482
+ try:
483
+ response = requests.get(health_url, timeout=2.0, proxies=proxies)
484
+ if response.status_code == 200:
485
+ return
486
+ except requests.RequestException:
487
+ pass
488
+
489
+ time.sleep(0.5)
490
+
491
+ raise TimeoutError(
492
+ f"Swarm service at {base_url} did not become ready within {timeout_s}s"
493
+ )
494
+
495
+ def _ensure_docker_available(self) -> None:
496
+ import subprocess
497
+
498
+ try:
499
+ subprocess.run(
500
+ ["docker", "version"],
501
+ check=True,
502
+ capture_output=True,
503
+ timeout=5,
504
+ )
505
+ except (
506
+ subprocess.CalledProcessError,
507
+ FileNotFoundError,
508
+ subprocess.TimeoutExpired,
509
+ ) as exc:
510
+ raise RuntimeError(
511
+ "Docker is not available. Please install Docker Desktop or Docker Engine."
512
+ ) from exc
513
+
514
+ def _ensure_swarm_initialized(self) -> None:
515
+ import subprocess
516
+
517
+ try:
518
+ result = subprocess.run(
519
+ ["docker", "info", "--format", "{{.Swarm.LocalNodeState}}"],
520
+ capture_output=True,
521
+ text=True,
522
+ check=True,
523
+ timeout=5,
524
+ )
525
+ state = result.stdout.strip().lower()
526
+ if state == "active":
527
+ return
528
+ except subprocess.CalledProcessError:
529
+ state = "unknown"
530
+
531
+ if not self._auto_init_swarm:
532
+ raise RuntimeError(
533
+ f"Docker Swarm is not active (state={state}). Enable Swarm manually or pass auto_init_swarm=True."
534
+ )
535
+
536
+ try:
537
+ subprocess.run(
538
+ ["docker", "swarm", "init"],
539
+ check=True,
540
+ capture_output=True,
541
+ timeout=10,
542
+ )
543
+ except subprocess.CalledProcessError as e:
544
+ raise RuntimeError("Failed to initialize Docker Swarm") from e
545
+
546
+ def _ensure_overlay_network(self, network: str) -> None:
547
+ import subprocess
548
+
549
+ inspect = subprocess.run(
550
+ ["docker", "network", "inspect", network],
551
+ capture_output=True,
552
+ text=True,
553
+ check=False,
554
+ )
555
+ if inspect.returncode == 0:
556
+ return
557
+
558
+ try:
559
+ subprocess.run(
560
+ [
561
+ "docker",
562
+ "network",
563
+ "create",
564
+ "--driver",
565
+ "overlay",
566
+ "--attachable",
567
+ network,
568
+ ],
569
+ check=True,
570
+ capture_output=True,
571
+ timeout=10,
572
+ )
573
+ except subprocess.CalledProcessError as e:
574
+ raise RuntimeError(f"Failed to create overlay network '{network}'") from e
575
+
576
+ def _find_available_port(self) -> int:
577
+ import socket
578
+
579
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
580
+ s.bind(("", 0))
581
+ s.listen(1)
582
+ port = s.getsockname()[1]
583
+ return port
584
+
585
+ def _generate_service_name(self, image: str) -> str:
586
+ import time
587
+
588
+ clean_image = image.split("/")[-1].split(":")[0]
589
+ timestamp = int(time.time() * 1000)
590
+ return f"{clean_image}-swarm-{timestamp}"
591
+
592
+
593
  class KubernetesProvider(ContainerProvider):
594
  """
595
  Container provider for Kubernetes clusters.
 
603
  >>> # Pod running in k8s, accessible via service or port-forward
604
  >>> provider.stop_container()
605
  """
606
+
607
  pass
608
+
609
+
610
+ class RuntimeProvider(ABC):
611
+ """
612
+ Abstract base class for runtime providers that are not container providers.
613
+ Providers implement this interface to support different runtime platforms:
614
+ - UVProvider: Runs environments via `uv run`
615
+
616
+ The provider manages a single runtime lifecycle and provides the base URL
617
+ for connecting to it.
618
+
619
+ Example:
620
+ >>> provider = UVProvider(project_path="/path/to/env")
621
+ >>> base_url = provider.start()
622
+ >>> print(base_url) # http://localhost:8000
623
+ >>> provider.stop()
624
+ """
625
+
626
+ @abstractmethod
627
+ def start(
628
+ self,
629
+ port: Optional[int] = None,
630
+ env_vars: Optional[Dict[str, str]] = None,
631
+ **kwargs: Any,
632
+ ) -> str:
633
+ """
634
+ Start a runtime from the specified image.
635
+
636
+ Args:
637
+ image: Runtime image name
638
+ port: Port to expose (if None, provider chooses)
639
+ env_vars: Environment variables for the runtime
640
+ **kwargs: Additional runtime options
641
+ """
642
+
643
+ @abstractmethod
644
+ def stop(self) -> None:
645
+ """
646
+ Stop the runtime.
647
+ """
648
+ pass
649
+
650
+ @abstractmethod
651
+ def wait_for_ready(self, timeout_s: float = 30.0) -> None:
652
+ """
653
+ Wait for the runtime to be ready to accept requests.
654
+ """
655
+ pass
656
+
657
+ def __enter__(self) -> "RuntimeProvider":
658
+ """
659
+ Enter the runtime provider.
660
+ """
661
+ self.start()
662
+ return self
663
+
664
+ def __exit__(self, exc_type, exc, tb) -> None:
665
+ """
666
+ Exit the runtime provider.
667
+ """
668
+ self.stop()
669
+ return False
src/core/containers/runtime/uv_provider.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Providers for launching ASGI applications via ``uv run``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import socket
7
+ import subprocess
8
+ import time
9
+ from typing import Dict, Optional
10
+
11
+ import requests
12
+
13
+ from .providers import RuntimeProvider
14
+
15
+
16
+ def _check_uv_installed() -> None:
17
+ try:
18
+ subprocess.check_output(["uv", "--version"])
19
+ except FileNotFoundError as exc:
20
+ raise RuntimeError(
21
+ "`uv` executable not found. Install uv from https://docs.astral.sh and ensure it is on PATH."
22
+ ) from exc
23
+
24
+
25
+ def _find_free_port() -> int:
26
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
27
+ sock.bind(("", 0))
28
+ sock.listen(1)
29
+ return sock.getsockname()[1]
30
+
31
+
32
+ def _create_uv_command(
33
+ *,
34
+ host: str,
35
+ port: int,
36
+ reload: bool,
37
+ workers: int,
38
+ app: str,
39
+ project_path: str,
40
+ ) -> list[str]:
41
+ command: list[str] = ["uv", "run", "--isolated", "--project", project_path]
42
+
43
+ command.append("--")
44
+ command.extend(
45
+ [
46
+ "uvicorn",
47
+ app,
48
+ "--host",
49
+ host,
50
+ "--port",
51
+ str(port),
52
+ "--workers",
53
+ str(workers),
54
+ ]
55
+ )
56
+
57
+ if reload:
58
+ command.append("--reload")
59
+
60
+ return command
61
+
62
+
63
+ def _poll_health(health_url: str, timeout_s: float) -> None:
64
+ """Poll a health endpoint until it returns HTTP 200 or times out."""
65
+
66
+ deadline = time.time() + timeout_s
67
+ while time.time() < deadline:
68
+ try:
69
+ timeout = max(0.0001, min(deadline - time.time(), 2.0))
70
+ response = requests.get(health_url, timeout=timeout)
71
+ if response.status_code == 200:
72
+ return
73
+ except requests.RequestException:
74
+ continue
75
+
76
+ time.sleep(0.5)
77
+
78
+ raise TimeoutError(f"Server did not become ready within {timeout_s:.1f} seconds")
79
+
80
+
81
+ class UVProvider(RuntimeProvider):
82
+ """
83
+ RuntimeProvider implementation backed by ``uv run``.
84
+
85
+ Args:
86
+ project_path: Local path to a uv project (passed to ``uv run --project``)
87
+ app: ASGI application path for uvicorn (defaults to ``server.app:app``)
88
+ host: Host interface to bind to (defaults to ``0.0.0.0``)
89
+ reload: Whether to enable uvicorn's reload mode
90
+ env_vars: Environment variables to pass through to the spawned process
91
+ context_timeout_s: How long to wait for the environment to become ready
92
+
93
+ Example:
94
+ >>> provider = UVProvider(project_path="/path/to/env")
95
+ >>> base_url = provider.start()
96
+ >>> print(base_url) # http://localhost:8000
97
+ >>> # Use the environment via base_url
98
+ >>> provider.stop()
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ *,
104
+ project_path: str,
105
+ app: str = "server.app:app",
106
+ host: str = "0.0.0.0",
107
+ reload: bool = False,
108
+ env_vars: Optional[Dict[str, str]] = None,
109
+ context_timeout_s: float = 60.0,
110
+ ):
111
+ """Initialize the UVProvider."""
112
+ self.project_path = os.path.abspath(project_path)
113
+ self.app = app
114
+ self.host = host
115
+ self.reload = reload
116
+ self.env_vars = env_vars
117
+ self.context_timeout_s = context_timeout_s
118
+ _check_uv_installed()
119
+ self._process = None
120
+ self._base_url = None
121
+
122
+ def start(
123
+ self,
124
+ port: Optional[int] = None,
125
+ env_vars: Optional[Dict[str, str]] = None,
126
+ workers: int = 1,
127
+ **_: Dict[str, str],
128
+ ) -> str:
129
+ """
130
+ Start the environment via `uv run`.
131
+
132
+ Args:
133
+ port: The port to bind the environment to
134
+ env_vars: Environment variables to pass to the environment
135
+ workers: The number of workers to use
136
+
137
+ Returns:
138
+ The base URL of the environment
139
+
140
+ Raises:
141
+ RuntimeError: If the environment is already running
142
+ """
143
+ if self._process is not None and self._process.poll() is None:
144
+ raise RuntimeError("UVProvider is already running")
145
+
146
+ bind_port = port or _find_free_port()
147
+
148
+ command = _create_uv_command(
149
+ host=self.host,
150
+ port=bind_port,
151
+ reload=self.reload,
152
+ workers=workers,
153
+ app=self.app,
154
+ project_path=self.project_path,
155
+ )
156
+
157
+ env = os.environ.copy()
158
+
159
+ if self.env_vars:
160
+ env.update(self.env_vars)
161
+ if env_vars:
162
+ env.update(env_vars)
163
+
164
+ try:
165
+ self._process = subprocess.Popen(command, env=env)
166
+ except OSError as exc:
167
+ raise RuntimeError(f"Failed to launch `uv run`: {exc}") from exc
168
+
169
+ client_host = "127.0.0.1" if self.host in {"0.0.0.0", "::"} else self.host
170
+ self._base_url = f"http://{client_host}:{bind_port}"
171
+ return self._base_url
172
+
173
+ def wait_for_ready(self, timeout_s: float = 60.0) -> None:
174
+ """
175
+ Wait for the environment to become ready.
176
+
177
+ Args:
178
+ timeout_s: The timeout to wait for the environment to become ready
179
+
180
+ Raises:
181
+ RuntimeError: If the environment is not running
182
+ TimeoutError: If the environment does not become ready within the timeout
183
+ """
184
+ if self._process and self._process.poll() is not None:
185
+ code = self._process.returncode
186
+ raise RuntimeError(f"uv process exited prematurely with code {code}")
187
+
188
+ _poll_health(f"{self._base_url}/health", timeout_s=timeout_s)
189
+
190
+ def stop(self) -> None:
191
+ """
192
+ Stop the environment.
193
+
194
+ Raises:
195
+ RuntimeError: If the environment is not running
196
+ """
197
+ if self._process is None:
198
+ return
199
+
200
+ if self._process.poll() is None:
201
+ self._process.terminate()
202
+ try:
203
+ self._process.wait(timeout=10.0)
204
+ except subprocess.TimeoutExpired:
205
+ self._process.kill()
206
+ self._process.wait(timeout=5.0)
207
+
208
+ self._process = None
209
+ self._base_url = None
210
+
211
+ @property
212
+ def base_url(self) -> str:
213
+ """
214
+ The base URL of the environment.
215
+
216
+ Returns:
217
+ The base URL of the environment
218
+
219
+ Raises:
220
+ RuntimeError: If the environment is not running
221
+ """
222
+ if self._base_url is None:
223
+ raise RuntimeError("UVProvider has not been started")
224
+ return self._base_url
src/core/containers/test_local_docker_provider.py CHANGED
@@ -16,8 +16,8 @@ from pathlib import Path
16
  sys.path.insert(0, str(Path(__file__).parent.parent.parent))
17
 
18
  import requests
 
19
 
20
- from core.containers.runtime import LocalDockerProvider
21
 
22
  # TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
23
  def test_local_docker_provider():
@@ -87,7 +87,9 @@ def test_local_docker_provider():
87
  print(f" Length: {data['observation']['message_length']}")
88
  print(f" Reward: {data['reward']}")
89
  assert response.status_code == 200
90
- assert data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!"
 
 
91
  assert data["observation"]["message_length"] == 31
92
  print("✓ Step test passed\n")
93
 
@@ -107,11 +109,11 @@ def test_local_docker_provider():
107
  for i in range(3):
108
  response = requests.post(
109
  f"{base_url}/step",
110
- json={"action": {"message": f"Message {i+1}"}},
111
  headers={"Content-Type": "application/json"},
112
  )
113
  assert response.status_code == 200
114
- print(f" Step {i+1}: ✓")
115
 
116
  # Check state updated
117
  response = requests.get(f"{base_url}/state")
@@ -130,6 +132,7 @@ def test_local_docker_provider():
130
  except Exception as e:
131
  print(f"\n❌ Test failed: {e}")
132
  import traceback
 
133
  traceback.print_exc()
134
  return False
135
 
@@ -197,8 +200,7 @@ def test_provider_with_env_vars():
197
 
198
  print("Starting container with environment variables...")
199
  base_url = provider.start_container(
200
- "echo-env:latest",
201
- env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
202
  )
203
  print(f"✓ Started at: {base_url}")
204
 
 
16
  sys.path.insert(0, str(Path(__file__).parent.parent.parent))
17
 
18
  import requests
19
+ from openenv.core.containers.runtime import LocalDockerProvider
20
 
 
21
 
22
  # TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
23
  def test_local_docker_provider():
 
87
  print(f" Length: {data['observation']['message_length']}")
88
  print(f" Reward: {data['reward']}")
89
  assert response.status_code == 200
90
+ assert (
91
+ data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!"
92
+ )
93
  assert data["observation"]["message_length"] == 31
94
  print("✓ Step test passed\n")
95
 
 
109
  for i in range(3):
110
  response = requests.post(
111
  f"{base_url}/step",
112
+ json={"action": {"message": f"Message {i + 1}"}},
113
  headers={"Content-Type": "application/json"},
114
  )
115
  assert response.status_code == 200
116
+ print(f" Step {i + 1}: ✓")
117
 
118
  # Check state updated
119
  response = requests.get(f"{base_url}/state")
 
132
  except Exception as e:
133
  print(f"\n❌ Test failed: {e}")
134
  import traceback
135
+
136
  traceback.print_exc()
137
  return False
138
 
 
200
 
201
  print("Starting container with environment variables...")
202
  base_url = provider.start_container(
203
+ "echo-env:latest", env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
 
204
  )
205
  print(f"✓ Started at: {base_url}")
206
 
src/core/env_client.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Environment client for persistent sessions.
9
+
10
+ This module provides a WebSocket-based client that maintains a persistent connection
11
+ to an environment server, enabling efficient multi-step interactions without
12
+ the overhead of HTTP request/response cycles.
13
+
14
+ The client is async by default. For synchronous usage, use the `.sync()` method
15
+ to get a `SyncEnvClient` wrapper.
16
+
17
+ Example (async):
18
+ >>> async with GenericEnvClient(base_url="ws://localhost:8000") as env:
19
+ ... result = await env.reset()
20
+ ... result = await env.step({"code": "print('hello')"})
21
+
22
+ Example (sync wrapper):
23
+ >>> env = GenericEnvClient(base_url="ws://localhost:8000").sync()
24
+ >>> with env:
25
+ ... result = env.reset()
26
+ ... result = env.step({"code": "print('hello')"})
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import asyncio
32
+ import json
33
+ import os
34
+ from abc import ABC, abstractmethod
35
+ from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar
36
+
37
+ from .client_types import StateT, StepResult
38
+ from .containers.runtime import LocalDockerProvider, UVProvider
39
+ from .utils import convert_to_ws_url
40
+
41
+ if TYPE_CHECKING:
42
+ from websockets.asyncio.client import ClientConnection
43
+
44
+ from .containers.runtime import ContainerProvider, RuntimeProvider
45
+ from .sync_client import SyncEnvClient
46
+
47
+ from websockets.asyncio.client import connect as ws_connect
48
+
49
+ ActT = TypeVar("ActT")
50
+ ObsT = TypeVar("ObsT")
51
+ EnvClientT = TypeVar("EnvClientT", bound="EnvClient")
52
+
53
+
54
+ class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
55
+ """
56
+ Async environment client for persistent sessions.
57
+
58
+ This client maintains a persistent WebSocket connection to an environment
59
+ server, enabling efficient multi-step interactions. Each client instance
60
+ corresponds to a dedicated environment session on the server.
61
+
62
+ The client is async by default. For synchronous usage, use the `.sync()`
63
+ method to get a `SyncEnvClient` wrapper.
64
+
65
+ Features:
66
+ - Lower latency for sequential interactions
67
+ - Session state is maintained server-side
68
+ - Better suited for long-running episodes
69
+ - Async by default for modern Python async/await patterns
70
+
71
+ Example (async):
72
+ >>> from envs.coding_env.client import CodingEnv
73
+ >>>
74
+ >>> # Connect to a server using async context manager
75
+ >>> async with CodingEnv(base_url="ws://localhost:8000") as env:
76
+ ... result = await env.reset(seed=42)
77
+ ... while not result.done:
78
+ ... action = agent.predict(result.observation)
79
+ ... result = await env.step(action)
80
+
81
+ Example (sync wrapper):
82
+ >>> env = CodingEnv(base_url="ws://localhost:8000").sync()
83
+ >>> with env:
84
+ ... result = env.reset(seed=42)
85
+ ... result = env.step(action)
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ base_url: str,
91
+ connect_timeout_s: float = 10.0,
92
+ message_timeout_s: float = 60.0,
93
+ max_message_size_mb: float = 100.0,
94
+ provider: Optional["ContainerProvider | RuntimeProvider"] = None,
95
+ mode: Optional[str] = None,
96
+ ):
97
+ """
98
+ Initialize environment client.
99
+
100
+ Args:
101
+ base_url: Base URL of the environment server (http:// or ws://).
102
+ Will be converted to ws:// if http:// is provided.
103
+ connect_timeout_s: Timeout for establishing WebSocket connection
104
+ message_timeout_s: Timeout for receiving responses to messages
105
+ max_message_size_mb: Maximum WebSocket message size in megabytes.
106
+ Default 100MB to handle large observations (screenshots, DOM, etc.)
107
+ provider: Optional container/runtime provider for lifecycle management.
108
+ Can be a ContainerProvider (Docker) or RuntimeProvider (UV).
109
+ mode: Communication mode: 'simulation' for Gym-style API (default) or
110
+ 'production' for MCP JSON-RPC protocol. Can also be set via the
111
+ OPENENV_CLIENT_MODE environment variable. Constructor parameter
112
+ takes precedence over environment variable. Case-insensitive.
113
+ """
114
+ # Determine mode (constructor > env var > default)
115
+ if mode is None:
116
+ mode = os.environ.get("OPENENV_CLIENT_MODE", "simulation")
117
+
118
+ # Normalize and validate mode
119
+ mode = mode.lower()
120
+ if mode not in ("simulation", "production"):
121
+ raise ValueError(
122
+ f"Invalid mode: '{mode}'. Must be 'simulation' or 'production'. "
123
+ f"Set via constructor parameter or OPENENV_CLIENT_MODE environment variable."
124
+ )
125
+
126
+ # Store mode (use object.__setattr__ to bypass immutability)
127
+ object.__setattr__(self, "_mode", mode)
128
+
129
+ # Convert HTTP URL to WebSocket URL
130
+ ws_url = convert_to_ws_url(base_url)
131
+
132
+ self._ws_url = f"{ws_url}/ws"
133
+ self._connect_timeout = connect_timeout_s
134
+ self._message_timeout = message_timeout_s
135
+ self._max_message_size = int(
136
+ max_message_size_mb * 1024 * 1024
137
+ ) # Convert MB to bytes
138
+ self._provider = provider
139
+ self._ws: Optional[ClientConnection] = None
140
+
141
+ def __setattr__(self, name: str, value: Any) -> None:
142
+ """Prevent modification of _mode after initialization."""
143
+ if name == "_mode" and hasattr(self, "_mode"):
144
+ raise AttributeError("Cannot modify mode after initialization")
145
+ super().__setattr__(name, value)
146
+
147
+ async def connect(self) -> "EnvClient":
148
+ """
149
+ Establish WebSocket connection to the server.
150
+
151
+ Returns:
152
+ self for method chaining
153
+
154
+ Raises:
155
+ ConnectionError: If connection cannot be established
156
+ """
157
+ if self._ws is not None:
158
+ return self
159
+
160
+ # Bypass proxy for localhost connections
161
+ ws_url_lower = self._ws_url.lower()
162
+ is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
163
+
164
+ old_no_proxy = os.environ.get("NO_PROXY")
165
+ if is_localhost:
166
+ # Set NO_PROXY to bypass proxy for localhost
167
+ current_no_proxy = old_no_proxy or ""
168
+ if "localhost" not in current_no_proxy.lower():
169
+ os.environ["NO_PROXY"] = (
170
+ f"{current_no_proxy},localhost,127.0.0.1"
171
+ if current_no_proxy
172
+ else "localhost,127.0.0.1"
173
+ )
174
+
175
+ try:
176
+ self._ws = await ws_connect(
177
+ self._ws_url,
178
+ open_timeout=self._connect_timeout,
179
+ max_size=self._max_message_size,
180
+ )
181
+ except Exception as e:
182
+ raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
183
+ finally:
184
+ # Restore original NO_PROXY value
185
+ if is_localhost:
186
+ if old_no_proxy is None:
187
+ os.environ.pop("NO_PROXY", None)
188
+ else:
189
+ os.environ["NO_PROXY"] = old_no_proxy
190
+
191
+ return self
192
+
193
+ async def disconnect(self) -> None:
194
+ """Close the WebSocket connection."""
195
+ if self._ws is not None:
196
+ try:
197
+ # Send close message
198
+ await self._send({"type": "close"})
199
+ except Exception:
200
+ pass # Best effort
201
+ try:
202
+ await self._ws.close()
203
+ except Exception:
204
+ pass
205
+ self._ws = None
206
+
207
+ async def _ensure_connected(self) -> None:
208
+ """Ensure WebSocket connection is established."""
209
+ if self._ws is None:
210
+ await self.connect()
211
+
212
+ async def _send(self, message: Dict[str, Any]) -> None:
213
+ """Send a message over the WebSocket."""
214
+ await self._ensure_connected()
215
+ assert self._ws is not None
216
+ await self._ws.send(json.dumps(message))
217
+
218
+ async def _receive(self) -> Dict[str, Any]:
219
+ """Receive and parse a message from the WebSocket."""
220
+ assert self._ws is not None
221
+ raw = await asyncio.wait_for(self._ws.recv(), timeout=self._message_timeout)
222
+ return json.loads(raw)
223
+
224
+ async def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]:
225
+ """Send a message and wait for response."""
226
+ await self._send(message)
227
+ response = await self._receive()
228
+
229
+ # Check for error response
230
+ if response.get("type") == "error":
231
+ error_data = response.get("data", {})
232
+ raise RuntimeError(
233
+ f"Server error: {error_data.get('message', 'Unknown error')} "
234
+ f"(code: {error_data.get('code', 'UNKNOWN')})"
235
+ )
236
+
237
+ return response
238
+
239
+ @classmethod
240
+ async def from_docker_image(
241
+ cls: Type[EnvClientT],
242
+ image: str,
243
+ provider: Optional["ContainerProvider"] = None,
244
+ **kwargs: Any,
245
+ ) -> EnvClientT:
246
+ """
247
+ Create an environment client by spinning up a Docker container.
248
+
249
+ Args:
250
+ image: Docker image name to run (e.g., "coding-env:latest")
251
+ provider: Container provider to use (defaults to LocalDockerProvider)
252
+ **kwargs: Additional arguments to pass to provider.start_container()
253
+
254
+ Returns:
255
+ Connected client instance
256
+ """
257
+ if provider is None:
258
+ provider = LocalDockerProvider()
259
+
260
+ # Start container
261
+ base_url = provider.start_container(image, **kwargs)
262
+
263
+ # Wait for server to be ready
264
+ provider.wait_for_ready(base_url)
265
+
266
+ # Create and connect client
267
+ client = cls(base_url=base_url, provider=provider)
268
+ await client.connect()
269
+
270
+ return client
271
+
272
+ @classmethod
273
+ async def from_env(
274
+ cls: Type[EnvClientT],
275
+ repo_id: str,
276
+ *,
277
+ use_docker: bool = True,
278
+ provider: Optional["ContainerProvider | RuntimeProvider"] = None,
279
+ **provider_kwargs: Any,
280
+ ) -> EnvClientT:
281
+ """
282
+ Create a client from a Hugging Face Space.
283
+
284
+ Args:
285
+ repo_id: Hugging Face space identifier ``{org}/{space}``.
286
+ use_docker: When ``True`` (default) pull from the HF registry and
287
+ launch via :class:`LocalDockerProvider`. When ``False`` run the
288
+ space locally with :class:`UVProvider`.
289
+ provider: Optional provider instance to reuse. Must be a
290
+ :class:`ContainerProvider` when ``use_docker=True`` and a
291
+ :class:`RuntimeProvider` otherwise.
292
+ provider_kwargs: Additional keyword arguments forwarded to
293
+ either the container provider's ``start_container`` (docker)
294
+ or to the ``UVProvider`` constructor/start (uv). When
295
+ ``use_docker=False``, the ``project_path`` argument can be
296
+ used to override the default git URL
297
+ (``git+https://huggingface.co/spaces/{repo_id}``).
298
+
299
+ Returns:
300
+ Connected client instance
301
+
302
+ Examples:
303
+ >>> # Pull and run from HF Docker registry
304
+ >>> env = await MyEnv.from_env("openenv/echo-env")
305
+ >>>
306
+ >>> # Run locally with UV (clones the space)
307
+ >>> env = await MyEnv.from_env("openenv/echo-env", use_docker=False)
308
+ >>>
309
+ >>> # Run from a local checkout
310
+ >>> env = await MyEnv.from_env(
311
+ ... "openenv/echo-env",
312
+ ... use_docker=False,
313
+ ... project_path="/path/to/local/checkout"
314
+ ... )
315
+ """
316
+ # Extract start args that apply to both providers
317
+ start_args = {}
318
+ for key in ("port", "env_vars", "workers"):
319
+ if key in provider_kwargs:
320
+ start_args[key] = provider_kwargs.pop(key)
321
+
322
+ if use_docker:
323
+ # Docker mode: pull from HF registry
324
+ docker_provider = provider or LocalDockerProvider()
325
+ tag = provider_kwargs.pop("tag", "latest")
326
+ image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
327
+ base_url = docker_provider.start_container(
328
+ image, **start_args, **provider_kwargs
329
+ )
330
+ docker_provider.wait_for_ready(base_url)
331
+
332
+ client = cls(base_url=base_url, provider=docker_provider)
333
+ await client.connect()
334
+ return client
335
+ else:
336
+ # UV mode: clone and run with uv
337
+ if provider is None:
338
+ uv_kwargs = dict(provider_kwargs)
339
+ project_path = uv_kwargs.pop("project_path", None)
340
+ if project_path is None:
341
+ project_path = f"git+https://huggingface.co/spaces/{repo_id}"
342
+
343
+ provider = UVProvider(project_path=project_path, **uv_kwargs)
344
+ else:
345
+ if provider_kwargs:
346
+ raise ValueError(
347
+ "provider_kwargs cannot be used when supplying a provider instance"
348
+ )
349
+
350
+ base_url = provider.start(**start_args)
351
+ provider.wait_for_ready()
352
+
353
+ client = cls(base_url=base_url, provider=provider)
354
+ await client.connect()
355
+ return client
356
+
357
+ @abstractmethod
358
+ def _step_payload(self, action: ActT) -> Dict[str, Any]:
359
+ """Convert an Action object to the JSON data expected by the env server."""
360
+ raise NotImplementedError
361
+
362
+ @abstractmethod
363
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]:
364
+ """Convert a JSON response from the env server to StepResult[ObsT]."""
365
+ raise NotImplementedError
366
+
367
+ @abstractmethod
368
+ def _parse_state(self, payload: Dict[str, Any]) -> StateT:
369
+ """Convert a JSON response from the state endpoint to a State object."""
370
+ raise NotImplementedError
371
+
372
+ async def reset(self, **kwargs: Any) -> StepResult[ObsT]:
373
+ """
374
+ Reset the environment with optional parameters.
375
+
376
+ Args:
377
+ **kwargs: Optional parameters passed to the environment's reset method.
378
+ Common parameters include:
379
+ - seed: Random seed for reproducibility
380
+ - episode_id: Custom episode identifier
381
+
382
+ Returns:
383
+ StepResult containing initial observation
384
+ """
385
+ message = {
386
+ "type": "reset",
387
+ "data": kwargs,
388
+ }
389
+ response = await self._send_and_receive(message)
390
+ return self._parse_result(response.get("data", {}))
391
+
392
+ async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
393
+ """
394
+ Execute an action in the environment.
395
+
396
+ Args:
397
+ action: The action to execute
398
+ **kwargs: Optional parameters (currently ignored)
399
+
400
+ Returns:
401
+ StepResult containing observation, reward, and done status
402
+ """
403
+ message = {
404
+ "type": "step",
405
+ "data": self._step_payload(action),
406
+ }
407
+ response = await self._send_and_receive(message)
408
+ return self._parse_result(response.get("data", {}))
409
+
410
+ async def state(self) -> StateT:
411
+ """
412
+ Get the current environment state from the server.
413
+
414
+ Returns:
415
+ State object with environment state information
416
+ """
417
+ message = {"type": "state"}
418
+ response = await self._send_and_receive(message)
419
+ return self._parse_state(response.get("data", {}))
420
+
421
+ async def close(self) -> None:
422
+ """
423
+ Close the WebSocket connection and clean up resources.
424
+
425
+ If this client was created via from_docker_image() or from_env(),
426
+ this will also stop and remove the associated container/process.
427
+ """
428
+ await self.disconnect()
429
+
430
+ if self._provider is not None:
431
+ # Handle both ContainerProvider and RuntimeProvider
432
+ if hasattr(self._provider, "stop_container"):
433
+ self._provider.stop_container()
434
+ elif hasattr(self._provider, "stop"):
435
+ self._provider.stop()
436
+
437
+ async def __aenter__(self) -> "EnvClient":
438
+ """Enter async context manager, ensuring connection is established."""
439
+ await self.connect()
440
+ return self
441
+
442
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
443
+ """Exit async context manager, closing connection."""
444
+ await self.close()
445
+
446
+ def __enter__(self) -> "EnvClient":
447
+ """Sync context manager entry - raises error suggesting async usage."""
448
+ raise TypeError(
449
+ "EnvClient is async by default. Use 'async with' instead of 'with', "
450
+ "or call .sync() to get a synchronous wrapper:\n"
451
+ " async with client: # async usage\n"
452
+ " with client.sync(): # sync wrapper"
453
+ )
454
+
455
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
456
+ """Sync context manager exit - should not be reached."""
457
+ pass # pragma: no cover
458
+
459
+ def sync(self) -> "SyncEnvClient":
460
+ """
461
+ Return a synchronous wrapper around this async client.
462
+
463
+ Use this method when you need synchronous access to the environment
464
+ without async/await syntax. This is useful for:
465
+ - Integration with synchronous codebases
466
+ - Interactive/REPL usage
467
+ - Stopping async from "infecting" the call stack
468
+
469
+ Returns:
470
+ SyncEnvClient wrapper that provides synchronous methods
471
+
472
+ Example:
473
+ >>> # Create async client and get sync wrapper
474
+ >>> async_client = GenericEnvClient(base_url="http://localhost:8000")
475
+ >>> sync_client = async_client.sync()
476
+ >>>
477
+ >>> # Use synchronous API
478
+ >>> with sync_client:
479
+ ... result = sync_client.reset()
480
+ ... result = sync_client.step({"code": "print('hello')"})
481
+ """
482
+ from .sync_client import SyncEnvClient
483
+
484
+ return SyncEnvClient(self)
src/core/env_server/__init__.py CHANGED
@@ -7,10 +7,74 @@
7
  """Core environment interfaces and types."""
8
 
9
  from .base_transforms import CompositeTransform, NullTransform
10
- from .http_server import HTTPEnvServer, create_app, create_fastapi_app
 
 
 
 
 
 
 
 
11
  from .interfaces import Environment, Message, ModelTokenizer, Transform
12
- from .types import Action, Observation, State
13
- from .web_interface import create_web_interface_app, WebInterfaceManager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  __all__ = [
16
  # Core interfaces
@@ -22,6 +86,33 @@ __all__ = [
22
  "Action",
23
  "Observation",
24
  "State",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Base transforms
26
  "CompositeTransform",
27
  "NullTransform",
@@ -32,4 +123,28 @@ __all__ = [
32
  # Web Interface
33
  "create_web_interface_app",
34
  "WebInterfaceManager",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ]
 
7
  """Core environment interfaces and types."""
8
 
9
  from .base_transforms import CompositeTransform, NullTransform
10
+ from .exceptions import (
11
+ ConcurrencyConfigurationError,
12
+ EnvironmentFactoryError,
13
+ OpenEnvError,
14
+ SessionCapacityError,
15
+ SessionCreationError,
16
+ SessionNotFoundError,
17
+ )
18
+ from .http_server import create_app, create_fastapi_app, HTTPEnvServer
19
  from .interfaces import Environment, Message, ModelTokenizer, Transform
20
+
21
+ try:
22
+ from .mcp_environment import MCPEnvironment
23
+ except ModuleNotFoundError:
24
+ MCPEnvironment = None # type: ignore[assignment]
25
+
26
+ from .mcp_types import (
27
+ CallToolAction,
28
+ CallToolObservation,
29
+ JsonRpcError,
30
+ # JSON-RPC types
31
+ JsonRpcErrorCode,
32
+ JsonRpcRequest,
33
+ JsonRpcResponse,
34
+ ListToolsAction,
35
+ ListToolsObservation,
36
+ McpMethod,
37
+ RESERVED_TOOL_NAMES,
38
+ Tool,
39
+ ToolError,
40
+ ToolErrorType,
41
+ WSMCPMessage,
42
+ WSMCPResponse,
43
+ )
44
+ from .route_config import GetEndpointConfig
45
+ from .serialization import (
46
+ deserialize_action,
47
+ deserialize_action_with_preprocessing,
48
+ serialize_observation,
49
+ )
50
+ from .types import (
51
+ Action,
52
+ BaseMessage,
53
+ ConcurrencyConfig,
54
+ HealthResponse,
55
+ HealthStatus,
56
+ Observation,
57
+ SchemaResponse,
58
+ ServerCapacityStatus,
59
+ ServerMode,
60
+ SessionInfo,
61
+ State,
62
+ WSCloseMessage,
63
+ WSErrorCode,
64
+ WSErrorResponse,
65
+ WSIncomingMessage,
66
+ WSObservationResponse,
67
+ WSResetMessage,
68
+ WSStateMessage,
69
+ WSStateResponse,
70
+ WSStepMessage,
71
+ )
72
+
73
+ try:
74
+ from .web_interface import create_web_interface_app, WebInterfaceManager
75
+ except ModuleNotFoundError:
76
+ create_web_interface_app = None # type: ignore[assignment]
77
+ WebInterfaceManager = None # type: ignore[assignment]
78
 
79
  __all__ = [
80
  # Core interfaces
 
86
  "Action",
87
  "Observation",
88
  "State",
89
+ "SchemaResponse",
90
+ "HealthResponse",
91
+ # Enums
92
+ "HealthStatus",
93
+ "ServerMode",
94
+ "WSErrorCode",
95
+ # WebSocket message types
96
+ "BaseMessage",
97
+ "WSIncomingMessage",
98
+ "WSResetMessage",
99
+ "WSStepMessage",
100
+ "WSStateMessage",
101
+ "WSCloseMessage",
102
+ "WSObservationResponse",
103
+ "WSStateResponse",
104
+ "WSErrorResponse",
105
+ # Concurrency types
106
+ "ConcurrencyConfig",
107
+ "ServerCapacityStatus",
108
+ "SessionInfo",
109
+ # Exceptions
110
+ "OpenEnvError",
111
+ "ConcurrencyConfigurationError",
112
+ "SessionCapacityError",
113
+ "SessionNotFoundError",
114
+ "SessionCreationError",
115
+ "EnvironmentFactoryError",
116
  # Base transforms
117
  "CompositeTransform",
118
  "NullTransform",
 
123
  # Web Interface
124
  "create_web_interface_app",
125
  "WebInterfaceManager",
126
+ # Serialization utilities
127
+ "deserialize_action",
128
+ "deserialize_action_with_preprocessing",
129
+ "serialize_observation",
130
+ # Route configuration
131
+ "GetEndpointConfig",
132
+ # MCP types
133
+ "Tool",
134
+ "ToolError",
135
+ "ToolErrorType",
136
+ "ListToolsAction",
137
+ "CallToolAction",
138
+ "ListToolsObservation",
139
+ "CallToolObservation",
140
+ "WSMCPMessage",
141
+ "WSMCPResponse",
142
+ "RESERVED_TOOL_NAMES",
143
+ "MCPEnvironment",
144
+ # JSON-RPC types
145
+ "JsonRpcErrorCode",
146
+ "JsonRpcError",
147
+ "JsonRpcRequest",
148
+ "JsonRpcResponse",
149
+ "McpMethod",
150
  ]
src/core/env_server/base_transforms.py CHANGED
@@ -26,4 +26,4 @@ class NullTransform(Transform):
26
  """Default transform that passes through unchanged."""
27
 
28
  def __call__(self, observation: Observation) -> Observation:
29
- return observation
 
26
  """Default transform that passes through unchanged."""
27
 
28
  def __call__(self, observation: Observation) -> Observation:
29
+ return observation
src/core/env_server/exceptions.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Custom exceptions for environment server operations."""
8
+
9
+ from typing import Optional
10
+
11
+
12
+ class OpenEnvError(Exception):
13
+ """Base exception for all OpenEnv errors."""
14
+
15
+ pass
16
+
17
+
18
+ class ConcurrencyConfigurationError(OpenEnvError):
19
+ """
20
+ Raised when an environment is misconfigured for concurrent sessions.
21
+
22
+ This error is raised during server startup when max_concurrent_envs > 1
23
+ is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ environment_name: str,
29
+ max_concurrent_envs: int,
30
+ message: Optional[str] = None,
31
+ ):
32
+ self.environment_name = environment_name
33
+ self.max_concurrent_envs = max_concurrent_envs
34
+
35
+ if message is None:
36
+ message = (
37
+ f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. "
38
+ f"Cannot run with max_concurrent_envs={max_concurrent_envs}. "
39
+ f"Either set max_concurrent_envs=1 or ensure the environment "
40
+ f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True."
41
+ )
42
+
43
+ super().__init__(message)
44
+
45
+
46
+ class SessionCapacityError(OpenEnvError):
47
+ """
48
+ Raised when the server cannot accept new sessions due to capacity limits.
49
+
50
+ This error is raised when a new WebSocket connection is attempted but
51
+ the server has already reached max_concurrent_envs active sessions.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ active_sessions: int,
57
+ max_sessions: int,
58
+ message: Optional[str] = None,
59
+ ):
60
+ self.active_sessions = active_sessions
61
+ self.max_sessions = max_sessions
62
+
63
+ if message is None:
64
+ message = (
65
+ f"Server at capacity: {active_sessions}/{max_sessions} sessions active. "
66
+ f"Cannot accept new connections."
67
+ )
68
+
69
+ super().__init__(message)
70
+
71
+
72
+ class SessionNotFoundError(OpenEnvError):
73
+ """Raised when attempting to access a session that does not exist."""
74
+
75
+ def __init__(self, session_id: str, message: Optional[str] = None):
76
+ self.session_id = session_id
77
+
78
+ if message is None:
79
+ message = f"Session '{session_id}' not found."
80
+
81
+ super().__init__(message)
82
+
83
+
84
+ class SessionCreationError(OpenEnvError):
85
+ """Raised when a session cannot be created."""
86
+
87
+ def __init__(self, reason: str, message: Optional[str] = None):
88
+ self.reason = reason
89
+
90
+ if message is None:
91
+ message = f"Failed to create session: {reason}"
92
+
93
+ super().__init__(message)
94
+
95
+
96
+ class EnvironmentFactoryError(OpenEnvError):
97
+ """Raised when the environment factory fails to create an instance."""
98
+
99
+ def __init__(self, factory_name: str, message: Optional[str] = None):
100
+ self.factory_name = factory_name
101
+
102
+ if message is None:
103
+ message = f"Environment factory '{factory_name}' failed to create instance."
104
+
105
+ super().__init__(message)
src/core/env_server/gradio_theme.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Unified terminal-style theme for OpenEnv Gradio UI (light/dark)."""
8
+
9
+ from __future__ import annotations
10
+
11
+ import gradio as gr
12
+
13
+ _MONO_FONTS = (
14
+ "JetBrains Mono",
15
+ "Fira Code",
16
+ "Cascadia Code",
17
+ "Consolas",
18
+ "ui-monospace",
19
+ "monospace",
20
+ )
21
+
22
+ _CORE_FONT = (
23
+ "Lato",
24
+ "Inter",
25
+ "Arial",
26
+ "Helvetica",
27
+ "sans-serif",
28
+ )
29
+
30
+ _ZERO_RADIUS = gr.themes.Size(
31
+ xxs="0px",
32
+ xs="0px",
33
+ sm="0px",
34
+ md="0px",
35
+ lg="0px",
36
+ xl="0px",
37
+ xxl="0px",
38
+ )
39
+
40
+ _GREEN_HUE = gr.themes.Color(
41
+ c50="#e6f4ea",
42
+ c100="#ceead6",
43
+ c200="#a8dab5",
44
+ c300="#6fcc8b",
45
+ c400="#3fb950",
46
+ c500="#238636",
47
+ c600="#1a7f37",
48
+ c700="#116329",
49
+ c800="#0a4620",
50
+ c900="#033a16",
51
+ c950="#04200d",
52
+ )
53
+
54
+ _NEUTRAL_HUE = gr.themes.Color(
55
+ c50="#f6f8fa",
56
+ c100="#eaeef2",
57
+ c200="#d0d7de",
58
+ c300="#afb8c1",
59
+ c400="#8c959f",
60
+ c500="#6e7781",
61
+ c600="#57606a",
62
+ c700="#424a53",
63
+ c800="#32383f",
64
+ c900="#24292f",
65
+ c950="#1b1f24",
66
+ )
67
+
68
+ OPENENV_GRADIO_THEME = gr.themes.Base(
69
+ primary_hue=_GREEN_HUE,
70
+ secondary_hue=_NEUTRAL_HUE,
71
+ neutral_hue=_NEUTRAL_HUE,
72
+ font=_CORE_FONT,
73
+ font_mono=_MONO_FONTS,
74
+ radius_size=_ZERO_RADIUS,
75
+ ).set(
76
+ body_background_fill="#ffffff",
77
+ background_fill_primary="#ffffff",
78
+ background_fill_secondary="#f6f8fa",
79
+ block_background_fill="#ffffff",
80
+ block_border_color="#ffffff",
81
+ block_label_text_color="#57606a",
82
+ block_title_text_color="#24292f",
83
+ border_color_primary="#d0d7de",
84
+ input_background_fill="#ffffff",
85
+ input_border_color="#d0d7de",
86
+ button_primary_background_fill="#1a7f37",
87
+ button_primary_background_fill_hover="#116329",
88
+ button_primary_text_color="#ffffff",
89
+ button_secondary_background_fill="#f6f8fa",
90
+ button_secondary_background_fill_hover="#eaeef2",
91
+ button_secondary_text_color="#24292f",
92
+ button_secondary_border_color="#d0d7de",
93
+ body_background_fill_dark="#0d1117",
94
+ background_fill_primary_dark="#0d1117",
95
+ background_fill_secondary_dark="#0d1117",
96
+ block_background_fill_dark="#0d1117",
97
+ block_border_color_dark="#0d1117",
98
+ block_label_text_color_dark="#8b949e",
99
+ block_title_text_color_dark="#c9d1d9",
100
+ border_color_primary_dark="#30363d",
101
+ input_background_fill_dark="#0d1117",
102
+ input_border_color_dark="#30363d",
103
+ button_primary_background_fill_dark="#30363d",
104
+ button_primary_background_fill_hover_dark="#484f58",
105
+ button_primary_text_color_dark="#c9d1d9",
106
+ button_secondary_background_fill_dark="#21262d",
107
+ button_secondary_background_fill_hover_dark="#30363d",
108
+ button_secondary_text_color_dark="#c9d1d9",
109
+ button_secondary_border_color_dark="#30363d",
110
+ )
111
+
112
+ OPENENV_GRADIO_CSS = """
113
+ * { border-radius: 0 !important; }
114
+ .col-left { padding: 16px !important; }
115
+ .col-right { padding: 16px !important; }
116
+ .prose, .markdown-text, .md,
117
+ .prose > *, .markdown-text > * {
118
+ background: transparent !important;
119
+ border: none !important;
120
+ box-shadow: none !important;
121
+ }
122
+ .dark .col-left {
123
+ border-left-color: rgba(139, 148, 158, 0.4) !important;
124
+ }
125
+ .dark .col-right {
126
+ border-left-color: rgba(201, 209, 217, 0.3) !important;
127
+ }
128
+ """
src/core/env_server/gradio_ui.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Gradio-based web UI for OpenEnv environments.
9
+
10
+ Replaces the legacy HTML/JavaScript interface when ENABLE_WEB_INTERFACE is set.
11
+ Mount at /web via gr.mount_gradio_app() from create_web_interface_app().
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import re
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ import gradio as gr
21
+
22
+ from .types import EnvironmentMetadata
23
+
24
+
25
+ def _escape_md(text: str) -> str:
26
+ """Escape Markdown special characters in user-controlled content."""
27
+ return re.sub(r"([\\`*_\{\}\[\]()#+\-.!|~>])", r"\\\1", str(text))
28
+
29
+
30
+ def _format_observation(data: Dict[str, Any]) -> str:
31
+ """Format reset/step response for Markdown display."""
32
+ lines: List[str] = []
33
+ obs = data.get("observation", {})
34
+ if isinstance(obs, dict):
35
+ if obs.get("prompt"):
36
+ lines.append(f"**Prompt:**\n\n{_escape_md(obs['prompt'])}\n")
37
+ messages = obs.get("messages", [])
38
+ if messages:
39
+ lines.append("**Messages:**\n")
40
+ for msg in messages:
41
+ sender = _escape_md(str(msg.get("sender_id", "?")))
42
+ content = _escape_md(str(msg.get("content", "")))
43
+ cat = _escape_md(str(msg.get("category", "")))
44
+ lines.append(f"- `[{cat}]` Player {sender}: {content}")
45
+ lines.append("")
46
+ reward = data.get("reward")
47
+ done = data.get("done")
48
+ if reward is not None:
49
+ lines.append(f"**Reward:** `{reward}`")
50
+ if done is not None:
51
+ lines.append(f"**Done:** `{done}`")
52
+ return "\n".join(lines) if lines else "*No observation data*"
53
+
54
+
55
+ def _readme_section(metadata: Optional[EnvironmentMetadata]) -> str:
56
+ """README content for the left panel."""
57
+ if not metadata or not metadata.readme_content:
58
+ return "*No README available.*"
59
+ return metadata.readme_content
60
+
61
+
62
+ def get_gradio_display_title(
63
+ metadata: Optional[EnvironmentMetadata],
64
+ fallback: str = "OpenEnv Environment",
65
+ ) -> str:
66
+ """Return the title used for the Gradio app (browser tab and Blocks)."""
67
+ name = metadata.name if metadata else fallback
68
+ return f"OpenEnv Agentic Environment: {name}"
69
+
70
+
71
+ def build_gradio_app(
72
+ web_manager: Any,
73
+ action_fields: List[Dict[str, Any]],
74
+ metadata: Optional[EnvironmentMetadata],
75
+ is_chat_env: bool,
76
+ title: str = "OpenEnv Environment",
77
+ quick_start_md: Optional[str] = None,
78
+ ) -> gr.Blocks:
79
+ """
80
+ Build a Gradio Blocks app for the OpenEnv web interface.
81
+
82
+ Args:
83
+ web_manager: WebInterfaceManager (reset/step_environment, get_state).
84
+ action_fields: Field dicts from _extract_action_fields(action_cls).
85
+ metadata: Environment metadata for README/name.
86
+ is_chat_env: If True, single message textbox; else form from action_fields.
87
+ title: App title (overridden by metadata.name when present; see get_gradio_display_title).
88
+ quick_start_md: Optional Quick Start markdown (class names already replaced).
89
+
90
+ Returns:
91
+ gr.Blocks to mount with gr.mount_gradio_app(app, blocks, path="/web").
92
+ """
93
+ readme_content = _readme_section(metadata)
94
+ display_title = get_gradio_display_title(metadata, fallback=title)
95
+
96
+ async def reset_env():
97
+ try:
98
+ data = await web_manager.reset_environment()
99
+ obs_md = _format_observation(data)
100
+ return (
101
+ obs_md,
102
+ json.dumps(data, indent=2),
103
+ "Environment reset successfully.",
104
+ )
105
+ except Exception as e:
106
+ return ("", "", f"Error: {e}")
107
+
108
+ def _step_with_action(action_data: Dict[str, Any]):
109
+ async def _run():
110
+ try:
111
+ data = await web_manager.step_environment(action_data)
112
+ obs_md = _format_observation(data)
113
+ return (
114
+ obs_md,
115
+ json.dumps(data, indent=2),
116
+ "Step complete.",
117
+ )
118
+ except Exception as e:
119
+ return ("", "", f"Error: {e}")
120
+
121
+ return _run
122
+
123
+ async def step_chat(message: str):
124
+ if not (message or str(message).strip()):
125
+ return ("", "", "Please enter an action message.")
126
+ action = {"message": str(message).strip()}
127
+ return await _step_with_action(action)()
128
+
129
+ def get_state_sync():
130
+ try:
131
+ data = web_manager.get_state()
132
+ return json.dumps(data, indent=2)
133
+ except Exception as e:
134
+ return f"Error: {e}"
135
+
136
+ with gr.Blocks(title=display_title) as demo:
137
+ with gr.Row():
138
+ with gr.Column(scale=1, elem_classes="col-left"):
139
+ if quick_start_md:
140
+ with gr.Accordion("Quick Start", open=True):
141
+ gr.Markdown(quick_start_md)
142
+ with gr.Accordion("README", open=False):
143
+ gr.Markdown(readme_content)
144
+
145
+ with gr.Column(scale=2, elem_classes="col-right"):
146
+ obs_display = gr.Markdown(
147
+ value=("# Playground\n\nClick **Reset** to start a new episode."),
148
+ )
149
+ with gr.Group():
150
+ if is_chat_env:
151
+ action_input = gr.Textbox(
152
+ label="Action message",
153
+ placeholder="e.g. Enter your message...",
154
+ )
155
+ step_inputs = [action_input]
156
+ step_fn = step_chat
157
+ else:
158
+ step_inputs = []
159
+ for field in action_fields:
160
+ name = field["name"]
161
+ field_type = field.get("type", "text")
162
+ label = name.replace("_", " ").title()
163
+ placeholder = field.get("placeholder", "")
164
+ if field_type == "checkbox":
165
+ inp = gr.Checkbox(label=label)
166
+ elif field_type == "number":
167
+ inp = gr.Number(label=label)
168
+ elif field_type == "select":
169
+ choices = field.get("choices") or []
170
+ inp = gr.Dropdown(
171
+ choices=choices,
172
+ label=label,
173
+ allow_custom_value=False,
174
+ )
175
+ elif field_type in ("textarea", "tensor"):
176
+ inp = gr.Textbox(
177
+ label=label,
178
+ placeholder=placeholder,
179
+ lines=3,
180
+ )
181
+ else:
182
+ inp = gr.Textbox(
183
+ label=label,
184
+ placeholder=placeholder,
185
+ )
186
+ step_inputs.append(inp)
187
+
188
+ async def step_form(*values):
189
+ if not action_fields:
190
+ return await _step_with_action({})()
191
+ action_data = {}
192
+ for i, field in enumerate(action_fields):
193
+ if i >= len(values):
194
+ break
195
+ name = field["name"]
196
+ val = values[i]
197
+ if field.get("type") == "checkbox":
198
+ action_data[name] = bool(val)
199
+ elif val is not None and val != "":
200
+ action_data[name] = val
201
+ return await _step_with_action(action_data)()
202
+
203
+ step_fn = step_form
204
+
205
+ with gr.Row():
206
+ step_btn = gr.Button("Step", variant="primary")
207
+ reset_btn = gr.Button("Reset", variant="secondary")
208
+ state_btn = gr.Button("Get state", variant="secondary")
209
+ with gr.Row():
210
+ status = gr.Textbox(
211
+ label="Status",
212
+ interactive=False,
213
+ )
214
+ raw_json = gr.Code(
215
+ label="Raw JSON response",
216
+ language="json",
217
+ interactive=False,
218
+ )
219
+
220
+ reset_btn.click(
221
+ fn=reset_env,
222
+ outputs=[obs_display, raw_json, status],
223
+ )
224
+ step_btn.click(
225
+ fn=step_fn,
226
+ inputs=step_inputs,
227
+ outputs=[obs_display, raw_json, status],
228
+ )
229
+ if is_chat_env:
230
+ action_input.submit(
231
+ fn=step_fn,
232
+ inputs=step_inputs,
233
+ outputs=[obs_display, raw_json, status],
234
+ )
235
+ state_btn.click(
236
+ fn=get_state_sync,
237
+ outputs=[raw_json],
238
+ )
239
+
240
+ return demo
src/core/env_server/http_server.py CHANGED
@@ -8,25 +8,113 @@
8
  HTTP server wrapper for Environment instances.
9
 
10
  This module provides utilities to wrap any Environment subclass and expose it
11
- over HTTP endpoints that HTTPEnvClient can consume.
12
  """
13
 
14
  from __future__ import annotations
15
 
 
 
 
16
  import os
17
- from dataclasses import asdict
18
- from typing import Any, Dict, Type
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  from .interfaces import Environment
21
- from .types import Action, Observation
22
- from fastapi import Body, FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class HTTPEnvServer:
25
  """
26
  HTTP server wrapper for Environment instances.
27
 
28
  This class wraps an Environment and exposes its reset(), step(), and state
29
- methods as HTTP endpoints compatible with HTTPEnvClient.
30
 
31
  The server expects:
32
  - Action deserialization: Converts JSON dict to Action subclass
@@ -35,9 +123,15 @@ class HTTPEnvServer:
35
  Example:
36
  >>> from core.env_server import HTTPEnvServer
37
  >>> from envs.coding_env.server import CodeExecutionEnvironment
 
38
  >>>
39
- >>> env = CodeExecutionEnvironment()
40
- >>> server = HTTPEnvServer(env)
 
 
 
 
 
41
  >>>
42
  >>> # Register routes with FastAPI
43
  >>> from fastapi import FastAPI
@@ -47,178 +141,1177 @@ class HTTPEnvServer:
47
 
48
  def __init__(
49
  self,
50
- env: Environment,
51
  action_cls: Type[Action],
52
  observation_cls: Type[Observation],
 
 
53
  ):
54
  """
55
  Initialize HTTP server wrapper.
56
 
57
  Args:
58
- env: The Environment instance to wrap
 
59
  action_cls: The Action subclass this environment expects
60
  observation_cls: The Observation subclass this environment returns
 
 
 
 
 
 
 
 
 
61
  """
62
- self.env = env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  self.action_cls = action_cls
64
  self.observation_cls = observation_cls
65
 
66
- def register_routes(self, app: Any) -> None:
 
 
 
 
 
 
 
 
 
 
67
  """
68
- Register HTTP routes on a FastAPI application.
69
 
70
- Args:
71
- app: FastAPI application instance
 
72
  """
 
 
73
 
74
- if not isinstance(app, FastAPI):
75
- raise TypeError("app must be a FastAPI instance")
 
 
 
 
 
76
 
77
- @app.post("/reset")
78
- async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]:
79
- """Reset endpoint - returns initial observation."""
80
- # TODO: Handle seed, episode_id from request if provided
81
- observation = self.env.reset()
82
- return self._serialize_observation(observation)
83
 
84
- @app.post("/step")
85
- async def step(request: Dict[str, Any]) -> Dict[str, Any]:
86
- """Step endpoint - executes action and returns observation."""
87
- action_data = request.get("action", {})
88
- # TODO: Handle timeout_s, request_id, episode_id from request if provided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # Deserialize action
91
- action = self._deserialize_action(action_data)
 
 
 
 
 
 
 
 
92
 
93
- # Execute step
94
- observation = self.env.step(action)
95
 
96
- # Return serialized observation
97
- return self._serialize_observation(observation)
 
 
 
98
 
99
- @app.get("/state")
100
- async def get_state() -> Dict[str, Any]:
101
- """State endpoint - returns current environment state."""
102
- state = self.env.state
103
- return asdict(state)
 
 
 
 
 
 
 
 
104
 
105
- @app.get("/health")
106
- async def health() -> Dict[str, str]:
107
- """Health check endpoint."""
108
- return {"status": "healthy"}
 
 
 
 
 
109
 
 
110
 
111
- def _deserialize_action(self, action_data: Dict[str, Any]) -> Action:
112
  """
113
- Convert JSON dict to Action instance.
114
 
115
  Args:
116
- action_data: Dictionary containing action data
 
 
 
 
 
117
 
118
- Returns:
119
- Action instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- Note:
122
- This is a simple implementation. Subclasses may need to override
123
- for more complex deserialization logic.
124
  """
125
- # Remove metadata if present (it will be set via kw_only field)
126
- metadata = action_data.pop("metadata", {})
127
- action = self.action_cls(**action_data)
128
- action.metadata = metadata
129
- return action
130
 
131
- def _serialize_observation(self, observation: Observation) -> Dict[str, Any]:
132
  """
133
- Convert Observation instance to JSON-compatible dict.
134
 
135
  Args:
136
- observation: Observation instance
137
 
138
  Returns:
139
- Dictionary compatible with HTTPEnvClient._parse_result()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- The format matches what HTTPEnvClient expects:
142
- {
143
- "observation": {...}, # Observation fields
144
- "reward": float | None,
145
- "done": bool,
146
- }
 
 
 
 
 
 
 
 
 
 
 
147
  """
148
- obs_dict = asdict(observation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- # Extract reward and done (these are part of StepResult on client side)
151
- reward = obs_dict.pop("reward", None)
152
- done = obs_dict.pop("done", False)
153
- obs_dict.pop("metadata", None) # Remove metadata from observation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- # Return in HTTPEnvClient expected format
156
- return {
157
- "observation": obs_dict,
158
- "reward": reward,
159
- "done": done,
160
- }
161
 
162
  def create_app(
163
- env: Environment,
164
  action_cls: Type[Action],
165
  observation_cls: Type[Observation],
166
  env_name: Optional[str] = None,
167
- ) -> Any:
 
 
 
168
  """
169
  Create a FastAPI application with or without web interface.
170
-
171
  This function creates a FastAPI app with the web interface enabled by default,
172
  including README integration for better user experience.
173
-
174
  Args:
175
- env: The Environment instance to serve
176
  action_cls: The Action subclass this environment expects
177
  observation_cls: The Observation subclass this environment returns
178
  env_name: Optional environment name for README loading
179
-
 
 
 
 
 
 
 
 
180
  Returns:
181
  FastAPI application instance with or without web interface and README integration
182
  """
183
  # Check if web interface should be enabled
184
  # This can be controlled via environment variable or build argument
185
- enable_web = (
186
- os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes")
 
 
187
  )
188
 
189
  if enable_web:
190
- # Import web interface only when needed
191
  from .web_interface import create_web_interface_app
192
- return create_web_interface_app(env, action_cls, observation_cls, env_name)
 
 
 
 
 
 
 
 
 
193
  else:
194
  # Use standard FastAPI app without web interface
195
- return create_fastapi_app(env, action_cls, observation_cls)
196
-
 
 
197
 
198
  def create_fastapi_app(
199
- env: Environment,
200
  action_cls: Type[Action],
201
  observation_cls: Type[Observation],
202
- ) -> Any:
 
 
203
  """
204
- Create a FastAPI application with routes for the given environment.
205
 
206
  Args:
207
- env: The Environment instance to serve
208
  action_cls: The Action subclass this environment expects
209
  observation_cls: The Observation subclass this environment returns
 
 
 
 
210
 
211
  Returns:
212
- FastAPI application instance with routes registered
213
-
214
- Example:
215
- >>> from envs.coding_env.server import CodeExecutionEnvironment
216
- >>> from envs.coding_env.models import CodeAction, CodeObservation
217
- >>>
218
- >>> env = CodeExecutionEnvironment()
219
- >>> app = create_fastapi_app(env, CodeAction, CodeObservation)
220
- >>>
221
- >>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000
222
  """
223
  try:
224
  from fastapi import FastAPI
@@ -227,7 +1320,72 @@ def create_fastapi_app(
227
  "FastAPI is required. Install with: pip install fastapi uvicorn"
228
  )
229
 
230
- app = FastAPI(title="Environment HTTP Server")
231
- server = HTTPEnvServer(env, action_cls, observation_cls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  server.register_routes(app)
233
  return app
 
8
  HTTP server wrapper for Environment instances.
9
 
10
  This module provides utilities to wrap any Environment subclass and expose it
11
+ over HTTP and WebSocket endpoints that EnvClient can consume.
12
  """
13
 
14
  from __future__ import annotations
15
 
16
+ import asyncio
17
+ import inspect
18
+ import json
19
  import os
20
+ import time
21
+ import uuid
22
+ from concurrent.futures import ThreadPoolExecutor
23
+ from typing import Any, Callable, Dict, Optional, Type
24
+
25
+ from fastapi import (
26
+ Body,
27
+ FastAPI,
28
+ HTTPException,
29
+ Request,
30
+ status,
31
+ WebSocket,
32
+ WebSocketDisconnect,
33
+ )
34
+ from pydantic import ValidationError
35
 
36
  from .interfaces import Environment
37
+ from .mcp_environment import get_server_tools
38
+ from .mcp_types import (
39
+ JsonRpcErrorCode,
40
+ JsonRpcRequest,
41
+ JsonRpcResponse,
42
+ McpMethod,
43
+ WSMCPMessage,
44
+ WSMCPResponse,
45
+ )
46
+ from .route_config import GetEndpointConfig, register_get_endpoints
47
+ from .serialization import deserialize_action, serialize_observation
48
+ from .types import (
49
+ Action,
50
+ ConcurrencyConfig,
51
+ EnvironmentMetadata,
52
+ HealthResponse,
53
+ HealthStatus,
54
+ Observation,
55
+ ResetRequest,
56
+ ResetResponse,
57
+ SchemaResponse,
58
+ ServerCapacityStatus,
59
+ ServerMode,
60
+ SessionInfo,
61
+ State,
62
+ StepRequest,
63
+ StepResponse,
64
+ WSCloseMessage,
65
+ WSErrorCode,
66
+ WSErrorResponse,
67
+ WSObservationResponse,
68
+ WSResetMessage,
69
+ WSStateMessage,
70
+ WSStateResponse,
71
+ WSStepMessage,
72
+ )
73
+
74
+
75
+ def _make_json_serializable(obj: Any) -> Any:
76
+ """
77
+ Convert an object to a JSON-serializable form.
78
+
79
+ Handles Pydantic models, dataclasses, and other common types.
80
+
81
+ Args:
82
+ obj: The object to convert
83
+
84
+ Returns:
85
+ A JSON-serializable representation of the object
86
+ """
87
+ if obj is None:
88
+ return None
89
+ if isinstance(obj, (str, int, float, bool)):
90
+ return obj
91
+ if isinstance(obj, (list, tuple)):
92
+ return [_make_json_serializable(item) for item in obj]
93
+ if isinstance(obj, dict):
94
+ return {k: _make_json_serializable(v) for k, v in obj.items()}
95
+ if hasattr(obj, "model_dump"):
96
+ # Pydantic model
97
+ return obj.model_dump()
98
+ if hasattr(obj, "__dict__"):
99
+ # Object with __dict__
100
+ return {k: _make_json_serializable(v) for k, v in obj.__dict__.items()}
101
+ # Fallback to string representation
102
+ return str(obj)
103
+
104
+
105
+ from .exceptions import (
106
+ ConcurrencyConfigurationError,
107
+ EnvironmentFactoryError,
108
+ SessionCapacityError,
109
+ )
110
+
111
 
112
  class HTTPEnvServer:
113
  """
114
  HTTP server wrapper for Environment instances.
115
 
116
  This class wraps an Environment and exposes its reset(), step(), and state
117
+ methods as HTTP and WebSocket endpoints compatible with EnvClient.
118
 
119
  The server expects:
120
  - Action deserialization: Converts JSON dict to Action subclass
 
123
  Example:
124
  >>> from core.env_server import HTTPEnvServer
125
  >>> from envs.coding_env.server import CodeExecutionEnvironment
126
+ >>> from envs.coding_env.models import CodeAction, CodeObservation
127
  >>>
128
+ >>> # Pass environment class (factory pattern)
129
+ >>> server = HTTPEnvServer(
130
+ ... env=CodeExecutionEnvironment,
131
+ ... action_cls=CodeAction,
132
+ ... observation_cls=CodeObservation,
133
+ ... max_concurrent_envs=4,
134
+ ... )
135
  >>>
136
  >>> # Register routes with FastAPI
137
  >>> from fastapi import FastAPI
 
141
 
142
  def __init__(
143
  self,
144
+ env: Callable[[], Environment],
145
  action_cls: Type[Action],
146
  observation_cls: Type[Observation],
147
+ max_concurrent_envs: Optional[int] = None,
148
+ concurrency_config: Optional[ConcurrencyConfig] = None,
149
  ):
150
  """
151
  Initialize HTTP server wrapper.
152
 
153
  Args:
154
+ env: Environment factory (callable) that creates new instances.
155
+ Will be called to create a new environment for each WebSocket session.
156
  action_cls: The Action subclass this environment expects
157
  observation_cls: The Observation subclass this environment returns
158
+ max_concurrent_envs: Maximum number of concurrent WebSocket sessions.
159
+ Mutually exclusive with concurrency_config.
160
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
161
+ Mutually exclusive with max_concurrent_envs.
162
+
163
+ Raises:
164
+ ValueError: If both max_concurrent_envs and concurrency_config are provided.
165
+ ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
166
+ environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
167
  """
168
+ # Validate that env is callable
169
+ if not callable(env):
170
+ raise TypeError(
171
+ f"env must be a callable (class or factory function), got {type(env)}. "
172
+ f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())."
173
+ )
174
+
175
+ self._env_factory: Callable[[], Environment] = env
176
+
177
+ # Handle concurrency configuration
178
+ if max_concurrent_envs is not None and concurrency_config is not None:
179
+ raise ValueError(
180
+ "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. "
181
+ "Please use only one method to configure concurrency."
182
+ )
183
+
184
+ if concurrency_config is not None:
185
+ self._concurrency_config = concurrency_config
186
+ elif max_concurrent_envs is not None:
187
+ self._concurrency_config = ConcurrencyConfig(
188
+ max_concurrent_envs=max_concurrent_envs,
189
+ session_timeout=None,
190
+ )
191
+ else:
192
+ # Default configuration
193
+ self._concurrency_config = ConcurrencyConfig(
194
+ max_concurrent_envs=1,
195
+ session_timeout=None,
196
+ )
197
+
198
+ self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs
199
+
200
+ # Validate concurrency configuration
201
+ self._validate_concurrency_safety()
202
+
203
  self.action_cls = action_cls
204
  self.observation_cls = observation_cls
205
 
206
+ # Session management for WebSocket connections
207
+ self._sessions: Dict[str, Environment] = {}
208
+ self._session_executors: Dict[str, ThreadPoolExecutor] = {}
209
+ self._session_info: Dict[str, SessionInfo] = {}
210
+ self._session_lock = asyncio.Lock()
211
+
212
+ # Create thread pool for running sync code in async context
213
+ # This is needed for environments using sync libraries (e.g., Playwright)
214
+ self._executor = ThreadPoolExecutor(max_workers=32)
215
+
216
+ def _validate_concurrency_safety(self) -> None:
217
  """
218
+ Validate that the environment supports the configured concurrency level.
219
 
220
+ Raises:
221
+ ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
222
+ environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
223
  """
224
+ if self._max_concurrent_envs <= 1:
225
+ return
226
 
227
+ if inspect.isclass(self._env_factory):
228
+ env_cls = self._env_factory
229
+ else:
230
+ _temp_env = self._env_factory()
231
+ env_cls = type(_temp_env)
232
+ _temp_env.close()
233
+ del _temp_env
234
 
235
+ if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False):
236
+ raise ConcurrencyConfigurationError(
237
+ environment_name=env_cls.__name__,
238
+ max_concurrent_envs=self._max_concurrent_envs,
239
+ )
 
240
 
241
+ def get_capacity_status(self) -> ServerCapacityStatus:
242
+ """
243
+ Get the current capacity status of the server.
244
+
245
+ Returns:
246
+ ServerCapacityStatus with current session counts and availability.
247
+ """
248
+ return ServerCapacityStatus.from_counts(
249
+ active=len(self._sessions),
250
+ max_sessions=self._max_concurrent_envs,
251
+ )
252
+
253
+ async def _run_sync_in_thread_pool(
254
+ self, func: Callable[..., Observation], *args, **kwargs
255
+ ) -> Observation:
256
+ """Run a synchronous function in the thread pool executor."""
257
+ loop = asyncio.get_event_loop()
258
+ return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs))
259
+
260
+ def _get_valid_kwargs(
261
+ self,
262
+ sig: inspect.Signature,
263
+ kwargs: Dict[str, Any],
264
+ skip_params: Optional[set[str]] = None,
265
+ ) -> Dict[str, Any]:
266
+ """Filter kwargs to only include parameters accepted by the function signature."""
267
+ if skip_params is None:
268
+ skip_params = set()
269
+
270
+ valid_kwargs = {}
271
+
272
+ has_kwargs = any(
273
+ p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
274
+ )
275
+
276
+ for k, v in kwargs.items():
277
+ if k in sig.parameters or has_kwargs:
278
+ if k not in skip_params:
279
+ valid_kwargs[k] = v
280
+
281
+ return valid_kwargs
282
+
283
+ async def _create_session(self) -> tuple[str, Environment]:
284
+ """
285
+ Create a new WebSocket session with its own environment instance.
286
+
287
+ Returns:
288
+ Tuple of (session_id, environment)
289
 
290
+ Raises:
291
+ SessionCapacityError: If max concurrent sessions reached
292
+ EnvironmentFactoryError: If the factory fails to create an environment
293
+ """
294
+ async with self._session_lock:
295
+ if len(self._sessions) >= self._max_concurrent_envs:
296
+ raise SessionCapacityError(
297
+ active_sessions=len(self._sessions),
298
+ max_sessions=self._max_concurrent_envs,
299
+ )
300
 
301
+ session_id = str(uuid.uuid4())
302
+ current_time = time.time()
303
 
304
+ # Create executor and reserve slot so capacity is not exceeded while
305
+ # we create the env outside the lock (avoids blocking other sessions)
306
+ executor = ThreadPoolExecutor(max_workers=1)
307
+ self._session_executors[session_id] = executor
308
+ self._sessions[session_id] = None # placeholder until env is ready
309
 
310
+ try:
311
+ # Create environment in the executor thread (outside lock)
312
+ loop = asyncio.get_event_loop()
313
+ env = await loop.run_in_executor(executor, self._env_factory)
314
+ except Exception as e:
315
+ async with self._session_lock:
316
+ executor.shutdown(wait=False)
317
+ self._session_executors.pop(session_id, None)
318
+ self._sessions.pop(session_id, None)
319
+ factory_name = getattr(
320
+ self._env_factory, "__name__", str(self._env_factory)
321
+ )
322
+ raise EnvironmentFactoryError(factory_name) from e
323
 
324
+ async with self._session_lock:
325
+ self._sessions[session_id] = env
326
+ self._session_info[session_id] = SessionInfo(
327
+ session_id=session_id,
328
+ created_at=current_time,
329
+ last_activity_at=current_time,
330
+ step_count=0,
331
+ environment_type=type(env).__name__,
332
+ )
333
 
334
+ return session_id, env
335
 
336
+ async def _destroy_session(self, session_id: str) -> None:
337
  """
338
+ Destroy a WebSocket session and cleanup resources.
339
 
340
  Args:
341
+ session_id: The session ID to destroy
342
+ """
343
+ async with self._session_lock:
344
+ env = self._sessions.pop(session_id, None)
345
+ executor = self._session_executors.pop(session_id, None)
346
+ self._session_info.pop(session_id, None)
347
 
348
+ # Run close() in the same executor where the env was created
349
+ # This is required for thread-sensitive libraries like Playwright/greenlet
350
+ if env is not None:
351
+ if executor is not None:
352
+ try:
353
+ loop = asyncio.get_event_loop()
354
+ await loop.run_in_executor(executor, env.close)
355
+ except Exception:
356
+ # If executor close fails, try direct close as fallback
357
+ try:
358
+ env.close()
359
+ except Exception:
360
+ pass # Best effort cleanup
361
+ else:
362
+ try:
363
+ env.close()
364
+ except Exception:
365
+ pass # Best effort cleanup
366
+
367
+ # Shutdown executor after close is done
368
+ if executor is not None:
369
+ executor.shutdown(wait=False)
370
+
371
+ def _update_session_activity(
372
+ self, session_id: str, increment_step: bool = False
373
+ ) -> None:
374
+ """
375
+ Update session activity timestamp and optionally increment step count.
376
 
377
+ Args:
378
+ session_id: The session ID to update
379
+ increment_step: If True, increment the step count
380
  """
381
+ if session_id in self._session_info:
382
+ self._session_info[session_id].last_activity_at = time.time()
383
+ if increment_step:
384
+ self._session_info[session_id].step_count += 1
 
385
 
386
+ def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
387
  """
388
+ Get information about a specific session.
389
 
390
  Args:
391
+ session_id: The session ID to query
392
 
393
  Returns:
394
+ SessionInfo if the session exists, None otherwise
395
+ """
396
+ return self._session_info.get(session_id)
397
+
398
+ async def _run_in_session_executor(
399
+ self, session_id: str, func: Callable[..., Observation], *args, **kwargs
400
+ ) -> Observation:
401
+ """Run a synchronous function in the session's thread pool executor."""
402
+ executor = self._session_executors.get(session_id, self._executor)
403
+ loop = asyncio.get_event_loop()
404
+ return await loop.run_in_executor(executor, lambda: func(*args, **kwargs))
405
+
406
+ @property
407
+ def active_sessions(self) -> int:
408
+ """Return the number of active WebSocket sessions."""
409
+ return len(self._sessions)
410
+
411
+ @property
412
+ def max_concurrent_envs(self) -> int:
413
+ """Return the maximum number of concurrent environments."""
414
+ return self._max_concurrent_envs
415
+
416
+ @property
417
+ def is_concurrency_safe(self) -> bool:
418
+ """Return whether the environment is marked as concurrency safe."""
419
+ import inspect
420
 
421
+ if inspect.isclass(self._env_factory):
422
+ return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False)
423
+ else:
424
+ _temp_env = self._env_factory()
425
+ result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False)
426
+ _temp_env.close()
427
+ del _temp_env
428
+ return result
429
+
430
+ @property
431
+ def concurrency_config(self) -> ConcurrencyConfig:
432
+ """Return the concurrency configuration."""
433
+ return self._concurrency_config
434
+
435
+ def register_routes(
436
+ self, app: FastAPI, mode: ServerMode | str = ServerMode.SIMULATION
437
+ ) -> None:
438
  """
439
+ Register HTTP routes on a FastAPI application.
440
+
441
+ Args:
442
+ app: FastAPI application instance
443
+ mode: Server mode - either SIMULATION or PRODUCTION (or string equivalents).
444
+ In production mode, simulation control endpoints (/reset, /step, /state)
445
+ are NOT registered. Only safe endpoints (/health, /schema, /metadata, /ws)
446
+ are available. Defaults to SIMULATION for backwards compatibility.
447
+
448
+ Raises:
449
+ ValueError: If mode is not a valid ServerMode or string equivalent.
450
+ """
451
+ # Convert string to ServerMode enum for backwards compatibility
452
+ if isinstance(mode, str):
453
+ try:
454
+ mode = ServerMode(mode.lower())
455
+ except ValueError:
456
+ valid_modes = [m.value for m in ServerMode]
457
+ raise ValueError(
458
+ f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
459
+ )
460
+
461
+ # Helper function to handle reset endpoint
462
+ async def reset_handler(
463
+ request: ResetRequest = Body(default_factory=ResetRequest),
464
+ ) -> ResetResponse:
465
+ """Reset endpoint - returns initial observation."""
466
+ _env = self._env_factory()
467
+
468
+ try:
469
+ kwargs = request.model_dump(exclude_unset=True)
470
+
471
+ is_async = _env.reset_async.__func__ is not Environment.reset_async
472
+
473
+ if is_async:
474
+ sig = inspect.signature(_env.reset_async)
475
+ else:
476
+ sig = inspect.signature(_env.reset)
477
+ valid_kwargs = self._get_valid_kwargs(sig, kwargs)
478
+
479
+ if is_async:
480
+ observation = await _env.reset_async(**valid_kwargs)
481
+ else:
482
+ observation = await self._run_sync_in_thread_pool(
483
+ _env.reset, **valid_kwargs
484
+ )
485
+ return ResetResponse(**serialize_observation(observation))
486
+ finally:
487
+ _env.close()
488
+
489
+ # Helper function to handle step endpoint
490
+ async def step_handler(request: StepRequest) -> StepResponse:
491
+ """Step endpoint - executes action and returns observation."""
492
+ action_data = request.action
493
+
494
+ try:
495
+ action = deserialize_action(action_data, self.action_cls)
496
+ except ValidationError as e:
497
+ raise HTTPException(
498
+ status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()
499
+ )
500
+
501
+ _env = self._env_factory()
502
+
503
+ try:
504
+ kwargs = request.model_dump(exclude_unset=True, exclude={"action"})
505
+
506
+ is_async = _env.step_async.__func__ is not Environment.step_async
507
+
508
+ if is_async:
509
+ sig = inspect.signature(_env.step_async)
510
+ else:
511
+ sig = inspect.signature(_env.step)
512
+ valid_kwargs = self._get_valid_kwargs(
513
+ sig, kwargs, skip_params={"action"}
514
+ )
515
+
516
+ if is_async:
517
+ observation = await _env.step_async(action, **valid_kwargs)
518
+ else:
519
+ observation = await self._run_sync_in_thread_pool(
520
+ _env.step, action, **valid_kwargs
521
+ )
522
+
523
+ return StepResponse(**serialize_observation(observation))
524
+ finally:
525
+ _env.close()
526
+
527
+ # Helper function to handle MCP endpoint
528
+ async def mcp_handler(
529
+ request: JsonRpcRequest, session_env: Optional[Environment] = None
530
+ ) -> JsonRpcResponse:
531
+ """
532
+ Handle MCP JSON-RPC requests.
533
+
534
+ Supports tools/list and tools/call methods in JSON-RPC 2.0 format.
535
+ """
536
+ method = request.method
537
+ request_id = request.id
538
+
539
+ # Use provided session environment or create temporary one
540
+ if session_env is not None:
541
+ _env = session_env
542
+ should_close = False
543
+ else:
544
+ _env = self._env_factory()
545
+ should_close = True
546
+ try:
547
+ if method == McpMethod.TOOLS_LIST:
548
+ # Check if environment is MCP-enabled
549
+ if not hasattr(_env, "mcp_client"):
550
+ return JsonRpcResponse.error_response(
551
+ JsonRpcErrorCode.INTERNAL_ERROR,
552
+ "Environment does not support MCP",
553
+ request_id=request_id,
554
+ )
555
+
556
+ # Use async context manager for MCP client
557
+ async with _env.mcp_client:
558
+ tools = await _env.mcp_client.list_tools()
559
+
560
+ return JsonRpcResponse.success(
561
+ result={
562
+ "tools": [
563
+ t.model_dump() if hasattr(t, "model_dump") else dict(t)
564
+ for t in tools
565
+ ]
566
+ },
567
+ request_id=request_id,
568
+ )
569
+
570
+ elif method == McpMethod.TOOLS_CALL:
571
+ params = request.params
572
+ tool_name = params.get("name")
573
+ arguments = params.get("arguments", {})
574
+
575
+ if not hasattr(_env, "mcp_client"):
576
+ return JsonRpcResponse.error_response(
577
+ JsonRpcErrorCode.INTERNAL_ERROR,
578
+ "Environment does not support MCP",
579
+ request_id=request_id,
580
+ )
581
+
582
+ if not tool_name:
583
+ return JsonRpcResponse.error_response(
584
+ JsonRpcErrorCode.INVALID_REQUEST,
585
+ "Missing 'name' in params",
586
+ request_id=request_id,
587
+ )
588
+
589
+ # Use async context manager for MCP client
590
+ async with _env.mcp_client:
591
+ result = await _env.mcp_client.call_tool(
592
+ name=tool_name, arguments=arguments
593
+ )
594
+
595
+ # Ensure result is JSON serializable
596
+ serializable_result = _make_json_serializable(result)
597
 
598
+ return JsonRpcResponse.success(
599
+ result=serializable_result,
600
+ request_id=request_id,
601
+ )
602
+
603
+ else:
604
+ return JsonRpcResponse.error_response(
605
+ JsonRpcErrorCode.METHOD_NOT_FOUND,
606
+ f"Method not found: {method}",
607
+ request_id=request_id,
608
+ )
609
+
610
+ except Exception as e:
611
+ return JsonRpcResponse.error_response(
612
+ JsonRpcErrorCode.INTERNAL_ERROR,
613
+ str(e),
614
+ request_id=request_id,
615
+ )
616
+ finally:
617
+ if should_close:
618
+ _env.close()
619
+
620
+ # Register MCP WebSocket endpoint (available in both production and simulation modes)
621
+ @app.websocket("/mcp")
622
+ async def mcp_websocket_endpoint(websocket: WebSocket):
623
+ """
624
+ WebSocket endpoint for MCP JSON-RPC requests.
625
+
626
+ Each WebSocket connection gets its own environment instance for MCP operations.
627
+
628
+ Message Protocol:
629
+ - Client sends: JSON-RPC 2.0 request (tools/list, tools/call)
630
+ - Server responds: JSON-RPC 2.0 response (result or error)
631
+ """
632
+ await websocket.accept()
633
+
634
+ session_id = None
635
+ session_env = None
636
+
637
+ try:
638
+ # Create session with dedicated environment
639
+ session_id, session_env = await self._create_session()
640
+
641
+ while True:
642
+ # Receive message from client
643
+ raw_message = await websocket.receive_text()
644
+
645
+ try:
646
+ jsonrpc_dict = json.loads(raw_message)
647
+ jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
648
+ except json.JSONDecodeError as e:
649
+ error_resp = JsonRpcResponse.error_response(
650
+ JsonRpcErrorCode.PARSE_ERROR,
651
+ f"Parse error: {e}",
652
+ )
653
+ await websocket.send_text(error_resp.model_dump_json())
654
+ continue
655
+ except ValidationError as e:
656
+ error_resp = JsonRpcResponse.error_response(
657
+ JsonRpcErrorCode.INVALID_REQUEST,
658
+ f"Invalid request: {e}",
659
+ )
660
+ await websocket.send_text(error_resp.model_dump_json())
661
+ continue
662
+
663
+ try:
664
+ # Call mcp_handler with session environment
665
+ response = await mcp_handler(
666
+ jsonrpc_request, session_env=session_env
667
+ )
668
+ await websocket.send_text(response.model_dump_json())
669
+ except Exception as e:
670
+ error_resp = JsonRpcResponse.error_response(
671
+ JsonRpcErrorCode.INTERNAL_ERROR,
672
+ str(e),
673
+ request_id=jsonrpc_request.id,
674
+ )
675
+ await websocket.send_text(error_resp.model_dump_json())
676
+
677
+ except WebSocketDisconnect:
678
+ pass
679
+ except SessionCapacityError as e:
680
+ error_resp = JsonRpcResponse.error_response(
681
+ JsonRpcErrorCode.SERVER_ERROR,
682
+ str(e),
683
+ data={
684
+ "active_sessions": e.active_sessions,
685
+ "max_sessions": e.max_sessions,
686
+ },
687
+ )
688
+ await websocket.send_text(error_resp.model_dump_json())
689
+ except EnvironmentFactoryError as e:
690
+ error_resp = JsonRpcResponse.error_response(
691
+ JsonRpcErrorCode.SERVER_ERROR,
692
+ str(e),
693
+ data={"factory_name": e.factory_name},
694
+ )
695
+ await websocket.send_text(error_resp.model_dump_json())
696
+ except Exception as e:
697
+ error_resp = JsonRpcResponse.error_response(
698
+ JsonRpcErrorCode.SERVER_ERROR,
699
+ str(e),
700
+ )
701
+ await websocket.send_text(error_resp.model_dump_json())
702
+ finally:
703
+ if session_id:
704
+ await self._destroy_session(session_id)
705
+ try:
706
+ await websocket.close()
707
+ except RuntimeError:
708
+ pass
709
+
710
+ # Register simulation control routes only in simulation mode
711
+ if mode == ServerMode.SIMULATION:
712
+
713
+ @app.post(
714
+ "/reset",
715
+ response_model=ResetResponse,
716
+ tags=["Environment Control"],
717
+ summary="Reset the environment",
718
+ description="""
719
+ Reset the environment to its initial state and return the first observation.
720
+
721
+ You can optionally provide a seed for reproducibility and an episode_id for tracking.
722
+ """,
723
+ responses={
724
+ 200: {
725
+ "description": "Environment reset successfully",
726
+ "content": {
727
+ "application/json": {
728
+ "example": {
729
+ "observation": {"status": "ready", "data": {}},
730
+ "reward": None,
731
+ "done": False,
732
+ }
733
+ }
734
+ },
735
+ }
736
+ },
737
+ )
738
+ async def reset(
739
+ request: ResetRequest = Body(default_factory=ResetRequest),
740
+ ) -> ResetResponse:
741
+ return await reset_handler(request)
742
+
743
+ @app.post(
744
+ "/step",
745
+ response_model=StepResponse,
746
+ tags=["Environment Control"],
747
+ summary="Execute an action in the environment",
748
+ description="""
749
+ Execute an action in the environment and receive the resulting observation.
750
+
751
+ The action must conform to the environment's action schema, which can be
752
+ retrieved from the `/schema` endpoint. If the action is invalid,
753
+ the endpoint will return HTTP 422 with detailed validation errors.
754
+
755
+ The response includes:
756
+ - **observation**: The environment's response to the action
757
+ - **reward**: Optional reward signal (float or None)
758
+ - **done**: Boolean indicating if the episode has terminated
759
+ """,
760
+ responses={
761
+ 200: {
762
+ "description": "Action executed successfully",
763
+ "content": {
764
+ "application/json": {
765
+ "example": {
766
+ "observation": {"status": "success", "data": {}},
767
+ "reward": 1.0,
768
+ "done": False,
769
+ }
770
+ }
771
+ },
772
+ },
773
+ 422: {
774
+ "description": "Validation error - invalid action format or values",
775
+ "content": {
776
+ "application/json": {
777
+ "example": {
778
+ "detail": [
779
+ {
780
+ "type": "string_too_short",
781
+ "loc": ["body", "action", "message"],
782
+ "msg": "String should have at least 1 character",
783
+ "input": "",
784
+ }
785
+ ]
786
+ }
787
+ }
788
+ },
789
+ },
790
+ 500: {
791
+ "description": "Internal server error during action execution"
792
+ },
793
+ },
794
+ )
795
+ async def step(request: StepRequest) -> StepResponse:
796
+ return await step_handler(request)
797
+
798
+ def get_state_handler() -> State:
799
+ _env = self._env_factory()
800
+ try:
801
+ return _env.state
802
+ finally:
803
+ _env.close()
804
+
805
+ def get_metadata_handler() -> EnvironmentMetadata:
806
+ _env = self._env_factory()
807
+ try:
808
+ return _env.get_metadata()
809
+ finally:
810
+ _env.close()
811
+
812
+ # Build list of GET endpoints based on mode
813
+ get_endpoints = [
814
+ GetEndpointConfig(
815
+ path="/metadata",
816
+ handler=get_metadata_handler,
817
+ response_model=EnvironmentMetadata,
818
+ tag="Environment Info",
819
+ summary="Get environment metadata",
820
+ description="""
821
+ Get metadata about this environment.
822
+
823
+ Returns information about the environment including name, description,
824
+ version, author, and documentation links.
825
+ """,
826
+ ),
827
+ GetEndpointConfig(
828
+ path="/health",
829
+ handler=lambda: HealthResponse(status=HealthStatus.HEALTHY),
830
+ response_model=HealthResponse,
831
+ tag="Health",
832
+ summary="Health check",
833
+ description="Check if the environment server is running and healthy.",
834
+ ),
835
+ ]
836
+
837
+ # Only register /state endpoint in simulation mode
838
+ if mode == ServerMode.SIMULATION:
839
+ get_endpoints.insert(
840
+ 0,
841
+ GetEndpointConfig(
842
+ path="/state",
843
+ handler=get_state_handler,
844
+ response_model=State,
845
+ tag="State Management",
846
+ summary="Get current environment state",
847
+ description="""
848
+ Retrieve the current internal state of the environment.
849
+
850
+ The structure of the state object is defined by the environment's State model.
851
+ """,
852
+ ),
853
+ )
854
+
855
+ register_get_endpoints(app, get_endpoints)
856
+
857
+ # Register combined schema endpoint
858
+ @app.get(
859
+ "/schema",
860
+ response_model=SchemaResponse,
861
+ tags=["Schema"],
862
+ summary="Get all JSON schemas",
863
+ description="""
864
+ Get JSON schemas for actions, observations, and state in a single response.
865
+
866
+ Returns a combined schema object containing:
867
+ - **action**: JSON schema for actions accepted by this environment
868
+ - **observation**: JSON schema for observations returned by this environment
869
+ - **state**: JSON schema for environment state objects
870
+
871
+ This is more efficient than calling individual schema endpoints and provides
872
+ all schema information needed to interact with the environment.
873
+ """,
874
+ responses={
875
+ 200: {
876
+ "description": "Combined schemas retrieved successfully",
877
+ "content": {
878
+ "application/json": {
879
+ "example": {
880
+ "action": {
881
+ "type": "object",
882
+ "properties": {"message": {"type": "string"}},
883
+ },
884
+ "observation": {
885
+ "type": "object",
886
+ "properties": {"response": {"type": "string"}},
887
+ },
888
+ "state": {
889
+ "type": "object",
890
+ "properties": {"step_count": {"type": "integer"}},
891
+ },
892
+ }
893
+ }
894
+ },
895
+ }
896
+ },
897
+ )
898
+ async def get_schemas() -> SchemaResponse:
899
+ """Return all schemas in one response."""
900
+ return SchemaResponse(
901
+ action=self.action_cls.model_json_schema(),
902
+ observation=self.observation_cls.model_json_schema(),
903
+ state=State.model_json_schema(),
904
+ )
905
+
906
+ # Register MCP endpoint for production mode (direct MCP access)
907
+ @app.post("/mcp")
908
+ async def mcp_endpoint(request_raw: Request) -> Dict[str, Any]:
909
+ """
910
+ MCP JSON-RPC endpoint for production mode.
911
+
912
+ Bypasses step() overhead and provides direct access to MCP tools.
913
+ Supports tools/list and tools/call methods.
914
+ """
915
+ # Parse JSON manually to handle parse errors gracefully
916
+ try:
917
+ body = await request_raw.body()
918
+ request_dict = json.loads(body)
919
+ request = JsonRpcRequest(**request_dict)
920
+ except json.JSONDecodeError:
921
+ return JsonRpcResponse.error_response(
922
+ JsonRpcErrorCode.PARSE_ERROR
923
+ ).model_dump()
924
+ except ValidationError as e:
925
+ return JsonRpcResponse.error_response(
926
+ JsonRpcErrorCode.INVALID_REQUEST,
927
+ f"Invalid request: {e}",
928
+ ).model_dump()
929
+ except Exception:
930
+ return JsonRpcResponse.error_response(
931
+ JsonRpcErrorCode.PARSE_ERROR
932
+ ).model_dump()
933
+
934
+ method = request.method
935
+ params = request.params
936
+ request_id = request.id
937
+
938
+ # Create a temporary environment for MCP access
939
+ _env = self._env_factory()
940
+
941
+ try:
942
+ # Check if environment supports MCP
943
+ if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"):
944
+ return JsonRpcResponse.error_response(
945
+ JsonRpcErrorCode.INTERNAL_ERROR,
946
+ "Environment does not support MCP",
947
+ request_id=request_id,
948
+ ).model_dump()
949
+
950
+ if method == McpMethod.TOOLS_LIST:
951
+ # List tools from MCP server
952
+ if hasattr(_env, "mcp_client") and _env.mcp_client:
953
+ async with _env.mcp_client:
954
+ tools = await _env.mcp_client.list_tools()
955
+ return JsonRpcResponse.success(
956
+ result={
957
+ "tools": [
958
+ t.model_dump()
959
+ if hasattr(t, "model_dump")
960
+ else dict(t)
961
+ for t in tools
962
+ ]
963
+ },
964
+ request_id=request_id,
965
+ ).model_dump()
966
+ elif hasattr(_env, "mcp_server") and _env.mcp_server:
967
+ # Use server directly
968
+ tools = []
969
+ for tool_name, tool in get_server_tools(
970
+ _env.mcp_server
971
+ ).items():
972
+ tool_dict = {
973
+ "name": tool.name,
974
+ "description": tool.description or "",
975
+ "inputSchema": tool.parameters or {},
976
+ }
977
+ tools.append(tool_dict)
978
+ return JsonRpcResponse.success(
979
+ result={"tools": tools},
980
+ request_id=request_id,
981
+ ).model_dump()
982
+ else:
983
+ return JsonRpcResponse.error_response(
984
+ JsonRpcErrorCode.INTERNAL_ERROR,
985
+ "MCP server not available",
986
+ request_id=request_id,
987
+ ).model_dump()
988
+
989
+ elif method == McpMethod.TOOLS_CALL:
990
+ tool_name = params.get("name")
991
+ arguments = params.get("arguments", {})
992
+
993
+ if not tool_name:
994
+ return JsonRpcResponse.error_response(
995
+ JsonRpcErrorCode.INVALID_PARAMS,
996
+ "Invalid params - 'name' is required",
997
+ request_id=request_id,
998
+ ).model_dump()
999
+
1000
+ # Call tool via MCP
1001
+ if hasattr(_env, "mcp_client") and _env.mcp_client:
1002
+ async with _env.mcp_client:
1003
+ result = await _env.mcp_client.call_tool(
1004
+ name=tool_name, arguments=arguments
1005
+ )
1006
+ elif hasattr(_env, "mcp_server") and _env.mcp_server:
1007
+ # Call tool directly on FastMCP server
1008
+ server_tools = get_server_tools(_env.mcp_server)
1009
+ if tool_name in server_tools:
1010
+ tool = server_tools[tool_name]
1011
+ result = tool.fn(**arguments)
1012
+ else:
1013
+ return JsonRpcResponse.error_response(
1014
+ JsonRpcErrorCode.INVALID_PARAMS,
1015
+ f"Tool not found: {tool_name}",
1016
+ request_id=request_id,
1017
+ ).model_dump()
1018
+ else:
1019
+ return JsonRpcResponse.error_response(
1020
+ JsonRpcErrorCode.INTERNAL_ERROR,
1021
+ "MCP server not available",
1022
+ request_id=request_id,
1023
+ ).model_dump()
1024
+
1025
+ # Make result JSON serializable
1026
+ serializable_result = _make_json_serializable(result)
1027
+
1028
+ return JsonRpcResponse.success(
1029
+ result=serializable_result,
1030
+ request_id=request_id,
1031
+ ).model_dump()
1032
+
1033
+ else:
1034
+ return JsonRpcResponse.error_response(
1035
+ JsonRpcErrorCode.METHOD_NOT_FOUND,
1036
+ f"Method not found: {method}",
1037
+ request_id=request_id,
1038
+ ).model_dump()
1039
+
1040
+ except Exception as e:
1041
+ return JsonRpcResponse.error_response(
1042
+ JsonRpcErrorCode.INTERNAL_ERROR,
1043
+ str(e),
1044
+ request_id=request_id,
1045
+ ).model_dump()
1046
+ finally:
1047
+ _env.close()
1048
+
1049
+ # Register WebSocket endpoint for persistent sessions
1050
+ @app.websocket("/ws")
1051
+ async def websocket_endpoint(websocket: WebSocket):
1052
+ """
1053
+ WebSocket endpoint for persistent environment sessions.
1054
+
1055
+ Each WebSocket connection gets its own environment instance.
1056
+
1057
+ Message Protocol:
1058
+ - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage
1059
+ - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse
1060
+ """
1061
+ await websocket.accept()
1062
+
1063
+ session_id = None
1064
+ session_env = None
1065
+
1066
+ try:
1067
+ # Create session with dedicated environment
1068
+ session_id, session_env = await self._create_session()
1069
+
1070
+ while True:
1071
+ # Receive message from client
1072
+ raw_message = await websocket.receive_text()
1073
+
1074
+ try:
1075
+ message_dict = json.loads(raw_message)
1076
+ except json.JSONDecodeError as e:
1077
+ error_resp = WSErrorResponse(
1078
+ data={
1079
+ "message": f"Invalid JSON: {e}",
1080
+ "code": WSErrorCode.INVALID_JSON,
1081
+ }
1082
+ )
1083
+ await websocket.send_text(error_resp.model_dump_json())
1084
+ continue
1085
+
1086
+ msg_type = message_dict.get("type", "")
1087
+
1088
+ try:
1089
+ match msg_type:
1090
+ case "reset":
1091
+ msg = WSResetMessage(**message_dict)
1092
+
1093
+ is_async = (
1094
+ session_env.reset_async.__func__
1095
+ is not Environment.reset_async
1096
+ )
1097
+
1098
+ if is_async:
1099
+ sig = inspect.signature(session_env.reset_async)
1100
+ valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1101
+ observation = await session_env.reset_async(
1102
+ **valid_kwargs
1103
+ )
1104
+ else:
1105
+ sig = inspect.signature(session_env.reset)
1106
+ valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1107
+ observation = await self._run_in_session_executor(
1108
+ session_id, session_env.reset, **valid_kwargs
1109
+ )
1110
+
1111
+ self._update_session_activity(session_id)
1112
+
1113
+ response = WSObservationResponse(
1114
+ data=serialize_observation(observation),
1115
+ )
1116
+
1117
+ case "step":
1118
+ msg = WSStepMessage(**message_dict)
1119
+ action = deserialize_action(msg.data, self.action_cls)
1120
+
1121
+ is_async = (
1122
+ session_env.step_async.__func__
1123
+ is not Environment.step_async
1124
+ )
1125
+
1126
+ if is_async:
1127
+ observation = await session_env.step_async(action)
1128
+ else:
1129
+ observation = await self._run_in_session_executor(
1130
+ session_id, session_env.step, action
1131
+ )
1132
+
1133
+ self._update_session_activity(
1134
+ session_id, increment_step=True
1135
+ )
1136
+
1137
+ response = WSObservationResponse(
1138
+ data=serialize_observation(observation)
1139
+ )
1140
+
1141
+ case "state":
1142
+ msg = WSStateMessage(**message_dict)
1143
+ state = session_env.state
1144
+ if hasattr(state, "model_dump"):
1145
+ state_data = state.model_dump()
1146
+ else:
1147
+ state_data = dict(state) if state else {}
1148
+
1149
+ response = WSStateResponse(data=state_data)
1150
+
1151
+ case "close":
1152
+ msg = WSCloseMessage(**message_dict)
1153
+ break
1154
+
1155
+ case "mcp":
1156
+ msg = WSMCPMessage(**message_dict)
1157
+ try:
1158
+ rpc_request = JsonRpcRequest(**msg.data)
1159
+ except (ValidationError, Exception) as e:
1160
+ rpc_response = JsonRpcResponse.error_response(
1161
+ JsonRpcErrorCode.INVALID_REQUEST,
1162
+ f"Invalid request: {e}",
1163
+ )
1164
+ else:
1165
+ rpc_response = await mcp_handler(
1166
+ rpc_request,
1167
+ session_env=session_env,
1168
+ )
1169
+ response = WSMCPResponse(data=rpc_response.model_dump())
1170
+
1171
+ case _:
1172
+ response = WSErrorResponse(
1173
+ data={
1174
+ "message": f"Unknown message type: {msg_type}",
1175
+ "code": WSErrorCode.UNKNOWN_TYPE,
1176
+ }
1177
+ )
1178
+
1179
+ await websocket.send_text(response.model_dump_json())
1180
+
1181
+ except ValidationError as e:
1182
+ error_resp = WSErrorResponse(
1183
+ data={
1184
+ "message": "Invalid message",
1185
+ "code": WSErrorCode.VALIDATION_ERROR,
1186
+ "errors": e.errors(),
1187
+ }
1188
+ )
1189
+ await websocket.send_text(error_resp.model_dump_json())
1190
+ except Exception as e:
1191
+ error_resp = WSErrorResponse(
1192
+ data={
1193
+ "message": str(e),
1194
+ "code": WSErrorCode.EXECUTION_ERROR,
1195
+ }
1196
+ )
1197
+ await websocket.send_text(error_resp.model_dump_json())
1198
+
1199
+ except WebSocketDisconnect:
1200
+ pass
1201
+ except SessionCapacityError as e:
1202
+ error_resp = WSErrorResponse(
1203
+ data={
1204
+ "message": str(e),
1205
+ "code": WSErrorCode.CAPACITY_REACHED,
1206
+ "active_sessions": e.active_sessions,
1207
+ "max_sessions": e.max_sessions,
1208
+ }
1209
+ )
1210
+ await websocket.send_text(error_resp.model_dump_json())
1211
+ except EnvironmentFactoryError as e:
1212
+ error_resp = WSErrorResponse(
1213
+ data={
1214
+ "message": str(e),
1215
+ "code": WSErrorCode.FACTORY_ERROR,
1216
+ "factory_name": e.factory_name,
1217
+ }
1218
+ )
1219
+ await websocket.send_text(error_resp.model_dump_json())
1220
+ except Exception as e:
1221
+ error_resp = WSErrorResponse(
1222
+ data={"message": str(e), "code": WSErrorCode.SESSION_ERROR}
1223
+ )
1224
+ await websocket.send_text(error_resp.model_dump_json())
1225
+ finally:
1226
+ if session_id:
1227
+ await self._destroy_session(session_id)
1228
+ try:
1229
+ await websocket.close()
1230
+ except RuntimeError:
1231
+ pass
1232
 
 
 
 
 
 
 
1233
 
1234
  def create_app(
1235
+ env: Callable[[], Environment],
1236
  action_cls: Type[Action],
1237
  observation_cls: Type[Observation],
1238
  env_name: Optional[str] = None,
1239
+ max_concurrent_envs: Optional[int] = None,
1240
+ concurrency_config: Optional[ConcurrencyConfig] = None,
1241
+ gradio_builder: Optional[Callable[..., Any]] = None,
1242
+ ) -> FastAPI:
1243
  """
1244
  Create a FastAPI application with or without web interface.
1245
+
1246
  This function creates a FastAPI app with the web interface enabled by default,
1247
  including README integration for better user experience.
1248
+
1249
  Args:
1250
+ env: Environment factory (callable) that creates new instances
1251
  action_cls: The Action subclass this environment expects
1252
  observation_cls: The Observation subclass this environment returns
1253
  env_name: Optional environment name for README loading
1254
+ max_concurrent_envs: Maximum concurrent WebSocket sessions.
1255
+ Mutually exclusive with concurrency_config.
1256
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
1257
+ Mutually exclusive with max_concurrent_envs.
1258
+ gradio_builder: Optional callable to build a custom Gradio UI at /web.
1259
+ Signature: (web_manager, action_fields, metadata, is_chat_env, title,
1260
+ quick_start_md) -> gr.Blocks. When None, the default Gradio app is used.
1261
+ See docs/customizing-web-ui.md.
1262
+
1263
  Returns:
1264
  FastAPI application instance with or without web interface and README integration
1265
  """
1266
  # Check if web interface should be enabled
1267
  # This can be controlled via environment variable or build argument
1268
+ enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in (
1269
+ "true",
1270
+ "1",
1271
+ "yes",
1272
  )
1273
 
1274
  if enable_web:
1275
+ # Gradio-based web UI (gradio is a core dependency)
1276
  from .web_interface import create_web_interface_app
1277
+
1278
+ return create_web_interface_app(
1279
+ env,
1280
+ action_cls,
1281
+ observation_cls,
1282
+ env_name,
1283
+ max_concurrent_envs,
1284
+ concurrency_config,
1285
+ gradio_builder=gradio_builder,
1286
+ )
1287
  else:
1288
  # Use standard FastAPI app without web interface
1289
+ return create_fastapi_app(
1290
+ env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
1291
+ )
1292
+
1293
 
1294
  def create_fastapi_app(
1295
+ env: Callable[[], Environment],
1296
  action_cls: Type[Action],
1297
  observation_cls: Type[Observation],
1298
+ max_concurrent_envs: Optional[int] = None,
1299
+ concurrency_config: Optional[ConcurrencyConfig] = None,
1300
+ ) -> FastAPI:
1301
  """
1302
+ Create a FastAPI application with comprehensive documentation.
1303
 
1304
  Args:
1305
+ env: Environment factory (callable) that creates new instances
1306
  action_cls: The Action subclass this environment expects
1307
  observation_cls: The Observation subclass this environment returns
1308
+ max_concurrent_envs: Maximum concurrent WebSocket sessions.
1309
+ Mutually exclusive with concurrency_config.
1310
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
1311
+ Mutually exclusive with max_concurrent_envs.
1312
 
1313
  Returns:
1314
+ FastAPI application instance
 
 
 
 
 
 
 
 
 
1315
  """
1316
  try:
1317
  from fastapi import FastAPI
 
1320
  "FastAPI is required. Install with: pip install fastapi uvicorn"
1321
  )
1322
 
1323
+ app = FastAPI(
1324
+ title="OpenEnv Environment HTTP API",
1325
+ version="1.0.0",
1326
+ description="""
1327
+ # OpenEnv Environment HTTP API
1328
+
1329
+ HTTP API for interacting with OpenEnv environments through a standardized interface.
1330
+
1331
+ ## Features
1332
+
1333
+ * **Environment Reset**: Initialize or restart episodes
1334
+ * **Action Execution**: Send actions and receive observations
1335
+ * **State Inspection**: Query current environment state
1336
+ * **Schema Access**: Retrieve JSON schemas for actions and observations
1337
+
1338
+ ## Workflow
1339
+
1340
+ 1. Call `/reset` to start a new episode and get initial observation
1341
+ 2. Call `/step` repeatedly with actions to interact with environment
1342
+ 3. Episode ends when observation returns `done: true`
1343
+ 4. Call `/state` anytime to inspect current environment state
1344
+
1345
+ ## Documentation
1346
+
1347
+ * **Swagger UI**: Available at `/docs`
1348
+ * **ReDoc**: Available at `/redoc`
1349
+ * **OpenAPI Schema**: Available at `/openapi.json`
1350
+ """,
1351
+ openapi_tags=[
1352
+ {
1353
+ "name": "Environment Control",
1354
+ "description": "Core operations for environment interaction (reset, step)",
1355
+ },
1356
+ {
1357
+ "name": "State Management",
1358
+ "description": "Operations for inspecting environment state",
1359
+ },
1360
+ {
1361
+ "name": "Environment Info",
1362
+ "description": "Information about the environment",
1363
+ },
1364
+ {
1365
+ "name": "Schema",
1366
+ "description": "JSON Schema endpoints for actions, observations, and state",
1367
+ },
1368
+ {"name": "Health", "description": "Service health and status checks"},
1369
+ ],
1370
+ docs_url="/docs",
1371
+ redoc_url="/redoc",
1372
+ openapi_url="/openapi.json",
1373
+ contact={
1374
+ "name": "OpenEnv Team",
1375
+ "url": "https://github.com/meta-pytorch/OpenEnv",
1376
+ },
1377
+ license_info={
1378
+ "name": "BSD-3-Clause",
1379
+ "url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE",
1380
+ },
1381
+ )
1382
+
1383
+ server = HTTPEnvServer(
1384
+ env,
1385
+ action_cls,
1386
+ observation_cls,
1387
+ max_concurrent_envs,
1388
+ concurrency_config=concurrency_config,
1389
+ )
1390
  server.register_routes(app)
1391
  return app
src/core/env_server/interfaces.py CHANGED
@@ -4,10 +4,20 @@
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
7
  from abc import ABC, abstractmethod
8
- from typing import Any, Protocol, TypedDict
9
 
10
- from .types import Action, Observation, State
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  class Message(TypedDict):
@@ -64,7 +74,7 @@ class ModelTokenizer(Protocol):
64
  ...
65
 
66
 
67
- class Transform(ABC):
68
  """Transform observations to add rewards, metrics, or other modifications.
69
 
70
  Transforms follow the TorchRL pattern where they take an observation
@@ -73,7 +83,7 @@ class Transform(ABC):
73
  """
74
 
75
  @abstractmethod
76
- def __call__(self, observation: Observation) -> Observation:
77
  """Transform an observation.
78
 
79
  Args:
@@ -85,34 +95,203 @@ class Transform(ABC):
85
  pass
86
 
87
 
88
- class Environment(ABC):
89
  """Base class for all environment servers following Gym/Gymnasium API.
90
 
91
  Args:
92
  transform: Optional transform to apply to observations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  """
94
 
95
- def __init__(self, transform: Transform | None = None):
 
 
 
 
 
 
 
 
 
 
96
  self.transform = transform
 
97
 
98
  @abstractmethod
99
- def reset(self) -> Observation:
 
 
 
 
 
100
  """Reset the environment and return initial observation."""
101
  pass
102
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  @abstractmethod
104
- def step(self, action: Action) -> Observation:
 
 
 
 
 
105
  """Take a step in the environment."""
106
  pass
107
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  @property
109
  @abstractmethod
110
- def state(self) -> State:
111
  """Get the current environment state."""
112
  pass
113
 
114
- def _apply_transform(self, observation: Observation) -> Observation:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  """Apply transform if one is provided."""
116
  if self.transform is not None:
117
  return self.transform(observation)
118
  return observation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ import inspect
8
  from abc import ABC, abstractmethod
9
+ from typing import Any, Generic, Optional, Protocol, TYPE_CHECKING, TypeVar
10
 
11
+ from typing_extensions import TypedDict
12
+
13
+ from .types import Action, EnvironmentMetadata, Observation, State
14
+
15
+ if TYPE_CHECKING:
16
+ from openenv.core.rubrics import Rubric
17
+
18
+ ActT = TypeVar("ActT", bound=Action)
19
+ ObsT = TypeVar("ObsT", bound=Observation)
20
+ StateT = TypeVar("StateT", bound=State)
21
 
22
 
23
  class Message(TypedDict):
 
74
  ...
75
 
76
 
77
+ class Transform(ABC, Generic[ObsT]):
78
  """Transform observations to add rewards, metrics, or other modifications.
79
 
80
  Transforms follow the TorchRL pattern where they take an observation
 
83
  """
84
 
85
  @abstractmethod
86
+ def __call__(self, observation: ObsT) -> ObsT:
87
  """Transform an observation.
88
 
89
  Args:
 
95
  pass
96
 
97
 
98
+ class Environment(ABC, Generic[ActT, ObsT, StateT]):
99
  """Base class for all environment servers following Gym/Gymnasium API.
100
 
101
  Args:
102
  transform: Optional transform to apply to observations
103
+ rubric: Optional rubric for reward computation. When provided, the
104
+ rubric's output can be used to set the observation's reward in step().
105
+
106
+ Class Attributes:
107
+ SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions.
108
+ When True, multiple WebSocket connections can each have their own
109
+ environment instance (up to max_concurrent_envs). When False (default),
110
+ the environment should only be used with a single session at a time.
111
+
112
+ Set this to True in your Environment subclass if:
113
+ - The environment uses proper session isolation (e.g., unique working dirs)
114
+ - No shared mutable state exists between instances
115
+ - External resources (databases, APIs) can handle concurrent access
116
+
117
+ Attributes:
118
+ rubric: Optional rubric for computing rewards. Environments can set this
119
+ in __init__ and use it in step() to compute observation rewards.
120
+ Training infrastructure can access it for introspection:
121
+ for name, r in env.rubric.named_rubrics():
122
+ print(f"{name}: {r.last_score}")
123
+
124
+ See RFC 004 for rubric design: rfcs/004-rubrics.md
125
  """
126
 
127
+ # Class-level flag indicating whether this environment supports concurrent sessions
128
+ SUPPORTS_CONCURRENT_SESSIONS: bool = False
129
+
130
+ # Optional rubric for reward computation
131
+ rubric: Optional["Rubric"]
132
+
133
+ def __init__(
134
+ self,
135
+ transform: Optional[Transform[ObsT]] = None,
136
+ rubric: Optional["Rubric"] = None,
137
+ ):
138
  self.transform = transform
139
+ self.rubric = rubric
140
 
141
  @abstractmethod
142
+ def reset(
143
+ self,
144
+ seed: Optional[int] = None,
145
+ episode_id: Optional[str] = None,
146
+ **kwargs: Any,
147
+ ) -> ObsT:
148
  """Reset the environment and return initial observation."""
149
  pass
150
 
151
+ async def reset_async(
152
+ self,
153
+ seed: Optional[int] = None,
154
+ episode_id: Optional[str] = None,
155
+ **kwargs: Any,
156
+ ) -> ObsT:
157
+ """Async version of reset. Default implementation calls sync reset.
158
+
159
+ Override to provide true async implementation.
160
+ """
161
+ return self.reset(seed=seed, episode_id=episode_id, **kwargs)
162
+
163
  @abstractmethod
164
+ def step(
165
+ self,
166
+ action: ActT,
167
+ timeout_s: Optional[float] = None,
168
+ **kwargs: Any,
169
+ ) -> ObsT:
170
  """Take a step in the environment."""
171
  pass
172
 
173
+ async def step_async(
174
+ self,
175
+ action: ActT,
176
+ timeout_s: Optional[float] = None,
177
+ **kwargs: Any,
178
+ ) -> ObsT:
179
+ """Async version of step. Default implementation calls sync step.
180
+
181
+ Override to provide true async implementation.
182
+ """
183
+ return self.step(action, timeout_s=timeout_s, **kwargs)
184
+
185
  @property
186
  @abstractmethod
187
+ def state(self) -> StateT:
188
  """Get the current environment state."""
189
  pass
190
 
191
+ def get_metadata(self) -> EnvironmentMetadata:
192
+ """
193
+ Get metadata about this environment.
194
+
195
+ Override this method to provide custom metadata for the environment.
196
+ Default implementation returns basic metadata derived from class name.
197
+
198
+ Returns:
199
+ EnvironmentMetadata with environment information
200
+ """
201
+ return EnvironmentMetadata(
202
+ name=self.__class__.__name__,
203
+ description=f"{self.__class__.__name__} environment",
204
+ version="1.0.0",
205
+ )
206
+
207
+ def _apply_transform(self, observation: ObsT) -> ObsT:
208
  """Apply transform if one is provided."""
209
  if self.transform is not None:
210
  return self.transform(observation)
211
  return observation
212
+
213
+ def _apply_rubric(self, action: ActT, observation: ObsT) -> float:
214
+ """Apply rubric if one is provided.
215
+
216
+ Args:
217
+ action: The action taken by the agent.
218
+ observation: The resulting observation.
219
+
220
+ Returns:
221
+ Reward value from the rubric, or 0.0 if no rubric is set.
222
+
223
+ Usage in step():
224
+ def step(self, action: MyAction, ...) -> MyObservation:
225
+ # ... execute action and create observation ...
226
+ observation.reward = self._apply_rubric(action, observation)
227
+ return observation
228
+ """
229
+ if self.rubric is not None:
230
+ return self.rubric(action, observation)
231
+ return 0.0
232
+
233
+ async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float:
234
+ """Apply rubric asynchronously if one is provided.
235
+
236
+ Args:
237
+ action: The action taken by the agent.
238
+ observation: The resulting observation.
239
+
240
+ Returns:
241
+ Reward value from the rubric, or 0.0 if no rubric is set.
242
+
243
+ Usage in step_async():
244
+ async def step_async(self, action: MyAction, ...) -> MyObservation:
245
+ # ... execute action and create observation ...
246
+ observation.reward = await self._apply_rubric_async(action, observation)
247
+ return observation
248
+ """
249
+ if self.rubric is not None:
250
+ result = self.rubric(action, observation)
251
+ # If rubric returns a coroutine, await it
252
+ if inspect.iscoroutine(result):
253
+ return await result
254
+ return result
255
+ return 0.0
256
+
257
+ def _reset_rubric(self) -> None:
258
+ """Reset the rubric state if one is provided.
259
+
260
+ Call this in reset() to clear any trajectory state in the rubric.
261
+
262
+ Usage in reset():
263
+ def reset(self, ...) -> MyObservation:
264
+ self._reset_rubric()
265
+ # ... create initial observation ...
266
+ return observation
267
+ """
268
+ if self.rubric is not None:
269
+ self.rubric.reset()
270
+
271
+ async def _reset_rubric_async(self) -> None:
272
+ """Reset the rubric state asynchronously if one is provided.
273
+
274
+ Call this in reset_async() to clear any trajectory state in the rubric.
275
+
276
+ Usage in reset_async():
277
+ async def reset_async(self, ...) -> MyObservation:
278
+ await self._reset_rubric_async()
279
+ # ... create initial observation ...
280
+ return observation
281
+ """
282
+ if self.rubric is not None:
283
+ # Check if rubric has async reset method
284
+ if hasattr(self.rubric, "reset_async"):
285
+ result = self.rubric.reset_async()
286
+ if inspect.iscoroutine(result):
287
+ await result
288
+ else:
289
+ self.rubric.reset()
290
+
291
+ def close(self) -> None:
292
+ """Clean up resources used by the environment.
293
+
294
+ Override this method to implement custom cleanup logic.
295
+ Called when the environment is being destroyed or reset.
296
+ """
297
+ pass
src/core/env_server/mcp_environment.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ MCP Environment base class for OpenEnv.
9
+
10
+ This module provides the MCPEnvironment base class that integrates FastMCP servers
11
+ with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery
12
+ and invocation through the step() API, following RFC 003.
13
+
14
+ Key features:
15
+ - Automatic routing of ListToolsAction and CallToolAction to MCP server
16
+ - Reserved tool name validation (reset, step, state, close are protected)
17
+ - Timeout handling for tool calls
18
+ - Proper error categorization (tool not found, execution errors, timeouts)
19
+ - Mode-aware tool registration (production vs simulation)
20
+ - Code mode support via get_callables() and execute_code()
21
+
22
+ Usage:
23
+ from fastmcp import FastMCP
24
+ from openenv.core.env_server.mcp_environment import MCPEnvironment
25
+
26
+ class MyMCPEnv(MCPEnvironment):
27
+ def __init__(self):
28
+ mcp = FastMCP("my-server")
29
+
30
+ # Register mode-specific tools
31
+ @self.tool(mode="production")
32
+ def my_tool(arg: str) -> str:
33
+ return f"Production: {arg}"
34
+
35
+ @self.tool(mode="simulation")
36
+ def my_tool(arg: str) -> str:
37
+ return f"Simulation: {arg}"
38
+
39
+ super().__init__(mcp)
40
+
41
+ def reset(self, seed=None, episode_id=None, **kwargs):
42
+ # Reset logic here
43
+ ...
44
+
45
+ def _step_impl(self, action):
46
+ # Handle non-MCP actions
47
+ ...
48
+
49
+ @property
50
+ def state(self):
51
+ # Return current state
52
+ ...
53
+ """
54
+
55
+ import asyncio
56
+ import inspect
57
+ from abc import abstractmethod
58
+ from collections import defaultdict
59
+ from typing import Any, Callable, Dict, Optional
60
+
61
+ from fastmcp import Client
62
+ from fastmcp.client.client import CallToolResult
63
+ from mcp.types import TextContent
64
+
65
+ from ..utils import run_async_safely
66
+ from .interfaces import Environment
67
+ from .mcp_types import (
68
+ CallToolAction,
69
+ CallToolObservation,
70
+ ListToolsAction,
71
+ ListToolsObservation,
72
+ RESERVED_TOOL_NAMES,
73
+ Tool,
74
+ ToolError,
75
+ ToolErrorType,
76
+ )
77
+ from .types import Action, Observation
78
+
79
+
80
+ # Default timeout for MCP tool calls in seconds
81
+ MCP_TOOL_CALL_TIMEOUT = 30.0
82
+
83
+ # Valid modes for tool registration
84
+ VALID_MODES = {"production", "simulation"}
85
+
86
+
87
+ def get_server_tools(mcp_server: Any) -> Dict[str, Any]:
88
+ """
89
+ Get tools from a FastMCP server, compatible with both 2.x and 3.x.
90
+
91
+ Returns:
92
+ Dictionary mapping tool names to tool objects.
93
+ """
94
+ # FastMCP 2.x: get_tools() returns dict {name: Tool}
95
+ if hasattr(mcp_server, "get_tools"):
96
+ result = run_async_safely(mcp_server.get_tools())
97
+ if isinstance(result, dict):
98
+ return result
99
+ # FastMCP 3.x: list_tools() returns list of Tool objects
100
+ if hasattr(mcp_server, "list_tools"):
101
+ tools_list = run_async_safely(mcp_server.list_tools())
102
+ return {t.name: t for t in tools_list}
103
+ return {}
104
+
105
+
106
+ class MCPEnvironment(Environment):
107
+ """
108
+ Base class for environments that expose tools via MCP (Model Context Protocol).
109
+
110
+ MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing
111
+ agents to discover and invoke MCP tools through the standard step() interface.
112
+
113
+ The class automatically handles:
114
+ - ListToolsAction: Returns available tools from the MCP server
115
+ - CallToolAction: Invokes a specific tool with arguments
116
+
117
+ All other actions are delegated to the abstract _step_impl() method,
118
+ which subclasses must implement.
119
+
120
+ Args:
121
+ mcp_server: A FastMCP server instance containing tool definitions.
122
+ The server's tools will be validated against reserved names.
123
+ transform: Optional transform to apply to observations (inherited from Environment).
124
+
125
+ Raises:
126
+ ValueError: If any tool in the MCP server uses a reserved name
127
+ (reset, step, state, close).
128
+
129
+ Example:
130
+ >>> from fastmcp import FastMCP
131
+ >>> mcp = FastMCP("calculator")
132
+ >>> @mcp.tool()
133
+ ... def add(a: int, b: int) -> int:
134
+ ... return a + b
135
+ >>> env = MyMCPEnvironment(mcp)
136
+ >>> obs = env.step(ListToolsAction())
137
+ >>> obs.tools[0].name
138
+ 'add'
139
+ """
140
+
141
+ def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None:
142
+ """
143
+ Initialize the MCP environment.
144
+
145
+ Args:
146
+ mcp_server: A FastMCP server instance with tool definitions.
147
+ transform: Optional transform to apply to observations.
148
+
149
+ Raises:
150
+ ValueError: If any tool uses a reserved name (reset, step, state, close).
151
+ """
152
+ super().__init__(transform=transform)
153
+
154
+ # Validate tool names before storing
155
+ self._validate_tool_names(mcp_server)
156
+
157
+ self.mcp_server = mcp_server
158
+ self.mcp_client = Client(mcp_server)
159
+
160
+ # Track mode-specific tools: {tool_name: {mode: func}}
161
+ # mode can be "production", "simulation", or None (available in all modes)
162
+ self._mode_tools = defaultdict(dict)
163
+
164
+ # Track tool schemas for list_tools: {tool_name: {mode: schema}}
165
+ self._mode_tool_schemas = defaultdict(dict)
166
+
167
+ @property
168
+ def supports_code_mode(self) -> bool:
169
+ """Check if this environment supports code mode (execute_code)."""
170
+ return True
171
+
172
+ def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]:
173
+ """
174
+ Get tools from a FastMCP server, compatible with both 2.x and 3.x.
175
+
176
+ Returns:
177
+ Dictionary mapping tool names to tool objects.
178
+ """
179
+ return get_server_tools(mcp_server)
180
+
181
+ def get_callables(self) -> Dict[str, Callable]:
182
+ """
183
+ Get callable functions for code mode.
184
+
185
+ Returns tool functions as direct Python callables, enabling code mode
186
+ where agents write Python code that calls tools directly (no JSON-RPC
187
+ overhead). Mode-specific tools are filtered by the current mode.
188
+
189
+ Returns:
190
+ Dictionary mapping tool names to callables.
191
+ """
192
+ callables: Dict[str, Callable] = {}
193
+ current_mode = getattr(self, "_mode", None)
194
+
195
+ # Extract callables from FastMCP server using public API
196
+ for tool_name, tool in self._get_server_tools(self.mcp_server).items():
197
+ if hasattr(tool, "fn") and callable(tool.fn):
198
+ callables[tool_name] = tool.fn
199
+
200
+ # Add mode-specific tools available in current mode
201
+ for tool_name, mode_funcs in self._mode_tools.items():
202
+ if None in mode_funcs:
203
+ # Tool available in all modes (already in FastMCP if registered there)
204
+ if tool_name not in callables:
205
+ callables[tool_name] = mode_funcs[None]
206
+ elif current_mode in mode_funcs:
207
+ # Tool available in current mode only
208
+ callables[tool_name] = mode_funcs[current_mode]
209
+
210
+ return callables
211
+
212
+ def execute_code(self, code: str) -> Observation:
213
+ """
214
+ Execute Python code with tools available as callables.
215
+
216
+ This enables the CodeAct pattern where agents write Python code
217
+ that calls tools directly as functions, avoiding JSON-RPC overhead.
218
+
219
+ Args:
220
+ code: Python code to execute. Tools are available as functions
221
+ in the execution namespace. Set a variable named 'result'
222
+ to capture the return value.
223
+
224
+ Returns:
225
+ Observation with result in metadata["result"] or error in
226
+ metadata["error"].
227
+ """
228
+ namespace = self.get_callables()
229
+
230
+ result_dict: Dict[str, Any] = {}
231
+ try:
232
+ exec(code, namespace, result_dict)
233
+ result = result_dict.get("result")
234
+ return Observation(done=False, reward=0.0, metadata={"result": result})
235
+ except SyntaxError as e:
236
+ return Observation(
237
+ done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"}
238
+ )
239
+ except Exception as e:
240
+ return Observation(done=False, reward=0.0, metadata={"error": str(e)})
241
+
242
+ def _validate_tool_names(self, mcp_server: Any) -> None:
243
+ """
244
+ Validate that no tools use reserved names.
245
+
246
+ Reserved names (reset, step, state, close) are protected to maintain
247
+ the dual API boundary between infrastructure and agent APIs.
248
+
249
+ Args:
250
+ mcp_server: The FastMCP server to validate.
251
+
252
+ Raises:
253
+ ValueError: If any tool uses a reserved name.
254
+ """
255
+ tools_dict = self._get_server_tools(mcp_server)
256
+ if tools_dict:
257
+ tool_names = set(tools_dict.keys())
258
+ conflicts = tool_names & RESERVED_TOOL_NAMES
259
+ if conflicts:
260
+ raise ValueError(
261
+ f"MCP tools cannot use reserved names: {sorted(conflicts)}. "
262
+ f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
263
+ )
264
+
265
+ def tool(self, mode: Optional[str] = None) -> Callable:
266
+ """
267
+ Decorator for registering mode-aware tools.
268
+
269
+ Args:
270
+ mode: Optional mode for the tool ("production" or "simulation").
271
+ If None, tool is available in all modes.
272
+
273
+ Returns:
274
+ A decorator function for registering tools.
275
+
276
+ Raises:
277
+ ValueError: If mode is not None, "production", or "simulation".
278
+ """
279
+ if mode is not None and mode not in VALID_MODES:
280
+ raise ValueError(
281
+ f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None."
282
+ )
283
+
284
+ def decorator(func: Callable) -> Callable:
285
+ tool_name = func.__name__
286
+ # Validate tool name is not reserved
287
+ if tool_name in RESERVED_TOOL_NAMES:
288
+ raise ValueError(
289
+ f"Tool name '{tool_name}' is reserved and cannot be used. "
290
+ f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
291
+ )
292
+
293
+ # If mode is None, register with FastMCP as usual
294
+ if mode is None:
295
+ decorated_func = self.mcp_server.tool()(func)
296
+ self._mode_tools[tool_name][None] = func
297
+ return decorated_func
298
+
299
+ # For mode-specific tools, don't register with FastMCP
300
+ # Instead, track them ourselves
301
+ self._mode_tools[tool_name][mode] = func
302
+
303
+ # Extract schema information from function signature
304
+ sig = inspect.signature(func)
305
+ schema = {
306
+ "type": "object",
307
+ "properties": {},
308
+ "required": [],
309
+ }
310
+
311
+ for param_name, param in sig.parameters.items():
312
+ # Get type annotation
313
+ param_type = param.annotation
314
+ json_type = "string" # default
315
+ if param_type in (int, "int"):
316
+ json_type = "integer"
317
+ elif param_type in (float, "float"):
318
+ json_type = "number"
319
+ elif param_type in (bool, "bool"):
320
+ json_type = "boolean"
321
+
322
+ schema["properties"][param_name] = {"type": json_type}
323
+
324
+ # If no default value, it's required
325
+ if param.default == inspect.Parameter.empty:
326
+ schema["required"].append(param_name)
327
+
328
+ # Store the schema for this mode-specific tool
329
+ self._mode_tool_schemas[tool_name][mode] = {
330
+ "name": tool_name,
331
+ "description": func.__doc__ or "",
332
+ "input_schema": schema,
333
+ }
334
+
335
+ return func
336
+
337
+ return decorator
338
+
339
+ def step(
340
+ self,
341
+ action: Action,
342
+ timeout_s: Optional[float] = None,
343
+ **kwargs: Any,
344
+ ) -> Observation:
345
+ """
346
+ Execute an action in the environment.
347
+
348
+ This method routes MCP-specific actions (ListToolsAction, CallToolAction)
349
+ to the appropriate handlers, while delegating all other actions to
350
+ the subclass's _step_impl() method.
351
+
352
+ Args:
353
+ action: The action to execute. Can be:
354
+ - ListToolsAction: Returns available MCP tools
355
+ - CallToolAction: Invokes a specific MCP tool
356
+ - Any other Action: Delegated to _step_impl()
357
+ timeout_s: Optional timeout in seconds for the action.
358
+ Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions.
359
+ **kwargs: Additional arguments passed to handlers.
360
+
361
+ Returns:
362
+ Observation appropriate to the action type:
363
+ - ListToolsObservation for ListToolsAction
364
+ - CallToolObservation for CallToolAction
365
+ - Subclass-defined Observation for other actions
366
+ """
367
+ if isinstance(action, ListToolsAction):
368
+ return self._handle_list_tools()
369
+ elif isinstance(action, CallToolAction):
370
+ return self._handle_call_tool(action, timeout_s=timeout_s)
371
+ else:
372
+ return self._step_impl(action, timeout_s=timeout_s, **kwargs)
373
+
374
+ def _handle_list_tools(self) -> ListToolsObservation:
375
+ """
376
+ Handle a ListToolsAction by querying the MCP server.
377
+
378
+ Returns:
379
+ ListToolsObservation containing all available tools with their
380
+ names, descriptions, and input schemas, filtered by current mode.
381
+ """
382
+ try:
383
+ # Get current mode
384
+ current_mode = getattr(self, "_mode", None)
385
+
386
+ # Start with tools from FastMCP server (mode=None tools)
387
+ tools_result = run_async_safely(self._async_list_tools())
388
+
389
+ # Build list of Tool objects
390
+ tools = []
391
+
392
+ # Add FastMCP tools that are not mode-specific
393
+ for tool in tools_result:
394
+ if tool.name not in self._mode_tool_schemas:
395
+ tools.append(
396
+ Tool(
397
+ name=tool.name,
398
+ description=tool.description or "",
399
+ input_schema=tool.inputSchema
400
+ if hasattr(tool, "inputSchema")
401
+ else {},
402
+ )
403
+ )
404
+
405
+ # Add mode-specific tools available in current mode
406
+ for tool_name, mode_schemas in self._mode_tool_schemas.items():
407
+ if None in mode_schemas:
408
+ # Tool available in all modes
409
+ schema = mode_schemas[None]
410
+ tools.append(
411
+ Tool(
412
+ name=schema["name"],
413
+ description=schema["description"],
414
+ input_schema=schema["input_schema"],
415
+ )
416
+ )
417
+ elif current_mode in mode_schemas:
418
+ # Tool available in current mode
419
+ schema = mode_schemas[current_mode]
420
+ tools.append(
421
+ Tool(
422
+ name=schema["name"],
423
+ description=schema["description"],
424
+ input_schema=schema["input_schema"],
425
+ )
426
+ )
427
+
428
+ return ListToolsObservation(tools=tools)
429
+
430
+ except Exception as e:
431
+ # Return an observation with error in metadata
432
+ return ListToolsObservation(
433
+ tools=[],
434
+ metadata={
435
+ "error": str(e),
436
+ "error_type": "list_tools_failed",
437
+ },
438
+ )
439
+
440
+ async def _async_list_tools(self) -> list:
441
+ """
442
+ Async helper to list tools from the MCP client.
443
+
444
+ Returns:
445
+ List of tool objects from the MCP server.
446
+ """
447
+ async with self.mcp_client:
448
+ return await self.mcp_client.list_tools()
449
+
450
+ def _handle_call_tool(
451
+ self,
452
+ action: CallToolAction,
453
+ timeout_s: Optional[float] = None,
454
+ ) -> CallToolObservation:
455
+ """
456
+ Handle a CallToolAction by invoking the specified tool.
457
+
458
+ Args:
459
+ action: The CallToolAction containing tool_name and arguments.
460
+ timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s).
461
+
462
+ Returns:
463
+ CallToolObservation with the tool's result or an error.
464
+ """
465
+ timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
466
+
467
+ # Check if this is a mode-specific tool
468
+ tool_name = action.tool_name
469
+ current_mode = getattr(self, "_mode", None)
470
+
471
+ if tool_name in self._mode_tools:
472
+ mode_info = self._mode_tools[tool_name]
473
+
474
+ # Check if tool is available in current mode
475
+ # Tool is available if:
476
+ # 1. It has a None mode (available in all modes), OR
477
+ # 2. It has an implementation for the current mode
478
+ if None in mode_info:
479
+ # Use the mode-agnostic version
480
+ func = mode_info[None]
481
+ elif current_mode in mode_info:
482
+ # Use the mode-specific version
483
+ func = mode_info[current_mode]
484
+ else:
485
+ # Tool not available in current mode
486
+ return CallToolObservation(
487
+ tool_name=tool_name,
488
+ result=None,
489
+ error=ToolError(
490
+ error_type=ToolErrorType.TOOL_NOT_FOUND,
491
+ message=f"Tool '{tool_name}' not available in {current_mode} mode",
492
+ ),
493
+ )
494
+
495
+ # Call the mode-specific function directly
496
+ try:
497
+ # Check if function is async and await if necessary
498
+ if inspect.iscoroutinefunction(func):
499
+ result = run_async_safely(func(**action.arguments))
500
+ else:
501
+ result = func(**action.arguments)
502
+
503
+ # Wrap result in CallToolResult format to match FastMCP behavior
504
+ return CallToolObservation(
505
+ tool_name=tool_name,
506
+ result=CallToolResult(
507
+ content=[TextContent(type="text", text=str(result))],
508
+ structured_content={"result": result},
509
+ meta=None,
510
+ data=result,
511
+ is_error=False,
512
+ ),
513
+ )
514
+ except Exception as e:
515
+ return CallToolObservation(
516
+ tool_name=tool_name,
517
+ result=None,
518
+ error=ToolError(
519
+ error_type=ToolErrorType.EXECUTION_ERROR,
520
+ message=str(e),
521
+ ),
522
+ )
523
+
524
+ # Not a mode-specific tool, use FastMCP
525
+ try:
526
+ # Run the async call_tool with timeout
527
+ # Use run_async_safely to handle both sync and async contexts
528
+ result = run_async_safely(
529
+ asyncio.wait_for(
530
+ self._async_call_tool(action.tool_name, action.arguments),
531
+ timeout=timeout,
532
+ )
533
+ )
534
+
535
+ return CallToolObservation(
536
+ tool_name=action.tool_name,
537
+ result=result,
538
+ )
539
+
540
+ except asyncio.TimeoutError:
541
+ return CallToolObservation(
542
+ tool_name=action.tool_name,
543
+ result=None,
544
+ error=ToolError(
545
+ error_type=ToolErrorType.TIMEOUT,
546
+ message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
547
+ ),
548
+ )
549
+
550
+ except Exception as e:
551
+ error_message = str(e)
552
+
553
+ # Determine error type based on the exception
554
+ if (
555
+ "not found" in error_message.lower()
556
+ or "unknown tool" in error_message.lower()
557
+ ):
558
+ error_type = ToolErrorType.TOOL_NOT_FOUND
559
+ elif (
560
+ "invalid" in error_message.lower()
561
+ or "argument" in error_message.lower()
562
+ ):
563
+ error_type = ToolErrorType.INVALID_ARGS
564
+ else:
565
+ error_type = ToolErrorType.EXECUTION_ERROR
566
+
567
+ return CallToolObservation(
568
+ tool_name=action.tool_name,
569
+ result=None,
570
+ error=ToolError(
571
+ error_type=error_type,
572
+ message=error_message,
573
+ ),
574
+ )
575
+
576
+ async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
577
+ """
578
+ Async helper to call a tool on the MCP server.
579
+
580
+ Args:
581
+ tool_name: Name of the tool to invoke.
582
+ arguments: Dictionary of arguments to pass to the tool.
583
+
584
+ Returns:
585
+ The result from the tool execution.
586
+ """
587
+ async with self.mcp_client:
588
+ return await self.mcp_client.call_tool(tool_name, arguments)
589
+
590
+ @abstractmethod
591
+ def _step_impl(
592
+ self,
593
+ action: Action,
594
+ timeout_s: Optional[float] = None,
595
+ **kwargs: Any,
596
+ ) -> Observation:
597
+ """
598
+ Handle non-MCP actions in the environment.
599
+
600
+ Subclasses must implement this method to handle any actions that are
601
+ not ListToolsAction or CallToolAction. This is where environment-specific
602
+ action processing should occur.
603
+
604
+ Args:
605
+ action: The action to execute (guaranteed not to be an MCP action).
606
+ timeout_s: Optional timeout in seconds.
607
+ **kwargs: Additional arguments.
608
+
609
+ Returns:
610
+ An Observation appropriate for the action.
611
+ """
612
+ pass
613
+
614
+ def close(self) -> None:
615
+ """
616
+ Clean up resources used by the environment.
617
+
618
+ This method cleans up the MCP client and any other resources.
619
+ Subclasses should call super().close() if they override this method.
620
+ """
621
+ # The MCP client uses async context manager, so cleanup happens
622
+ # automatically when the context exits. We just clear references.
623
+ self.mcp_client = None
624
+ self.mcp_server = None
src/core/env_server/mcp_types.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ MCP (Model Context Protocol) type definitions for OpenEnv.
9
+
10
+ This module defines strongly typed models for MCP tool discovery and invocation,
11
+ following RFC 003. These types map MCP's REST-like API (tools/list, tools/call)
12
+ to Gym-style action types.
13
+
14
+ Key design decisions:
15
+ - Tool discovery (list_tools) does NOT require reset() first
16
+ - Reserved tool names (reset, step, state, close) are prohibited
17
+ - Both step() and WebSocket /mcp paths are supported
18
+ """
19
+
20
+ from enum import Enum
21
+ from typing import Any, Dict, List, Literal, Optional, Union
22
+
23
+ from pydantic import BaseModel, ConfigDict, Field
24
+
25
+ from .types import Action, BaseMessage, Observation
26
+
27
+
28
+ # =============================================================================
29
+ # JSON-RPC 2.0 Types
30
+ # =============================================================================
31
+
32
+
33
+ class JsonRpcErrorCode(int, Enum):
34
+ """
35
+ Standard JSON-RPC 2.0 error codes.
36
+
37
+ See: https://www.jsonrpc.org/specification#error_object
38
+ """
39
+
40
+ # Standard JSON-RPC errors
41
+ PARSE_ERROR = -32700 # Invalid JSON was received
42
+ INVALID_REQUEST = -32600 # JSON is not a valid Request object
43
+ METHOD_NOT_FOUND = -32601 # Method does not exist / is not available
44
+ INVALID_PARAMS = -32602 # Invalid method parameter(s)
45
+ INTERNAL_ERROR = -32603 # Internal JSON-RPC error
46
+
47
+ # Server errors (reserved for implementation-defined errors)
48
+ SERVER_ERROR = -32000 # Generic server error
49
+
50
+
51
+ class McpMethod(str, Enum):
52
+ """Supported MCP method names."""
53
+
54
+ TOOLS_LIST = "tools/list"
55
+ TOOLS_CALL = "tools/call"
56
+
57
+
58
+ class JsonRpcError(BaseModel):
59
+ """
60
+ JSON-RPC 2.0 error object.
61
+
62
+ See: https://www.jsonrpc.org/specification#error_object
63
+ """
64
+
65
+ model_config = ConfigDict(extra="forbid")
66
+
67
+ code: int = Field(description="Error code indicating the error type")
68
+ message: str = Field(description="Short description of the error")
69
+ data: Optional[Any] = Field(
70
+ default=None, description="Additional error information"
71
+ )
72
+
73
+ @classmethod
74
+ def from_code(
75
+ cls, code: JsonRpcErrorCode, message: Optional[str] = None, data: Any = None
76
+ ) -> "JsonRpcError":
77
+ """Create an error from a standard error code."""
78
+ default_messages = {
79
+ JsonRpcErrorCode.PARSE_ERROR: "Parse error",
80
+ JsonRpcErrorCode.INVALID_REQUEST: "Invalid Request",
81
+ JsonRpcErrorCode.METHOD_NOT_FOUND: "Method not found",
82
+ JsonRpcErrorCode.INVALID_PARAMS: "Invalid params",
83
+ JsonRpcErrorCode.INTERNAL_ERROR: "Internal error",
84
+ JsonRpcErrorCode.SERVER_ERROR: "Server error",
85
+ }
86
+ return cls(
87
+ code=code.value,
88
+ message=message or default_messages.get(code, "Unknown error"),
89
+ data=data,
90
+ )
91
+
92
+
93
+ class JsonRpcRequest(BaseModel):
94
+ """
95
+ JSON-RPC 2.0 request object.
96
+
97
+ See: https://www.jsonrpc.org/specification#request_object
98
+ """
99
+
100
+ model_config = ConfigDict(extra="forbid")
101
+
102
+ jsonrpc: Literal["2.0"] = Field(description="JSON-RPC version, must be '2.0'")
103
+ method: str = Field(description="Name of the method to be invoked")
104
+ params: Dict[str, Any] = Field(
105
+ default_factory=dict, description="Parameter values for the method"
106
+ )
107
+ id: Optional[Union[str, int]] = Field(
108
+ default=None, description="Request identifier established by the client"
109
+ )
110
+
111
+
112
+ class JsonRpcResponse(BaseModel):
113
+ """
114
+ JSON-RPC 2.0 response object.
115
+
116
+ Per JSON-RPC 2.0 spec, a response has either 'result' or 'error', not both.
117
+ This model excludes None values during serialization to comply with the spec.
118
+
119
+ See: https://www.jsonrpc.org/specification#response_object
120
+ """
121
+
122
+ model_config = ConfigDict(extra="forbid")
123
+
124
+ jsonrpc: Literal["2.0"] = Field(default="2.0", description="JSON-RPC version")
125
+ result: Optional[Any] = Field(
126
+ default=None, description="Result of the method invocation"
127
+ )
128
+ error: Optional[JsonRpcError] = Field(
129
+ default=None, description="Error object if method invocation failed"
130
+ )
131
+ id: Optional[Union[str, int]] = Field(
132
+ default=None, description="Request identifier from the request"
133
+ )
134
+
135
+ def model_dump(self, **kwargs) -> Dict[str, Any]:
136
+ """Serialize to dict, excluding result or error when None (JSON-RPC compliance)."""
137
+ # Always include jsonrpc and id, but only include result OR error
138
+ data: Dict[str, Any] = {"jsonrpc": self.jsonrpc, "id": self.id}
139
+ if self.error is not None:
140
+ data["error"] = (
141
+ self.error.model_dump()
142
+ if hasattr(self.error, "model_dump")
143
+ else self.error
144
+ )
145
+ else:
146
+ # Only include result if there's no error
147
+ data["result"] = self.result
148
+ return data
149
+
150
+ def model_dump_json(self, **kwargs) -> str:
151
+ """Serialize to JSON string, excluding result or error when None (JSON-RPC compliance)."""
152
+ import json
153
+
154
+ return json.dumps(self.model_dump())
155
+
156
+ @classmethod
157
+ def success(
158
+ cls, result: Any, request_id: Optional[Union[str, int]] = None
159
+ ) -> "JsonRpcResponse":
160
+ """Create a success response."""
161
+ return cls(result=result, id=request_id)
162
+
163
+ @classmethod
164
+ def error_response(
165
+ cls,
166
+ code: JsonRpcErrorCode,
167
+ message: Optional[str] = None,
168
+ data: Any = None,
169
+ request_id: Optional[Union[str, int]] = None,
170
+ ) -> "JsonRpcResponse":
171
+ """Create an error response from a standard error code."""
172
+ return cls(
173
+ error=JsonRpcError.from_code(code, message, data),
174
+ id=request_id,
175
+ )
176
+
177
+
178
+ # =============================================================================
179
+ # MCP Tool Types
180
+ # =============================================================================
181
+
182
+
183
+ class Tool(BaseModel):
184
+ """
185
+ Strongly typed MCP tool specification.
186
+
187
+ Follows the MCP ToolSpec format for tool discovery.
188
+ See: https://modelcontextprotocol.io/specification/2025-06-18/server/tools
189
+ """
190
+
191
+ model_config = ConfigDict(extra="forbid")
192
+
193
+ name: str = Field(description="Unique identifier for the tool")
194
+ description: str = Field(
195
+ description="Human-readable description of what the tool does"
196
+ )
197
+ input_schema: Dict[str, Any] = Field(
198
+ description="JSON Schema for the tool's input parameters"
199
+ )
200
+
201
+
202
+ class ToolErrorType(str, Enum):
203
+ """Types of errors that can occur during tool execution."""
204
+
205
+ EXECUTION_ERROR = "execution_error" # Tool ran but failed
206
+ INVALID_ARGS = "invalid_args" # Invalid arguments provided
207
+ TRANSPORT_ERROR = "transport_error" # Communication failure
208
+ TOOL_NOT_FOUND = "tool_not_found" # Tool doesn't exist
209
+ TIMEOUT = "timeout" # Operation timed out
210
+
211
+
212
+ class ToolError(BaseModel):
213
+ """
214
+ Structured error for tool execution failures.
215
+
216
+ This is used for transport/framework errors, NOT for errors returned
217
+ by the tool itself (those go in the result field).
218
+ """
219
+
220
+ model_config = ConfigDict(extra="forbid")
221
+
222
+ error_type: ToolErrorType = Field(description="Category of the error")
223
+ message: str = Field(description="Human-readable error message")
224
+
225
+
226
+ # --- MCP Actions ---
227
+
228
+
229
+ class ListToolsAction(Action):
230
+ """
231
+ Request list of available tools from the environment.
232
+
233
+ This action triggers MCP's tools/list operation and returns
234
+ all available tools with their schemas.
235
+
236
+ Note: Does NOT require reset() to be called first.
237
+ """
238
+
239
+ type: Literal["list_tools"] = Field(
240
+ default="list_tools", description="Action type discriminator"
241
+ )
242
+
243
+
244
+ class CallToolAction(Action):
245
+ """
246
+ Call a specific tool via MCP.
247
+
248
+ This action triggers MCP's tools/call operation with the
249
+ specified tool name and arguments.
250
+ """
251
+
252
+ type: Literal["call_tool"] = Field(
253
+ default="call_tool", description="Action type discriminator"
254
+ )
255
+ tool_name: str = Field(description="Name of the tool to call")
256
+ arguments: Dict[str, Any] = Field(
257
+ default_factory=dict, description="Arguments to pass to the tool"
258
+ )
259
+
260
+
261
+ # --- MCP Observations ---
262
+
263
+
264
+ class ListToolsObservation(Observation):
265
+ """
266
+ Response containing available tools.
267
+
268
+ Returned when processing a ListToolsAction.
269
+ """
270
+
271
+ tools: List[Tool] = Field(description="List of available tools with their schemas")
272
+
273
+
274
+ class CallToolObservation(Observation):
275
+ """
276
+ Response from tool execution.
277
+
278
+ Contains the tool's result or an error if the call failed.
279
+ Tool-specific errors (from the tool itself) are included in the result.
280
+ Transport/framework errors use the error field.
281
+ """
282
+
283
+ tool_name: str = Field(description="Name of the tool that was called")
284
+ result: Any = Field(
285
+ default=None, description="Tool-specific result (may include tool errors)"
286
+ )
287
+ error: Optional[ToolError] = Field(
288
+ default=None, description="Transport/framework error if call failed"
289
+ )
290
+
291
+
292
+ # --- WebSocket Message Types for MCP ---
293
+
294
+
295
+ class WSMCPMessage(BaseMessage):
296
+ """
297
+ WebSocket message for MCP JSON-RPC requests.
298
+
299
+ Allows direct MCP access via WebSocket for production inference,
300
+ bypassing the step() API.
301
+ """
302
+
303
+ type: Literal["mcp"] = Field(default="mcp", description="Message type")
304
+ data: Dict[str, Any] = Field(description="JSON-RPC payload (method, params, id)")
305
+
306
+
307
+ class WSMCPResponse(BaseModel):
308
+ """
309
+ WebSocket response for MCP JSON-RPC.
310
+
311
+ Contains the JSON-RPC response from the MCP server.
312
+ """
313
+
314
+ model_config = ConfigDict(extra="forbid")
315
+
316
+ type: str = Field(default="mcp", description="Response type")
317
+ data: Dict[str, Any] = Field(description="JSON-RPC response payload")
318
+
319
+
320
+ # Reserved tool names that cannot be used (protects dual API boundary)
321
+ RESERVED_TOOL_NAMES = frozenset(["reset", "step", "state", "close"])
src/core/env_server/route_config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Route configuration utilities for declarative FastAPI route registration.
9
+
10
+ This module provides utilities to reduce boilerplate in route registration
11
+ by using configuration objects instead of repeated function calls.
12
+ """
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Callable, List, Type
16
+
17
+ from fastapi import FastAPI
18
+ from pydantic import BaseModel
19
+
20
+
21
+ @dataclass
22
+ class GetEndpointConfig:
23
+ """Configuration for a simple GET endpoint."""
24
+
25
+ path: str
26
+ handler: Callable[[], BaseModel | dict]
27
+ response_model: Type[BaseModel] | type[dict]
28
+ tag: str
29
+ summary: str
30
+ description: str
31
+
32
+
33
+ def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None:
34
+ """
35
+ Register multiple GET endpoints from configuration.
36
+
37
+ Args:
38
+ app: FastAPI application instance
39
+ configs: List of GET endpoint configurations
40
+ """
41
+ for config in configs:
42
+ # Capture handler in a closure to avoid non-serializable default parameter
43
+ def make_endpoint(
44
+ handler: Callable[[], BaseModel | dict],
45
+ ) -> Callable[[], BaseModel | dict]:
46
+ async def endpoint() -> BaseModel | dict:
47
+ return handler()
48
+
49
+ return endpoint
50
+
51
+ app.get(
52
+ config.path,
53
+ response_model=config.response_model,
54
+ tags=[config.tag],
55
+ summary=config.summary,
56
+ description=config.description,
57
+ )(make_endpoint(config.handler))
src/core/env_server/serialization.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Shared serialization and deserialization utilities for OpenEnv HTTP servers.
9
+
10
+ This module provides common utilities for converting between JSON dictionaries
11
+ and Pydantic models (Action/Observation) to eliminate code duplication across
12
+ HTTP server and web interface implementations.
13
+ """
14
+
15
+ from typing import Any, Dict, Type
16
+
17
+ from .types import Action, Observation
18
+
19
+
20
+ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
21
+ """
22
+ Convert JSON dict to Action instance using Pydantic validation.
23
+
24
+ This is a basic deserialization that works for most environments.
25
+ For special cases (e.g., tensor fields, custom type conversions),
26
+ use deserialize_action_with_preprocessing().
27
+
28
+ Args:
29
+ action_data: Dictionary containing action data
30
+ action_cls: The Action subclass to instantiate
31
+
32
+ Returns:
33
+ Action instance
34
+
35
+ Raises:
36
+ ValidationError: If action_data is invalid for the action class
37
+
38
+ Note:
39
+ This uses Pydantic's model_validate() for automatic validation.
40
+ """
41
+ return action_cls.model_validate(action_data)
42
+
43
+
44
+ def deserialize_action_with_preprocessing(
45
+ action_data: Dict[str, Any], action_cls: Type[Action]
46
+ ) -> Action:
47
+ """
48
+ Convert JSON dict to Action instance with preprocessing for special types.
49
+
50
+ This version handles common type conversions needed for web interfaces:
51
+ - Converting lists/strings to tensors for 'tokens' field
52
+ - Converting string action_id to int
53
+ - Other custom preprocessing as needed
54
+
55
+ Args:
56
+ action_data: Dictionary containing action data
57
+ action_cls: The Action subclass to instantiate
58
+
59
+ Returns:
60
+ Action instance
61
+
62
+ Raises:
63
+ ValidationError: If action_data is invalid for the action class
64
+ """
65
+ processed_data = {}
66
+
67
+ for key, value in action_data.items():
68
+ if key == "tokens" and isinstance(value, (list, str)):
69
+ # Convert list or string to tensor
70
+ if isinstance(value, str):
71
+ # If it's a string, try to parse it as a list of numbers
72
+ try:
73
+ import json
74
+
75
+ value = json.loads(value)
76
+ except Exception:
77
+ # If parsing fails, treat as empty list
78
+ value = []
79
+ if isinstance(value, list):
80
+ try:
81
+ import torch # type: ignore
82
+
83
+ processed_data[key] = torch.tensor(value, dtype=torch.long)
84
+ except ImportError:
85
+ # If torch not available, keep as list
86
+ processed_data[key] = value
87
+ else:
88
+ processed_data[key] = value
89
+ elif key == "action_id" and isinstance(value, str):
90
+ # Convert action_id from string to int
91
+ try:
92
+ processed_data[key] = int(value)
93
+ except ValueError:
94
+ # If conversion fails, keep original value
95
+ processed_data[key] = value
96
+ else:
97
+ processed_data[key] = value
98
+
99
+ return action_cls.model_validate(processed_data)
100
+
101
+
102
+ def serialize_observation(observation: Observation) -> Dict[str, Any]:
103
+ """
104
+ Convert Observation instance to JSON-compatible dict using Pydantic.
105
+
106
+ Args:
107
+ observation: Observation instance
108
+
109
+ Returns:
110
+ Dictionary compatible with EnvClient._parse_result()
111
+
112
+ The format matches what EnvClient expects:
113
+ {
114
+ "observation": {...}, # Observation fields
115
+ "reward": float | None,
116
+ "done": bool,
117
+ }
118
+ """
119
+ # Use Pydantic's model_dump() for serialization
120
+ obs_dict = observation.model_dump(
121
+ exclude={
122
+ "reward",
123
+ "done",
124
+ "metadata",
125
+ } # Exclude these from observation dict
126
+ )
127
+
128
+ # Extract reward and done directly from the observation
129
+ reward = observation.reward
130
+ done = observation.done
131
+
132
+ # Return in EnvClient expected format
133
+ return {
134
+ "observation": obs_dict,
135
+ "reward": reward,
136
+ "done": done,
137
+ }
src/core/env_server/types.py CHANGED
@@ -4,54 +4,384 @@
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- from dataclasses import dataclass, field
8
- from typing import Any, Dict, List, Optional, Union
 
 
9
 
10
 
11
  # Type aliases
12
  Scalar = Union[int, float, bool]
13
 
14
 
15
- @dataclass(kw_only=True)
16
- class Action:
17
- """Base class for all environment actions."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- metadata: Dict[str, Any] = field(default_factory=dict)
20
 
 
 
21
 
22
- @dataclass(kw_only=True)
23
- class Observation:
24
- """Base class for all environment observations."""
 
25
 
26
- done: bool = False
27
- reward: Union[bool, int, float, None] = None
28
- metadata: Dict[str, Any] = field(default_factory=dict)
 
 
 
29
 
30
 
31
- @dataclass
32
- class State:
33
- """Base class for environment state."""
34
 
35
- episode_id: Optional[str] = None
36
- step_count: int = 0
37
 
 
 
 
 
 
 
 
 
 
38
 
39
- @dataclass
40
- class CodeExecResult:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  """Result of code execution containing stdout, stderr, and exit code."""
42
 
43
- stdout: str
44
- stderr: str
45
- exit_code: int
46
 
47
 
48
- @dataclass
49
- class EnvironmentMetadata:
50
  """Metadata about an environment for documentation and UI purposes."""
51
-
52
- name: str
53
- description: str
54
- readme_content: Optional[str] = None
55
- version: Optional[str] = None
56
- author: Optional[str] = None
57
- documentation_url: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # This source code is licensed under the BSD-style license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ from enum import Enum
8
+ from typing import Annotated, Any, Dict, Literal, Optional, Union
9
+
10
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
11
 
12
 
13
  # Type aliases
14
  Scalar = Union[int, float, bool]
15
 
16
 
17
+ # =============================================================================
18
+ # Enums for Type Safety
19
+ # =============================================================================
20
+
21
+
22
+ class ServerMode(str, Enum):
23
+ """Server operation mode."""
24
+
25
+ SIMULATION = "simulation"
26
+ PRODUCTION = "production"
27
+
28
+
29
+ class HealthStatus(str, Enum):
30
+ """Server health status values."""
31
+
32
+ HEALTHY = "healthy"
33
+ UNHEALTHY = "unhealthy"
34
+ DEGRADED = "degraded"
35
+
36
+
37
+ class WSErrorCode(str, Enum):
38
+ """WebSocket error codes for structured error handling."""
39
+
40
+ INVALID_JSON = "INVALID_JSON"
41
+ UNKNOWN_TYPE = "UNKNOWN_TYPE"
42
+ VALIDATION_ERROR = "VALIDATION_ERROR"
43
+ EXECUTION_ERROR = "EXECUTION_ERROR"
44
+ CAPACITY_REACHED = "CAPACITY_REACHED"
45
+ FACTORY_ERROR = "FACTORY_ERROR"
46
+ SESSION_ERROR = "SESSION_ERROR"
47
+
48
+
49
+ # =============================================================================
50
+ # Core Types
51
+ # =============================================================================
52
+
53
+
54
+ class Action(BaseModel):
55
+ """Base class for all environment actions.
56
+
57
+ All action subclasses should inherit from this base class.
58
+ Uses Pydantic for automatic validation and serialization.
59
+ """
60
+
61
+ model_config = ConfigDict(
62
+ extra="forbid", # Reject unknown fields
63
+ validate_assignment=True, # Validate on field assignment
64
+ arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc.
65
+ )
66
+
67
+ metadata: Dict[str, Any] = Field(
68
+ default_factory=dict, description="Additional metadata for the action"
69
+ )
70
+
71
+
72
+ class Observation(BaseModel):
73
+ """Base class for all environment observations.
74
+
75
+ All observation subclasses should inherit from this base class.
76
+ Uses Pydantic for automatic validation and serialization.
77
+ """
78
+
79
+ model_config = ConfigDict(
80
+ extra="forbid",
81
+ validate_assignment=True,
82
+ arbitrary_types_allowed=True,
83
+ )
84
+
85
+ done: bool = Field(default=False, description="Whether the episode has terminated")
86
+ reward: bool | int | float | None = Field(
87
+ default=None, description="Reward signal from the last action"
88
+ )
89
+ metadata: Dict[str, Any] = Field(
90
+ default_factory=dict, description="Additional metadata for the observation"
91
+ )
92
 
 
93
 
94
+ class ResetRequest(BaseModel):
95
+ """Request model for environment reset."""
96
 
97
+ model_config = ConfigDict(
98
+ extra="allow", # Allow extra fields for custom reset parameters
99
+ json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]},
100
+ )
101
 
102
+ seed: Optional[int] = Field(
103
+ default=None, ge=0, description="Random seed for reproducible episodes"
104
+ )
105
+ episode_id: Optional[str] = Field(
106
+ default=None, max_length=255, description="Custom episode identifier"
107
+ )
108
 
109
 
110
+ class ResetResponse(BaseModel):
111
+ """Response model for environment reset."""
 
112
 
113
+ model_config = ConfigDict(extra="forbid")
 
114
 
115
+ observation: Dict[str, Any] = Field(
116
+ ..., description="Initial observation from the environment"
117
+ )
118
+ reward: Optional[float] = Field(
119
+ default=None, description="Initial reward (typically None at reset)"
120
+ )
121
+ done: bool = Field(
122
+ default=False, description="Whether episode is already done (typically False)"
123
+ )
124
 
125
+
126
+ class StepRequest(BaseModel):
127
+ """Request model for environment step."""
128
+
129
+ model_config = ConfigDict(
130
+ extra="allow", # Allow extra fields for custom step parameters
131
+ json_schema_extra={
132
+ "examples": [
133
+ {"action": {"value": 1}, "timeout_s": 30.0},
134
+ {"action": {"value": 1}, "render": True, "verbose": False},
135
+ ]
136
+ },
137
+ )
138
+
139
+ action: Dict[str, Any] = Field(
140
+ ...,
141
+ description="Action to execute, must conform to environment's action schema",
142
+ )
143
+ timeout_s: Optional[float] = Field(
144
+ default=None,
145
+ gt=0,
146
+ description="Optional timeout in seconds for action execution",
147
+ )
148
+ request_id: Optional[str] = Field(
149
+ default=None,
150
+ max_length=255,
151
+ description="Optional request identifier for tracking",
152
+ )
153
+
154
+
155
+ class StepResponse(BaseModel):
156
+ """Response model for environment step."""
157
+
158
+ model_config = ConfigDict(extra="forbid")
159
+
160
+ observation: Dict[str, Any] = Field(
161
+ ..., description="Observation resulting from the action"
162
+ )
163
+ reward: Optional[float] = Field(
164
+ default=None, description="Reward signal from the action"
165
+ )
166
+ done: bool = Field(default=False, description="Whether the episode has terminated")
167
+
168
+
169
+ class BaseMessage(BaseModel):
170
+ """Base class for WebSocket messages with shared configuration."""
171
+
172
+ model_config = ConfigDict(
173
+ extra="forbid",
174
+ validate_assignment=True,
175
+ )
176
+
177
+
178
+ class State(BaseModel):
179
+ """Base class for environment state.
180
+
181
+ Represents internal environment state, separate from observations.
182
+ """
183
+
184
+ model_config = ConfigDict(
185
+ extra="allow", # Allow extra fields for flexibility
186
+ validate_assignment=True,
187
+ arbitrary_types_allowed=True,
188
+ )
189
+
190
+ episode_id: Optional[str] = Field(
191
+ default=None, description="Unique identifier for the current episode"
192
+ )
193
+ step_count: int = Field(
194
+ default=0,
195
+ ge=0, # Greater than or equal to 0
196
+ description="Number of steps taken in the current episode",
197
+ )
198
+
199
+
200
+ class CodeExecResult(BaseMessage):
201
  """Result of code execution containing stdout, stderr, and exit code."""
202
 
203
+ stdout: str = Field(description="Standard output from code execution")
204
+ stderr: str = Field(description="Standard error from code execution")
205
+ exit_code: int = Field(description="Exit code from code execution")
206
 
207
 
208
+ class EnvironmentMetadata(BaseMessage):
 
209
  """Metadata about an environment for documentation and UI purposes."""
210
+
211
+ name: str = Field(description="Name of the environment")
212
+ description: str = Field(description="Description of what the environment does")
213
+ readme_content: Optional[str] = Field(
214
+ default=None, description="Content of the README file for the environment"
215
+ )
216
+ version: Optional[str] = Field(
217
+ default=None, description="Version of the environment"
218
+ )
219
+ author: Optional[str] = Field(default=None, description="Author of the environment")
220
+ documentation_url: Optional[str] = Field(
221
+ default=None, description="URL to the environment's documentation"
222
+ )
223
+
224
+
225
+ class SchemaResponse(BaseMessage):
226
+ """Response model for the combined schema endpoint."""
227
+
228
+ action: Dict[str, Any] = Field(
229
+ description="JSON schema for actions accepted by this environment"
230
+ )
231
+ observation: Dict[str, Any] = Field(
232
+ description="JSON schema for observations returned by this environment"
233
+ )
234
+ state: Dict[str, Any] = Field(
235
+ description="JSON schema for environment state objects"
236
+ )
237
+
238
+
239
+ class HealthResponse(BaseMessage):
240
+ """Response model for health check endpoint."""
241
+
242
+ status: HealthStatus = Field(
243
+ default=HealthStatus.HEALTHY,
244
+ description="Health status of the environment server",
245
+ )
246
+
247
+
248
+ class WSResetMessage(BaseMessage):
249
+ """WebSocket message to reset the environment."""
250
+
251
+ type: Literal["reset"] = Field(default="reset", description="Message type")
252
+ data: Dict[str, Any] = Field(
253
+ default_factory=dict,
254
+ description="Optional reset parameters (seed, episode_id, etc.)",
255
+ )
256
+
257
+
258
+ class WSStepMessage(BaseMessage):
259
+ """WebSocket message to execute a step."""
260
+
261
+ type: Literal["step"] = Field(default="step", description="Message type")
262
+ data: Dict[str, Any] = Field(
263
+ ..., description="Action data conforming to environment's action schema"
264
+ )
265
+
266
+
267
+ class WSStateMessage(BaseMessage):
268
+ """WebSocket message to request current state."""
269
+
270
+ type: Literal["state"] = Field(default="state", description="Message type")
271
+
272
+
273
+ class WSCloseMessage(BaseMessage):
274
+ """WebSocket message to close the session."""
275
+
276
+ type: Literal["close"] = Field(default="close", description="Message type")
277
+
278
+
279
+ # Discriminated union for incoming WebSocket messages
280
+ # Note: WSMCPMessage is defined in mcp_types.py to avoid circular imports
281
+ # The union here covers the core message types; MCP messages are handled separately
282
+ WSIncomingMessage = Annotated[
283
+ WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage,
284
+ Field(discriminator="type"),
285
+ ]
286
+
287
+
288
+ class WSObservationResponse(BaseModel):
289
+ """WebSocket response containing an observation."""
290
+
291
+ model_config = ConfigDict(extra="forbid")
292
+
293
+ type: Literal["observation"] = Field(
294
+ default="observation", description="Response type"
295
+ )
296
+ data: Dict[str, Any] = Field(description="Observation data")
297
+
298
+
299
+ class WSStateResponse(BaseModel):
300
+ """WebSocket response containing environment state."""
301
+
302
+ model_config = ConfigDict(extra="forbid")
303
+
304
+ type: Literal["state"] = Field(default="state", description="Response type")
305
+ data: Dict[str, Any] = Field(description="State data")
306
+
307
+
308
+ class WSErrorResponse(BaseModel):
309
+ """WebSocket response for errors."""
310
+
311
+ model_config = ConfigDict(extra="forbid")
312
+
313
+ type: Literal["error"] = Field(default="error", description="Response type")
314
+ data: Dict[str, Any] = Field(description="Error details including message and code")
315
+
316
+
317
+ class ConcurrencyConfig(BaseMessage):
318
+ """Configuration for concurrent environment sessions."""
319
+
320
+ max_concurrent_envs: int = Field(
321
+ default=1,
322
+ ge=1,
323
+ description="Maximum number of concurrent WebSocket sessions allowed",
324
+ )
325
+ session_timeout: Optional[float] = Field(
326
+ default=None,
327
+ gt=0,
328
+ description="Timeout in seconds for inactive sessions. None means no timeout.",
329
+ )
330
+
331
+
332
+ class ServerCapacityStatus(BaseMessage):
333
+ """Status of server capacity for concurrent sessions."""
334
+
335
+ active_sessions: int = Field(
336
+ ge=0,
337
+ description="Number of currently active sessions",
338
+ )
339
+ max_sessions: int = Field(
340
+ ge=1,
341
+ description="Maximum number of allowed sessions",
342
+ )
343
+
344
+ @model_validator(mode="after")
345
+ def check_capacity_bounds(self) -> "ServerCapacityStatus":
346
+ if self.active_sessions > self.max_sessions:
347
+ raise ValueError(
348
+ f"active_sessions ({self.active_sessions}) cannot exceed "
349
+ f"max_sessions ({self.max_sessions})"
350
+ )
351
+ return self
352
+
353
+ @property
354
+ def available_slots(self) -> int:
355
+ """Number of available session slots."""
356
+ return self.max_sessions - self.active_sessions
357
+
358
+ @property
359
+ def is_at_capacity(self) -> bool:
360
+ """Whether the server has reached maximum capacity."""
361
+ return self.available_slots == 0
362
+
363
+ @classmethod
364
+ def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus":
365
+ """Create status from active and max session counts."""
366
+ return cls(
367
+ active_sessions=active,
368
+ max_sessions=max_sessions,
369
+ )
370
+
371
+
372
+ class SessionInfo(BaseMessage):
373
+ """Information about an active session."""
374
+
375
+ session_id: str = Field(description="Unique identifier for the session")
376
+ created_at: float = Field(description="Unix timestamp when the session was created")
377
+ last_activity_at: float = Field(
378
+ description="Unix timestamp of the last activity in the session"
379
+ )
380
+ step_count: int = Field(
381
+ default=0,
382
+ ge=0,
383
+ description="Number of steps executed in this session",
384
+ )
385
+ environment_type: str = Field(
386
+ description="Environment type for this session (e.g. `CodingEnv`)"
387
+ )
src/core/env_server/web_interface.py CHANGED
@@ -7,61 +7,164 @@
7
  """
8
  Web interface for OpenEnv environments.
9
 
10
- This module provides a web-based interface for interacting with OpenEnv environments,
11
- including a two-pane layout for HumanAgent interaction and state observation.
 
12
  """
13
 
14
  from __future__ import annotations
15
 
 
16
  import json
17
- import time
18
- from dataclasses import asdict, dataclass
19
- from typing import Any, Dict, List, Optional, Type
20
  from datetime import datetime
 
21
 
22
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
23
- from fastapi.responses import HTMLResponse, FileResponse
24
- from fastapi.staticfiles import StaticFiles
25
- from pydantic import BaseModel
26
 
 
 
27
  from .interfaces import Environment
28
- from .types import Action, Observation, State, EnvironmentMetadata
 
29
 
 
 
 
30
 
31
- def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  """
33
  Load environment metadata including README content.
34
-
35
  Args:
36
- env: The environment instance
 
 
 
37
  env_name: Optional environment name for README file lookup
38
-
39
  Returns:
40
  EnvironmentMetadata with loaded information
41
  """
42
- # Try to get metadata from environment if it has a method for it
43
- if hasattr(env, 'get_metadata'):
 
 
 
 
 
 
 
 
 
 
44
  return env.get_metadata()
45
-
 
 
 
 
 
 
 
 
 
 
 
46
  # Default metadata
47
  metadata = EnvironmentMetadata(
48
- name=env_name or env.__class__.__name__,
49
- description=f"{env.__class__.__name__} environment",
50
- version="1.0.0"
51
  )
52
-
53
  # Try to load README from file system
54
  readme_content = _load_readme_from_filesystem(env_name)
55
  if readme_content:
56
  metadata.readme_content = readme_content
57
-
58
  return metadata
59
 
60
 
61
  def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
62
  """
63
  Load README content from the filesystem.
64
-
65
  Tries multiple locations:
66
  1. Container filesystem: /app/README.md
67
  2. Local development: src/envs/{env_name}/README.md
@@ -69,59 +172,73 @@ def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
69
  """
70
  import os
71
  from pathlib import Path
72
-
73
  # Try container filesystem first
74
  container_readme = Path("/app/README.md")
75
  if container_readme.exists():
76
  try:
77
- return container_readme.read_text(encoding='utf-8')
78
  except Exception:
79
  pass
80
-
81
  # Try environment variable path
82
  custom_path = os.environ.get("ENV_README_PATH")
83
  if custom_path and Path(custom_path).exists():
84
  try:
85
- return Path(custom_path).read_text(encoding='utf-8')
86
  except Exception:
87
  pass
88
-
89
  # Try local development path
90
  if env_name:
91
  local_readme = Path(f"src/envs/{env_name}/README.md")
92
  if local_readme.exists():
93
  try:
94
- return local_readme.read_text(encoding='utf-8')
95
  except Exception:
96
  pass
97
-
98
  return None
99
 
100
 
101
- @dataclass
102
- class ActionLog:
103
  """Log entry for an action taken."""
104
- timestamp: str
105
- action: Dict[str, Any]
106
- observation: Dict[str, Any]
107
- reward: Optional[float]
108
- done: bool
109
- step_count: int
 
 
 
 
 
110
 
111
 
112
- @dataclass
113
- class EpisodeState:
114
  """Current episode state for the web interface."""
115
- episode_id: Optional[str]
116
- step_count: int
117
- current_observation: Optional[Dict[str, Any]]
118
- action_logs: List[ActionLog]
119
- is_reset: bool = True
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  class WebInterfaceManager:
123
  """Manages the web interface for an environment."""
124
-
 
 
125
  def __init__(
126
  self,
127
  env: Environment,
@@ -129,152 +246,146 @@ class WebInterfaceManager:
129
  observation_cls: Type[Observation],
130
  metadata: Optional[EnvironmentMetadata] = None,
131
  ):
132
- self.env = env
 
 
 
 
 
 
133
  self.action_cls = action_cls
134
  self.observation_cls = observation_cls
135
  self.metadata = metadata or EnvironmentMetadata(
136
  name=env.__class__.__name__,
137
- description=f"{env.__class__.__name__} environment"
138
  )
139
  self.episode_state = EpisodeState(
140
  episode_id=None,
141
  step_count=0,
142
  current_observation=None,
143
- action_logs=[]
144
  )
145
  self.connected_clients: List[WebSocket] = []
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  async def connect_websocket(self, websocket: WebSocket):
148
  """Connect a new WebSocket client."""
149
  await websocket.accept()
150
  self.connected_clients.append(websocket)
151
-
152
  # Send current state to the new client
153
  await self._send_state_update()
154
-
155
  async def disconnect_websocket(self, websocket: WebSocket):
156
  """Disconnect a WebSocket client."""
157
  if websocket in self.connected_clients:
158
  self.connected_clients.remove(websocket)
159
-
160
  async def _send_state_update(self):
161
  """Send current state to all connected clients."""
162
  if not self.connected_clients:
163
  return
164
-
165
  state_data = {
166
  "type": "state_update",
167
- "episode_state": asdict(self.episode_state)
168
  }
169
-
170
  # Send to all connected clients
171
  disconnected_clients = []
172
  for client in self.connected_clients:
173
  try:
174
  await client.send_text(json.dumps(state_data))
175
- except:
176
  disconnected_clients.append(client)
177
-
178
  # Remove disconnected clients
179
  for client in disconnected_clients:
180
  self.connected_clients.remove(client)
181
-
182
  async def reset_environment(self) -> Dict[str, Any]:
183
  """Reset the environment and update state."""
184
- observation = self.env.reset()
185
- state = self.env.state
186
-
 
 
 
 
 
187
  # Update episode state
188
  self.episode_state.episode_id = state.episode_id
189
  self.episode_state.step_count = 0
190
- self.episode_state.current_observation = asdict(observation)
191
  self.episode_state.action_logs = []
192
  self.episode_state.is_reset = True
193
-
194
  # Send state update
195
  await self._send_state_update()
196
-
197
- return {
198
- "observation": asdict(observation),
199
- "reward": observation.reward,
200
- "done": observation.done,
201
- }
202
-
203
  async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
204
  """Execute a step in the environment and update state."""
205
- # Deserialize action
206
- action = self._deserialize_action(action_data)
207
-
208
- # Execute step
209
- observation = self.env.step(action)
210
- state = self.env.state
211
-
 
 
 
 
 
 
 
 
212
  # Create action log
213
  action_log = ActionLog(
214
  timestamp=datetime.now().isoformat(),
215
- action=asdict(action),
216
- observation=asdict(observation),
217
  reward=observation.reward,
218
  done=observation.done,
219
- step_count=state.step_count
220
  )
221
-
222
  # Update episode state
223
  self.episode_state.episode_id = state.episode_id
224
  self.episode_state.step_count = state.step_count
225
- self.episode_state.current_observation = asdict(observation)
226
  self.episode_state.action_logs.append(action_log)
 
 
 
 
227
  self.episode_state.is_reset = False
228
-
229
  # Send state update
230
  await self._send_state_update()
231
-
232
- return {
233
- "observation": asdict(observation),
234
- "reward": observation.reward,
235
- "done": observation.done,
236
- }
237
-
238
  def get_state(self) -> Dict[str, Any]:
239
  """Get current environment state."""
240
- state = self.env.state
241
- return asdict(state)
242
-
243
- def _deserialize_action(self, action_data: Dict[str, Any]) -> Action:
244
- """Convert JSON dict to Action instance."""
245
- metadata = action_data.pop("metadata", {})
246
-
247
- # Handle tensor fields that come from JSON as lists
248
- processed_data = {}
249
- for key, value in action_data.items():
250
- if key == "tokens" and isinstance(value, (list, str)):
251
- # Convert list or string to tensor
252
- if isinstance(value, str):
253
- # If it's a string, try to parse it as a list of numbers
254
- try:
255
- import json
256
- value = json.loads(value)
257
- except:
258
- # If parsing fails, treat as empty list
259
- value = []
260
- if isinstance(value, list):
261
- import torch
262
- processed_data[key] = torch.tensor(value, dtype=torch.long)
263
- else:
264
- processed_data[key] = value
265
- elif key == "action_id" and isinstance(value, str):
266
- # Convert action_id from string to int
267
- try:
268
- processed_data[key] = int(value)
269
- except ValueError:
270
- # If conversion fails, keep original value
271
- processed_data[key] = value
272
- else:
273
- processed_data[key] = value
274
-
275
- action = self.action_cls(**processed_data)
276
- action.metadata = metadata
277
- return action
278
 
279
 
280
  def create_web_interface_app(
@@ -282,44 +393,53 @@ def create_web_interface_app(
282
  action_cls: Type[Action],
283
  observation_cls: Type[Observation],
284
  env_name: Optional[str] = None,
 
 
 
285
  ) -> FastAPI:
286
  """
287
  Create a FastAPI application with web interface for the given environment.
288
-
289
  Args:
290
  env: The Environment instance to serve
291
  action_cls: The Action subclass this environment expects
292
  observation_cls: The Observation subclass this environment returns
293
  env_name: Optional environment name for README loading
294
-
 
 
 
 
 
295
  Returns:
296
  FastAPI application instance with web interface
297
  """
298
  from .http_server import create_fastapi_app
299
-
300
  # Create the base environment app
301
- app = create_fastapi_app(env, action_cls, observation_cls)
302
-
 
 
303
  # Load environment metadata
304
  metadata = load_environment_metadata(env, env_name)
305
-
306
  # Create web interface manager
307
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
308
-
309
- # Add web interface routes
310
- @app.get("/web", response_class=HTMLResponse)
311
- async def web_interface():
312
- """Serve the web interface."""
313
- return get_web_interface_html(action_cls, web_manager.metadata)
314
-
315
  @app.get("/web/metadata")
316
  async def web_metadata():
317
  """Get environment metadata."""
318
- return asdict(web_manager.metadata)
319
-
320
- @app.websocket("/ws")
321
- async def websocket_endpoint(websocket: WebSocket):
322
- """WebSocket endpoint for real-time updates."""
 
 
 
 
323
  await web_manager.connect_websocket(websocket)
324
  try:
325
  while True:
@@ -327,1287 +447,198 @@ def create_web_interface_app(
327
  await websocket.receive_text()
328
  except WebSocketDisconnect:
329
  await web_manager.disconnect_websocket(websocket)
330
-
331
  @app.post("/web/reset")
332
  async def web_reset():
333
  """Reset endpoint for web interface."""
334
  return await web_manager.reset_environment()
335
-
336
  @app.post("/web/step")
337
  async def web_step(request: Dict[str, Any]):
338
  """Step endpoint for web interface."""
339
  # Check if this is a message-based request (chat environment)
340
  if "message" in request:
341
  message = request["message"]
342
- # Convert message to action using the environment's message_to_action method
343
- action = web_manager.env.message_to_action(message)
344
- action_data = {"tokens": action.tokens.tolist()}
 
 
 
 
 
345
  else:
346
  action_data = request.get("action", {})
347
-
348
  return await web_manager.step_environment(action_data)
349
-
350
  @app.get("/web/state")
351
  async def web_state():
352
  """State endpoint for web interface."""
353
  return web_manager.get_state()
354
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  return app
356
 
357
 
358
- def get_web_interface_html(action_cls: Type[Action], metadata: Optional[EnvironmentMetadata] = None) -> str:
359
- """Generate the HTML for the web interface."""
360
-
361
- # Check if this is a chat environment by looking for tokens field
362
- is_chat_env = False
363
- if hasattr(action_cls, '__dataclass_fields__'):
364
- for field_name, field_info in action_cls.__dataclass_fields__.items():
365
- if field_name == 'tokens' and hasattr(field_info.type, '__name__') and 'Tensor' in field_info.type.__name__:
366
- is_chat_env = True
367
- break
368
-
369
- # Get action fields for dynamic form generation with enhanced metadata
370
- action_fields = _extract_action_fields(action_cls)
371
-
372
- return f"""
373
- <!DOCTYPE html>
374
- <html lang="en">
375
- <head>
376
- <meta charset="UTF-8">
377
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
378
- <title>OpenEnv Web Interface</title>
379
- <style>
380
- * {{
381
- margin: 0;
382
- padding: 0;
383
- box-sizing: border-box;
384
- }}
385
-
386
- body {{
387
- font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
388
- background-color: #f5f5f5;
389
- height: 100vh;
390
- overflow: hidden;
391
- }}
392
-
393
- .container {{
394
- display: flex;
395
- height: 100vh;
396
- }}
397
-
398
- .left-pane {{
399
- width: 50%;
400
- background: white;
401
- border-right: 1px solid #e0e0e0;
402
- display: flex;
403
- flex-direction: column;
404
- }}
405
-
406
- .right-pane {{
407
- width: 50%;
408
- background: #fafafa;
409
- display: flex;
410
- flex-direction: column;
411
- }}
412
-
413
- .pane-header {{
414
- padding: 20px;
415
- border-bottom: 1px solid #e0e0e0;
416
- background: #f8f9fa;
417
- font-weight: 600;
418
- font-size: 16px;
419
- }}
420
-
421
- .pane-content {{
422
- flex: 1;
423
- padding: 20px;
424
- overflow-y: auto;
425
- }}
426
-
427
- .action-form {{
428
- background: white;
429
- border: 1px solid #e0e0e0;
430
- border-radius: 8px;
431
- padding: 20px;
432
- margin-bottom: 20px;
433
- }}
434
-
435
- .form-group {{
436
- margin-bottom: 15px;
437
- }}
438
-
439
- .form-group label {{
440
- display: block;
441
- margin-bottom: 5px;
442
- font-weight: 500;
443
- color: #333;
444
- }}
445
-
446
- .form-group input, .form-group textarea {{
447
- width: 100%;
448
- padding: 8px 12px;
449
- border: 1px solid #ddd;
450
- border-radius: 4px;
451
- font-size: 14px;
452
- }}
453
-
454
- .form-group input:focus, .form-group textarea:focus {{
455
- outline: none;
456
- border-color: #007bff;
457
- box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
458
- }}
459
-
460
- .btn {{
461
- background: #007bff;
462
- color: white;
463
- border: none;
464
- padding: 10px 20px;
465
- border-radius: 4px;
466
- cursor: pointer;
467
- font-size: 14px;
468
- margin-right: 10px;
469
- margin-bottom: 10px;
470
- }}
471
-
472
- .btn:hover {{
473
- background: #0056b3;
474
- }}
475
-
476
- .btn:disabled {{
477
- background: #6c757d;
478
- cursor: not-allowed;
479
- }}
480
-
481
- .btn-secondary {{
482
- background: #6c757d;
483
- }}
484
-
485
- .btn-secondary:hover {{
486
- background: #545b62;
487
- }}
488
-
489
- .state-display {{
490
- background: white;
491
- border: 1px solid #e0e0e0;
492
- border-radius: 8px;
493
- padding: 15px;
494
- margin-bottom: 20px;
495
- }}
496
-
497
- .state-item {{
498
- margin-bottom: 8px;
499
- }}
500
-
501
- .state-label {{
502
- font-weight: 500;
503
- color: #666;
504
- }}
505
-
506
- .state-value {{
507
- color: #333;
508
- font-family: monospace;
509
- }}
510
-
511
- .logs-container {{
512
- background: white;
513
- border: 1px solid #e0e0e0;
514
- border-radius: 8px;
515
- padding: 15px;
516
- max-height: 400px;
517
- overflow-y: auto;
518
- }}
519
-
520
- .log-entry {{
521
- border-bottom: 1px solid #f0f0f0;
522
- padding: 10px 0;
523
- }}
524
-
525
- .log-entry:last-child {{
526
- border-bottom: none;
527
- }}
528
-
529
- .log-timestamp {{
530
- font-size: 12px;
531
- color: #666;
532
- margin-bottom: 5px;
533
- }}
534
-
535
- .log-action {{
536
- background: #e3f2fd;
537
- padding: 8px;
538
- border-radius: 4px;
539
- margin-bottom: 5px;
540
- font-family: monospace;
541
- font-size: 12px;
542
- }}
543
-
544
- .log-observation {{
545
- background: #f3e5f5;
546
- padding: 8px;
547
- border-radius: 4px;
548
- font-family: monospace;
549
- font-size: 12px;
550
- }}
551
-
552
- .log-reward {{
553
- font-weight: 600;
554
- color: #28a745;
555
- }}
556
-
557
- .log-done {{
558
- font-weight: 600;
559
- color: #dc3545;
560
- }}
561
-
562
- .status-indicator {{
563
- display: inline-block;
564
- width: 8px;
565
- height: 8px;
566
- border-radius: 50%;
567
- margin-right: 8px;
568
- }}
569
-
570
- .status-connected {{
571
- background: #28a745;
572
- }}
573
-
574
- .status-disconnected {{
575
- background: #dc3545;
576
- }}
577
-
578
- .json-display {{
579
- background: #f8f9fa;
580
- border: 1px solid #e9ecef;
581
- border-radius: 4px;
582
- padding: 10px;
583
- font-family: monospace;
584
- font-size: 12px;
585
- white-space: pre-wrap;
586
- max-height: 200px;
587
- overflow-y: auto;
588
- }}
589
-
590
- /* Chat Interface Styles */
591
- .chat-interface {{
592
- background: white;
593
- border: 1px solid #e0e0e0;
594
- border-radius: 8px;
595
- padding: 20px;
596
- margin-bottom: 20px;
597
- }}
598
-
599
- .chat-messages {{
600
- background: #f8f9fa;
601
- border: 1px solid #e0e0e0;
602
- border-radius: 8px;
603
- padding: 15px;
604
- margin-bottom: 15px;
605
- max-height: 400px;
606
- overflow-y: auto;
607
- }}
608
-
609
- .chat-message {{
610
- margin-bottom: 15px;
611
- padding: 10px;
612
- border-radius: 8px;
613
- }}
614
-
615
- .chat-message:last-child {{
616
- margin-bottom: 0;
617
- }}
618
-
619
- .chat-message.user {{
620
- background: #e3f2fd;
621
- margin-left: 20px;
622
- }}
623
-
624
- .chat-message.assistant {{
625
- background: #f3e5f5;
626
- margin-right: 20px;
627
- }}
628
-
629
- .chat-message.system {{
630
- background: #e8f5e8;
631
- font-style: italic;
632
- }}
633
-
634
- .message-role {{
635
- font-weight: 600;
636
- font-size: 12px;
637
- color: #666;
638
- margin-bottom: 5px;
639
- }}
640
-
641
- .message-content {{
642
- font-size: 14px;
643
- line-height: 1.4;
644
- }}
645
-
646
- .chat-input-container {{
647
- border-top: 1px solid #e0e0e0;
648
- padding-top: 15px;
649
- }}
650
-
651
- .role-selector {{
652
- margin-bottom: 10px;
653
- }}
654
-
655
- .role-selector label {{
656
- font-weight: 500;
657
- margin-right: 10px;
658
- }}
659
-
660
- .role-selector select {{
661
- padding: 5px 10px;
662
- border: 1px solid #ddd;
663
- border-radius: 4px;
664
- }}
665
-
666
- .message-input {{
667
- display: flex;
668
- gap: 10px;
669
- align-items: flex-end;
670
- }}
671
-
672
- .message-input textarea {{
673
- flex: 1;
674
- padding: 10px;
675
- border: 1px solid #ddd;
676
- border-radius: 4px;
677
- resize: vertical;
678
- font-family: inherit;
679
- }}
680
-
681
- .message-input textarea:focus {{
682
- outline: none;
683
- border-color: #007bff;
684
- box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
685
- }}
686
-
687
- /* Instructions Section Styles */
688
- .instructions-section {{
689
- background: white;
690
- border: 1px solid #e0e0e0;
691
- border-radius: 8px;
692
- padding: 20px;
693
- margin-bottom: 20px;
694
- }}
695
-
696
- .instructions-header {{
697
- display: flex;
698
- justify-content: space-between;
699
- align-items: center;
700
- margin-bottom: 15px;
701
- }}
702
-
703
- .instructions-title {{
704
- font-size: 18px;
705
- font-weight: 600;
706
- color: #333;
707
- margin: 0;
708
- }}
709
-
710
- .instructions-toggle {{
711
- background: #f8f9fa;
712
- border: 1px solid #dee2e6;
713
- border-radius: 4px;
714
- padding: 5px 10px;
715
- cursor: pointer;
716
- font-size: 12px;
717
- color: #6c757d;
718
- }}
719
-
720
- .instructions-toggle:hover {{
721
- background: #e9ecef;
722
- }}
723
-
724
- .instructions-content {{
725
- display: none;
726
- max-height: 400px;
727
- overflow-y: auto;
728
- border-top: 1px solid #e0e0e0;
729
- padding-top: 15px;
730
- }}
731
-
732
- .instructions-content.expanded {{
733
- display: block;
734
- }}
735
-
736
- .instructions-content h1,
737
- .instructions-content h2,
738
- .instructions-content h3 {{
739
- color: #333;
740
- margin-top: 20px;
741
- margin-bottom: 10px;
742
- }}
743
-
744
- .instructions-content h1 {{
745
- font-size: 24px;
746
- border-bottom: 2px solid #007bff;
747
- padding-bottom: 10px;
748
- }}
749
-
750
- .instructions-content h2 {{
751
- font-size: 20px;
752
- }}
753
-
754
- .instructions-content h3 {{
755
- font-size: 16px;
756
- }}
757
-
758
- .instructions-content p {{
759
- margin-bottom: 10px;
760
- line-height: 1.6;
761
- }}
762
-
763
- .instructions-content code {{
764
- background: #f8f9fa;
765
- padding: 2px 4px;
766
- border-radius: 3px;
767
- font-family: monospace;
768
- font-size: 14px;
769
- }}
770
-
771
- .instructions-content pre {{
772
- background: #f8f9fa;
773
- border: 1px solid #e9ecef;
774
- border-radius: 4px;
775
- padding: 15px;
776
- overflow-x: auto;
777
- margin: 10px 0;
778
- }}
779
-
780
- .instructions-content pre code {{
781
- background: none;
782
- padding: 0;
783
- }}
784
-
785
- .instructions-content ul,
786
- .instructions-content ol {{
787
- margin: 10px 0;
788
- padding-left: 20px;
789
- }}
790
-
791
- .instructions-content li {{
792
- margin-bottom: 5px;
793
- }}
794
-
795
- .instructions-content table {{
796
- border-collapse: collapse;
797
- width: 100%;
798
- margin: 15px 0;
799
- }}
800
-
801
- .instructions-content th,
802
- .instructions-content td {{
803
- border: 1px solid #dee2e6;
804
- padding: 8px 12px;
805
- text-align: left;
806
- }}
807
-
808
- .instructions-content th {{
809
- background: #f8f9fa;
810
- font-weight: 600;
811
- }}
812
-
813
- /* Enhanced Form Styles */
814
- .help-text {{
815
- display: block;
816
- margin-top: 5px;
817
- font-size: 12px;
818
- color: #6c757d;
819
- font-style: italic;
820
- }}
821
-
822
- .form-group label {{
823
- font-weight: 500;
824
- color: #333;
825
- margin-bottom: 5px;
826
- }}
827
-
828
- .form-group select {{
829
- width: 100%;
830
- padding: 8px 12px;
831
- border: 1px solid #ddd;
832
- border-radius: 4px;
833
- font-size: 14px;
834
- background-color: white;
835
- }}
836
-
837
- .form-group select:focus {{
838
- outline: none;
839
- border-color: #007bff;
840
- box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
841
- }}
842
-
843
- .form-group textarea {{
844
- width: 100%;
845
- padding: 8px 12px;
846
- border: 1px solid #ddd;
847
- border-radius: 4px;
848
- font-size: 14px;
849
- font-family: inherit;
850
- resize: vertical;
851
- }}
852
-
853
- .form-group textarea:focus {{
854
- outline: none;
855
- border-color: #007bff;
856
- box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
857
- }}
858
-
859
- .form-group input[type="number"] {{
860
- width: 100%;
861
- padding: 8px 12px;
862
- border: 1px solid #ddd;
863
- border-radius: 4px;
864
- font-size: 14px;
865
- }}
866
-
867
- .form-group input[type="number"]:focus {{
868
- outline: none;
869
- border-color: #007bff;
870
- box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
871
- }}
872
-
873
- .form-group input[type="text"]:focus {{
874
- outline: none;
875
- border-color: #007bff;
876
- box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
877
- }}
878
-
879
- .required-indicator {{
880
- color: #dc3545;
881
- font-weight: bold;
882
- }}
883
-
884
- .form-group .field-description {{
885
- font-size: 11px;
886
- color: #666;
887
- margin-top: 2px;
888
- font-style: italic;
889
- }}
890
- </style>
891
- </head>
892
- <body>
893
- <div class="container">
894
- <!-- Left Pane: HumanAgent Interface -->
895
- <div class="left-pane">
896
- <div class="pane-header">
897
- <span class="status-indicator status-disconnected" id="connection-status"></span>
898
- HumanAgent Interface
899
- </div>
900
- <div class="pane-content">
901
- <!-- Instructions Section -->
902
- {_generate_instructions_section(metadata)}
903
-
904
- <!-- Action Form or Chat Interface -->
905
- {_generate_action_interface(action_fields, is_chat_env)}
906
-
907
- <!-- Control Buttons -->
908
- <div style="margin-bottom: 20px;">
909
- <button class="btn btn-secondary" id="reset-btn">Reset Environment</button>
910
- <button class="btn btn-secondary" id="state-btn">Get State</button>
911
- </div>
912
-
913
- <!-- Current State Display -->
914
- <div class="state-display">
915
- <h3>Current State</h3>
916
- <div id="current-state">
917
- <div class="state-item">
918
- <span class="state-label">Status:</span>
919
- <span class="state-value" id="env-status">Not initialized</span>
920
- </div>
921
- <div class="state-item">
922
- <span class="state-label">Episode ID:</span>
923
- <span class="state-value" id="episode-id">-</span>
924
- </div>
925
- <div class="state-item">
926
- <span class="state-label">Step Count:</span>
927
- <span class="state-value" id="step-count">0</span>
928
- </div>
929
- </div>
930
- </div>
931
- </div>
932
- </div>
933
-
934
- <!-- Right Pane: State Observer -->
935
- <div class="right-pane">
936
- <div class="pane-header">
937
- State Observer
938
- </div>
939
- <div class="pane-content">
940
- <!-- Current Observation -->
941
- <div class="state-display">
942
- <h3>Current Observation</h3>
943
- <div id="current-observation" class="json-display">
944
- No observation yet
945
- </div>
946
- </div>
947
-
948
- <!-- Action Logs -->
949
- <div class="logs-container">
950
- <h3>Action History</h3>
951
- <div id="action-logs">
952
- No actions taken yet
953
- </div>
954
- </div>
955
- </div>
956
- </div>
957
- </div>
958
-
959
- <script>
960
- class OpenEnvWebInterface {{
961
- constructor() {{
962
- this.ws = null;
963
- this.isConnected = false;
964
- this.init();
965
- }}
966
-
967
- init() {{
968
- this.connectWebSocket();
969
- this.setupEventListeners();
970
- }}
971
-
972
- connectWebSocket() {{
973
- const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
974
- const wsUrl = `${{protocol}}//${{window.location.host}}/ws`;
975
-
976
- this.ws = new WebSocket(wsUrl);
977
-
978
- this.ws.onopen = () => {{
979
- this.isConnected = true;
980
- this.updateConnectionStatus(true);
981
- console.log('WebSocket connected');
982
- }};
983
-
984
- this.ws.onmessage = (event) => {{
985
- const data = JSON.parse(event.data);
986
- if (data.type === 'state_update') {{
987
- this.updateUI(data.episode_state);
988
- }}
989
- }};
990
-
991
- this.ws.onclose = () => {{
992
- this.isConnected = false;
993
- this.updateConnectionStatus(false);
994
- console.log('WebSocket disconnected');
995
- // Attempt to reconnect after 3 seconds
996
- setTimeout(() => this.connectWebSocket(), 3000);
997
- }};
998
-
999
- this.ws.onerror = (error) => {{
1000
- console.error('WebSocket error:', error);
1001
- }};
1002
- }}
1003
-
1004
- setupEventListeners() {{
1005
- // Instructions toggle
1006
- const instructionsToggle = document.getElementById('instructions-toggle');
1007
- const instructionsContent = document.getElementById('instructions-content');
1008
- if (instructionsToggle && instructionsContent) {{
1009
- instructionsToggle.addEventListener('click', () => {{
1010
- instructionsContent.classList.toggle('expanded');
1011
- instructionsToggle.textContent = instructionsContent.classList.contains('expanded')
1012
- ? 'Hide Instructions' : 'Show Instructions';
1013
- }});
1014
- }}
1015
-
1016
- // Check if this is a chat environment
1017
- const isChatEnv = document.getElementById('chat-messages') !== null;
1018
-
1019
- if (isChatEnv) {{
1020
- // Chat environment event listeners
1021
- document.getElementById('send-message-btn').addEventListener('click', () => {{
1022
- this.sendMessage();
1023
- }});
1024
-
1025
- // Send message on Enter (but allow Shift+Enter for new lines)
1026
- document.getElementById('message-input').addEventListener('keydown', (e) => {{
1027
- if (e.key === 'Enter' && !e.shiftKey) {{
1028
- e.preventDefault();
1029
- this.sendMessage();
1030
- }}
1031
- }});
1032
- }} else {{
1033
- // Traditional action form submission
1034
- const actionForm = document.getElementById('action-form');
1035
- if (actionForm) {{
1036
- actionForm.addEventListener('submit', (e) => {{
1037
- e.preventDefault();
1038
- this.submitAction();
1039
- }});
1040
- }}
1041
- }}
1042
-
1043
- // Reset button
1044
- document.getElementById('reset-btn').addEventListener('click', () => {{
1045
- this.resetEnvironment();
1046
- }});
1047
-
1048
- // State button
1049
- document.getElementById('state-btn').addEventListener('click', () => {{
1050
- this.getState();
1051
- }});
1052
- }}
1053
-
1054
- async sendMessage() {{
1055
- const messageInput = document.getElementById('message-input');
1056
- const roleSelect = document.getElementById('message-role');
1057
- const message = messageInput.value.trim();
1058
- const role = roleSelect.value;
1059
-
1060
- if (!message) {{
1061
- return;
1062
- }}
1063
-
1064
- // Add message to chat display immediately
1065
- this.addMessageToChat(role, message);
1066
-
1067
- // Clear input
1068
- messageInput.value = '';
1069
-
1070
- try {{
1071
- // Send message to server to convert to action and step
1072
- const response = await fetch('/web/step', {{
1073
- method: 'POST',
1074
- headers: {{ 'Content-Type': 'application/json' }},
1075
- body: JSON.stringify({{
1076
- message: {{
1077
- role: role,
1078
- content: message
1079
- }}
1080
- }})
1081
- }});
1082
-
1083
- if (!response.ok) {{
1084
- throw new Error(`HTTP error! status: ${{response.status}}`);
1085
- }}
1086
-
1087
- const result = await response.json();
1088
- console.log('Message sent:', result);
1089
- }} catch (error) {{
1090
- console.error('Error sending message:', error);
1091
- alert('Error sending message: ' + error.message);
1092
- }}
1093
- }}
1094
-
1095
- addMessageToChat(role, content) {{
1096
- const chatMessages = document.getElementById('chat-messages');
1097
- const messageDiv = document.createElement('div');
1098
- messageDiv.className = `chat-message ${{role}}`;
1099
-
1100
- messageDiv.innerHTML = `
1101
- <div class="message-role">${{role.charAt(0).toUpperCase() + role.slice(1)}}</div>
1102
- <div class="message-content">${{content}}</div>
1103
- `;
1104
-
1105
- chatMessages.appendChild(messageDiv);
1106
- chatMessages.scrollTop = chatMessages.scrollHeight;
1107
- }}
1108
-
1109
- async submitAction() {{
1110
- const formData = new FormData(document.getElementById('action-form'));
1111
- const action = {{}};
1112
-
1113
- // Collect form data
1114
- for (const [key, value] of formData.entries()) {{
1115
- if (value !== '') {{
1116
- // Handle tensor fields (tokens) - convert comma-separated string to array
1117
- if (key === 'tokens') {{
1118
- try {{
1119
- action[key] = value.split(',').map(x => parseInt(x.trim())).filter(x => !isNaN(x));
1120
- }} catch (e) {{
1121
- console.error('Error parsing tokens:', e);
1122
- action[key] = [];
1123
- }}
1124
- }} else {{
1125
- action[key] = value;
1126
- }}
1127
- }}
1128
- }}
1129
-
1130
- try {{
1131
- const response = await fetch('/web/step', {{
1132
- method: 'POST',
1133
- headers: {{ 'Content-Type': 'application/json' }},
1134
- body: JSON.stringify({{ action }})
1135
- }});
1136
-
1137
- if (!response.ok) {{
1138
- throw new Error(`HTTP error! status: ${{response.status}}`);
1139
- }}
1140
-
1141
- const result = await response.json();
1142
- console.log('Step result:', result);
1143
- }} catch (error) {{
1144
- console.error('Error submitting action:', error);
1145
- alert('Error submitting action: ' + error.message);
1146
- }}
1147
- }}
1148
-
1149
- async resetEnvironment() {{
1150
- try {{
1151
- const response = await fetch('/web/reset', {{
1152
- method: 'POST',
1153
- headers: {{ 'Content-Type': 'application/json' }}
1154
- }});
1155
-
1156
- if (!response.ok) {{
1157
- throw new Error(`HTTP error! status: ${{response.status}}`);
1158
- }}
1159
-
1160
- const result = await response.json();
1161
- console.log('Reset result:', result);
1162
- }} catch (error) {{
1163
- console.error('Error resetting environment:', error);
1164
- alert('Error resetting environment: ' + error.message);
1165
- }}
1166
- }}
1167
-
1168
- async getState() {{
1169
- try {{
1170
- const response = await fetch('/web/state');
1171
- const state = await response.json();
1172
- console.log('Current state:', state);
1173
- alert('Current state: ' + JSON.stringify(state, null, 2));
1174
- }} catch (error) {{
1175
- console.error('Error getting state:', error);
1176
- alert('Error getting state: ' + error.message);
1177
- }}
1178
- }}
1179
-
1180
- updateConnectionStatus(connected) {{
1181
- const indicator = document.getElementById('connection-status');
1182
- if (connected) {{
1183
- indicator.className = 'status-indicator status-connected';
1184
- }} else {{
1185
- indicator.className = 'status-indicator status-disconnected';
1186
- }}
1187
- }}
1188
-
1189
- updateUI(episodeState) {{
1190
- // Check if this is a chat environment
1191
- const isChatEnv = document.getElementById('chat-messages') !== null;
1192
-
1193
- // Update current state
1194
- document.getElementById('env-status').textContent =
1195
- episodeState.is_reset ? 'Reset' : 'Running';
1196
- document.getElementById('episode-id').textContent =
1197
- episodeState.episode_id || '-';
1198
- document.getElementById('step-count').textContent =
1199
- episodeState.step_count.toString();
1200
-
1201
- if (isChatEnv) {{
1202
- // Update chat interface
1203
- this.updateChatInterface(episodeState);
1204
- }} else {{
1205
- // Update traditional observation display
1206
- const observationDiv = document.getElementById('current-observation');
1207
- if (episodeState.current_observation) {{
1208
- observationDiv.textContent = JSON.stringify(
1209
- episodeState.current_observation, null, 2
1210
- );
1211
- }} else {{
1212
- observationDiv.textContent = 'No observation yet';
1213
- }}
1214
- }}
1215
-
1216
- // Update action logs
1217
- const logsDiv = document.getElementById('action-logs');
1218
- if (episodeState.action_logs.length === 0) {{
1219
- logsDiv.innerHTML = 'No actions taken yet';
1220
- }} else {{
1221
- logsDiv.innerHTML = episodeState.action_logs.map(log => `
1222
- <div class="log-entry">
1223
- <div class="log-timestamp">${{log.timestamp}} (Step ${{log.step_count}})</div>
1224
- <div class="log-action">Action: ${{JSON.stringify(log.action, null, 2)}}</div>
1225
- <div class="log-observation">Observation: ${{JSON.stringify(log.observation, null, 2)}}</div>
1226
- <div>
1227
- <span class="log-reward">Reward: ${{log.reward !== null ? log.reward : 'None'}}</span>
1228
- ${{log.done ? '<span class="log-done">DONE</span>' : ''}}
1229
- </div>
1230
- </div>
1231
- `).join('');
1232
- }}
1233
- }}
1234
-
1235
- updateChatInterface(episodeState) {{
1236
- const chatMessages = document.getElementById('chat-messages');
1237
- if (!chatMessages) return;
1238
-
1239
- // Clear existing messages (except system message)
1240
- const systemMessage = chatMessages.querySelector('.chat-message.system');
1241
- chatMessages.innerHTML = '';
1242
- if (systemMessage) {{
1243
- chatMessages.appendChild(systemMessage);
1244
- }}
1245
-
1246
- // Add messages from current observation
1247
- if (episodeState.current_observation && episodeState.current_observation.messages) {{
1248
- episodeState.current_observation.messages.forEach(msg => {{
1249
- this.addMessageToChat(msg.role, msg.content);
1250
- }});
1251
- }}
1252
- }}
1253
- }}
1254
-
1255
- // Initialize the web interface when the page loads
1256
- document.addEventListener('DOMContentLoaded', () => {{
1257
- new OpenEnvWebInterface();
1258
- }});
1259
- </script>
1260
- </body>
1261
- </html>
1262
- """.replace('{_generate_action_form_fields(action_fields)}', _generate_action_form_fields(action_fields))
1263
-
1264
-
1265
- def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str:
1266
- """Generate the instructions section with environment documentation."""
1267
- if not metadata or not metadata.readme_content:
1268
- return ''
1269
-
1270
- # Convert markdown to HTML (basic conversion)
1271
- import re
1272
- html_content = _markdown_to_html(metadata.readme_content)
1273
-
1274
- return f'''
1275
- <!-- Instructions Section -->
1276
- <div class="instructions-section">
1277
- <div class="instructions-header">
1278
- <h3 class="instructions-title">{metadata.name}</h3>
1279
- <button class="instructions-toggle" id="instructions-toggle">Show Instructions</button>
1280
- </div>
1281
- <div class="instructions-content" id="instructions-content">
1282
- <div class="instructions-readme">
1283
- {html_content}
1284
- </div>
1285
- </div>
1286
- </div>
1287
- '''
1288
 
1289
 
1290
  def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
1291
  """Extract enhanced field metadata from Action class for form generation."""
1292
- import typing
1293
- from typing import get_origin, get_args
1294
-
 
 
 
 
 
 
 
1295
  action_fields = []
1296
- if not hasattr(action_cls, '__dataclass_fields__'):
1297
- return action_fields
1298
-
1299
- for field_name, field_info in action_cls.__dataclass_fields__.items():
1300
- if field_name == 'metadata':
1301
  continue
1302
-
1303
- field_type = field_info.type
1304
- field_metadata = _extract_field_metadata(field_name, field_info)
1305
-
1306
- # Determine input type based on field type
1307
- input_type = _determine_input_type(field_type)
1308
-
1309
- # Check if field is required
1310
- is_required = field_info.default is field_info.default_factory
1311
-
1312
- action_fields.append({
1313
- 'name': field_name,
1314
- 'type': input_type,
1315
- 'required': is_required,
1316
- 'description': field_metadata.get('description', ''),
1317
- 'default_value': field_metadata.get('default_value'),
1318
- 'choices': field_metadata.get('choices', []),
1319
- 'min_value': field_metadata.get('min_value'),
1320
- 'max_value': field_metadata.get('max_value'),
1321
- 'placeholder': field_metadata.get('placeholder', ''),
1322
- 'help_text': field_metadata.get('help_text', ''),
1323
- })
1324
-
 
 
1325
  return action_fields
1326
 
1327
 
1328
- def _extract_field_metadata(field_name: str, field_info) -> Dict[str, Any]:
1329
- """Extract metadata from dataclass field including docstring and type hints."""
1330
- import typing
1331
- from typing import get_origin, get_args, Literal, Union, Optional
1332
-
1333
- metadata = {}
1334
-
1335
- # Extract description from field docstring or annotation
1336
- if hasattr(field_info, 'metadata') and field_info.metadata:
1337
- # Check for custom metadata
1338
- for meta in field_info.metadata:
1339
- if isinstance(meta, dict):
1340
- metadata.update(meta)
1341
-
1342
- # Extract type information
1343
- field_type = field_info.type
1344
- origin = get_origin(field_type)
1345
-
1346
- # Handle Literal types for dropdown choices
1347
- if origin is Literal:
1348
- args = get_args(field_type)
1349
- metadata['choices'] = list(args)
1350
-
1351
- # Handle Optional types
1352
- if origin is Union:
1353
- args = get_args(field_type)
1354
- if len(args) == 2 and type(None) in args:
1355
- # This is Optional[SomeType]
1356
- non_none_type = args[0] if args[1] is type(None) else args[1]
1357
- metadata['optional'] = True
1358
- # Recursively check the non-None type for choices
1359
- if get_origin(non_none_type) is Literal:
1360
- metadata['choices'] = list(get_args(non_none_type))
1361
- else:
1362
- # Regular Union type
1363
- metadata['choices'] = [str(arg) for arg in args if arg is not type(None)]
1364
-
1365
- # Handle numeric constraints
1366
- if field_type in (int, float):
1367
- # Check for common constraint patterns in field name
1368
- if 'count' in field_name.lower() or 'num' in field_name.lower():
1369
- metadata['min_value'] = 0
1370
- if 'id' in field_name.lower():
1371
- metadata['min_value'] = 0
1372
-
1373
- # Generate placeholder text
1374
- if 'message' in field_name.lower():
1375
- metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
1376
- elif 'code' in field_name.lower():
1377
- metadata['placeholder'] = 'Enter Python code here...'
1378
- elif 'tokens' in field_name.lower():
1379
- metadata['placeholder'] = 'Enter comma-separated token IDs (e.g., 1,2,3,4,5)'
1380
- else:
1381
- metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
1382
-
1383
- # Generate help text based on field name and type
1384
- if 'action_id' in field_name.lower():
1385
- metadata['help_text'] = 'The action ID to execute in the environment'
1386
- elif 'game_name' in field_name.lower():
1387
- metadata['help_text'] = 'Name of the game or environment'
1388
- elif 'tokens' in field_name.lower():
1389
- metadata['help_text'] = 'Token IDs as a comma-separated list of integers'
1390
- elif 'code' in field_name.lower():
1391
- metadata['help_text'] = 'Python code to execute in the environment'
1392
- elif 'message' in field_name.lower():
1393
- metadata['help_text'] = 'Text message to send'
1394
-
1395
- return metadata
1396
 
 
 
 
1397
 
1398
- def _determine_input_type(field_type) -> str:
1399
- """Determine the appropriate HTML input type for a field type."""
1400
- import typing
1401
- from typing import get_origin, get_args, Literal, Union
1402
-
1403
- # Handle direct types
1404
- if field_type == str:
1405
- return "text"
1406
- elif field_type == int:
1407
- return "number"
1408
- elif field_type == float:
1409
- return "number"
1410
- elif field_type == bool:
1411
- return "checkbox"
1412
-
1413
- # Handle complex types
1414
- origin = get_origin(field_type)
1415
-
1416
- if origin is Literal:
1417
  return "select"
1418
- elif origin is Union:
1419
- args = get_args(field_type)
1420
- if len(args) == 2 and type(None) in args:
1421
- # Optional type - use the non-None type
1422
- non_none_type = args[0] if args[1] is type(None) else args[1]
1423
- return _determine_input_type(non_none_type)
1424
- elif all(isinstance(arg, str) for arg in args if arg is not type(None)):
1425
- return "select"
1426
- else:
1427
- return "text"
1428
- elif hasattr(field_type, '__name__') and 'Tensor' in field_type.__name__:
1429
- return "tensor"
1430
- else:
 
 
1431
  return "text"
1432
 
 
 
1433
 
1434
- def _markdown_to_html(markdown: str) -> str:
1435
- """Convert basic markdown to HTML for README display."""
1436
- import html
1437
- import re
1438
-
1439
- # Escape HTML first
1440
- html_content = html.escape(markdown)
1441
-
1442
- # Convert headers
1443
- html_content = re.sub(r'^# (.*?)$', r'<h1>\1</h1>', html_content, flags=re.MULTILINE)
1444
- html_content = re.sub(r'^## (.*?)$', r'<h2>\1</h2>', html_content, flags=re.MULTILINE)
1445
- html_content = re.sub(r'^### (.*?)$', r'<h3>\1</h3>', html_content, flags=re.MULTILINE)
1446
-
1447
- # Convert code blocks
1448
- html_content = re.sub(r'```(.*?)\n(.*?)\n```', r'<pre><code>\2</code></pre>', html_content, flags=re.DOTALL)
1449
- html_content = re.sub(r'`([^`]+)`', r'<code>\1</code>', html_content)
1450
-
1451
- # Convert bold and italic
1452
- html_content = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', html_content)
1453
- html_content = re.sub(r'\*(.*?)\*', r'<em>\1</em>', html_content)
1454
-
1455
- # Convert lists
1456
- html_content = re.sub(r'^- (.*?)$', r'<li>\1</li>', html_content, flags=re.MULTILINE)
1457
- html_content = re.sub(r'(<li>.*</li>)', r'<ul>\1</ul>', html_content, flags=re.DOTALL)
1458
-
1459
- # Convert line breaks
1460
- html_content = html_content.replace('\n', '<br>')
1461
-
1462
- return html_content
1463
-
1464
-
1465
- def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str:
1466
- """Generate either a chat interface or action form based on environment type."""
1467
- if is_chat_env:
1468
- return _generate_chat_interface()
1469
- else:
1470
- return _generate_action_form(action_fields)
1471
-
1472
- def _generate_chat_interface() -> str:
1473
- """Generate a chat-style interface for chat environments."""
1474
- return '''
1475
- <!-- Chat Interface -->
1476
- <div class="chat-interface">
1477
- <h3>Chat Interface</h3>
1478
- <div class="chat-messages" id="chat-messages">
1479
- <div class="chat-message system">
1480
- <div class="message-role">System</div>
1481
- <div class="message-content">Chat environment ready. Send a message to start the conversation.</div>
1482
- </div>
1483
- </div>
1484
- <div class="chat-input-container">
1485
- <div class="role-selector">
1486
- <label for="message-role">Role:</label>
1487
- <select id="message-role">
1488
- <option value="user">User</option>
1489
- <option value="assistant">Assistant</option>
1490
- </select>
1491
- </div>
1492
- <div class="message-input">
1493
- <textarea id="message-input" placeholder="Type your message here..." rows="3"></textarea>
1494
- <button class="btn" id="send-message-btn">Send Message</button>
1495
- </div>
1496
- </div>
1497
- </div>
1498
- '''
1499
-
1500
- def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str:
1501
- """Generate a traditional action form for non-chat environments."""
1502
- return f'''
1503
- <!-- Action Form -->
1504
- <div class="action-form">
1505
- <h3>Take Action</h3>
1506
- <form id="action-form">
1507
- {_generate_action_form_fields(action_fields)}
1508
- <button type="submit" class="btn" id="step-btn">Step</button>
1509
- </form>
1510
- </div>
1511
- '''
1512
-
1513
- def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str:
1514
- """Generate HTML form fields for action input with enhanced metadata."""
1515
- if not action_fields:
1516
- return '<p>No action fields available</p>'
1517
-
1518
- fields_html = []
1519
- for field in action_fields:
1520
- field_html = _generate_single_field(field)
1521
- fields_html.append(field_html)
1522
-
1523
- return '\n'.join(fields_html)
1524
-
1525
-
1526
- def _generate_single_field(field: Dict[str, Any]) -> str:
1527
- """Generate HTML for a single form field with enhanced metadata."""
1528
- field_name = field['name']
1529
- field_type = field['type']
1530
- required = field['required']
1531
- placeholder = field.get('placeholder', '')
1532
- help_text = field.get('help_text', '')
1533
- choices = field.get('choices', [])
1534
- min_value = field.get('min_value')
1535
- max_value = field.get('max_value')
1536
- default_value = field.get('default_value')
1537
-
1538
- # Build label with required indicator
1539
- label_text = field_name.replace('_', ' ').title()
1540
- if required:
1541
- label_text += ' <span style="color: red;">*</span>'
1542
-
1543
- # Build input attributes
1544
- input_attrs = []
1545
- if required:
1546
- input_attrs.append('required')
1547
- if placeholder:
1548
- input_attrs.append(f'placeholder="{placeholder}"')
1549
- if min_value is not None:
1550
- input_attrs.append(f'min="{min_value}"')
1551
- if max_value is not None:
1552
- input_attrs.append(f'max="{max_value}"')
1553
- if default_value is not None:
1554
- input_attrs.append(f'value="{default_value}"')
1555
-
1556
- attrs_str = ' '.join(input_attrs)
1557
-
1558
- if field_type == 'checkbox':
1559
- return f'''
1560
- <div class="form-group">
1561
- <label>
1562
- <input type="checkbox" name="{field_name}" value="true" {attrs_str}>
1563
- {label_text}
1564
- </label>
1565
- {f'<small class="help-text">{help_text}</small>' if help_text else ''}
1566
- </div>
1567
- '''
1568
-
1569
- elif field_type == 'select':
1570
- options_html = []
1571
- if not required:
1572
- options_html.append(f'<option value="">-- Select {label_text} --</option>')
1573
-
1574
- for choice in choices:
1575
- selected = 'selected' if str(choice) == str(default_value) else ''
1576
- options_html.append(f'<option value="{choice}" {selected}>{choice}</option>')
1577
-
1578
- return f'''
1579
- <div class="form-group">
1580
- <label for="{field_name}">{label_text}:</label>
1581
- <select name="{field_name}" id="{field_name}" {attrs_str}>
1582
- {''.join(options_html)}
1583
- </select>
1584
- {f'<small class="help-text">{help_text}</small>' if help_text else ''}
1585
- </div>
1586
- '''
1587
-
1588
- elif field_type == 'tensor':
1589
- return f'''
1590
- <div class="form-group">
1591
- <label for="{field_name}">{label_text} (comma-separated integers):</label>
1592
- <input type="text" name="{field_name}" id="{field_name}" {attrs_str}>
1593
- <small class="help-text">{help_text or 'Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)'}</small>
1594
- </div>
1595
- '''
1596
-
1597
- elif field_type == 'text' and ('message' in field_name.lower() or 'code' in field_name.lower()):
1598
- return f'''
1599
- <div class="form-group">
1600
- <label for="{field_name}">{label_text}:</label>
1601
- <textarea name="{field_name}" id="{field_name}" rows="3" {attrs_str}></textarea>
1602
- {f'<small class="help-text">{help_text}</small>' if help_text else ''}
1603
- </div>
1604
- '''
1605
-
1606
  else:
1607
- return f'''
1608
- <div class="form-group">
1609
- <label for="{field_name}">{label_text}:</label>
1610
- <input type="{field_type}" name="{field_name}" id="{field_name}" {attrs_str}>
1611
- {f'<small class="help-text">{help_text}</small>' if help_text else ''}
1612
- </div>
1613
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
  Web interface for OpenEnv environments.
9
 
10
+ When ENABLE_WEB_INTERFACE is set, the server exposes a Gradio UI at /web for
11
+ reset, step, and state observation. Controlled by the CLI enable_interface
12
+ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var.
13
  """
14
 
15
  from __future__ import annotations
16
 
17
+ import asyncio
18
  import json
19
+ from concurrent.futures import ThreadPoolExecutor
 
 
20
  from datetime import datetime
21
+ from typing import Any, Callable, Dict, List, Optional, Type
22
 
23
+ import gradio as gr
24
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
25
+ from pydantic import BaseModel, ConfigDict, Field
 
26
 
27
+ from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
28
+ from .gradio_ui import build_gradio_app, get_gradio_display_title
29
  from .interfaces import Environment
30
+ from .serialization import deserialize_action_with_preprocessing, serialize_observation
31
+ from .types import Action, EnvironmentMetadata, Observation, State
32
 
33
+ # Quick Start markdown template; placeholders match init suffixes (__ENV_NAME__, __ENV_CLASS_NAME__*).
34
+ DEFAULT_QUICK_START_MARKDOWN = """
35
+ ### Connect to this environment
36
 
37
+ Connect from Python using `__ENV_CLASS_NAME__Env`:
38
+
39
+ ```python
40
+ from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env
41
+
42
+ with __ENV_CLASS_NAME__Env.from_env("<SPACE_ID>") as env:
43
+ result = await env.step(__ENV_CLASS_NAME__Action(message="..."))
44
+ ```
45
+
46
+ Or connect directly to a running server:
47
+
48
+ ```python
49
+ env = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000")
50
+ ```
51
+
52
+ ### Contribute to this environment
53
+
54
+ Submit improvements via pull request on the Hugging Face Hub.
55
+
56
+ ```bash
57
+ openenv fork <SPACE_ID> --repo-id <your-username>/<your-repo-name>
58
+ ```
59
+
60
+ Then make your changes and submit a pull request:
61
+
62
+ ```bash
63
+ cd <forked-repo>
64
+ openenv push <SPACE_ID> --create-pr
65
+ ```
66
+
67
+ For more information, see the [OpenEnv documentation](https://meta-pytorch.org/OpenEnv/).
68
+ """
69
+
70
+
71
+ def get_quick_start_markdown(
72
+ metadata: Optional[EnvironmentMetadata],
73
+ action_cls: Type[Action],
74
+ observation_cls: Type[Observation],
75
+ ) -> str:
76
+ """
77
+ Build Quick Start markdown with class names replaced from current env (init-style suffixes).
78
+
79
+ Uses the same placeholder names as the init template so that __ENV_CLASS_NAME__Env,
80
+ __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation and __ENV_NAME__ are
81
+ replaced with the actual class/package names.
82
+ """
83
+ import os
84
+
85
+ # Prefix from action class (e.g. EchoAction -> Echo)
86
+ action_name = getattr(action_cls, "__name__", "Action")
87
+ if action_name.endswith("Action"):
88
+ prefix = action_name[: -len("Action")]
89
+ else:
90
+ prefix = action_name.replace("Action", "").strip() or "Env"
91
+
92
+ env_client_name = f"{prefix}Env"
93
+ obs_name = getattr(observation_cls, "__name__", "Observation")
94
+ pkg_name = (metadata.name if metadata else "env").replace(" ", "_").lower()
95
+
96
+ space_id = os.environ.get("SPACE_ID", "<hf-username>/<hf-repo-name>")
97
+
98
+ content = DEFAULT_QUICK_START_MARKDOWN
99
+ content = content.replace("__ENV_CLASS_NAME__Env", env_client_name)
100
+ content = content.replace("__ENV_CLASS_NAME__Action", action_name)
101
+ content = content.replace("__ENV_CLASS_NAME__Observation", obs_name)
102
+ content = content.replace("__ENV_CLASS_NAME__", prefix)
103
+ content = content.replace("__ENV_NAME__", pkg_name)
104
+ content = content.replace("<SPACE_ID>", space_id)
105
+ return content.strip()
106
+
107
+
108
+ def load_environment_metadata(
109
+ env: Environment, env_name: Optional[str] = None
110
+ ) -> EnvironmentMetadata:
111
  """
112
  Load environment metadata including README content.
113
+
114
  Args:
115
+ env: The environment instance, class, or factory function.
116
+ - If a class: used as a factory, won't call instance methods
117
+ - If a function: used as a factory, won't call instance methods
118
+ - If an instance: may call get_metadata() if available
119
  env_name: Optional environment name for README file lookup
120
+
121
  Returns:
122
  EnvironmentMetadata with loaded information
123
  """
124
+ import inspect
125
+
126
+ # Determine what type of env we received:
127
+ # 1. A class (used as factory) - e.g., PythonCodeActEnv
128
+ # 2. A function (factory function) - e.g., create_chat_environment
129
+ # 3. An actual instance - e.g., SnakeEnvironment()
130
+ is_class = inspect.isclass(env)
131
+ is_function = inspect.isfunction(env) or inspect.ismethod(env)
132
+ is_factory = is_class or is_function
133
+
134
+ # Try to get metadata from environment if it's an instance with get_metadata
135
+ if not is_factory and hasattr(env, "get_metadata"):
136
  return env.get_metadata()
137
+
138
+ # Determine the class name for default metadata
139
+ if is_class:
140
+ # env is the class itself
141
+ class_name = env.__name__
142
+ elif is_function:
143
+ # env is a factory function - use its name or derive from env_name
144
+ class_name = env_name or env.__name__
145
+ else:
146
+ # env is an instance
147
+ class_name = env.__class__.__name__
148
+
149
  # Default metadata
150
  metadata = EnvironmentMetadata(
151
+ name=env_name or class_name,
152
+ description=f"{class_name} environment",
153
+ version="1.0.0",
154
  )
155
+
156
  # Try to load README from file system
157
  readme_content = _load_readme_from_filesystem(env_name)
158
  if readme_content:
159
  metadata.readme_content = readme_content
160
+
161
  return metadata
162
 
163
 
164
  def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
165
  """
166
  Load README content from the filesystem.
167
+
168
  Tries multiple locations:
169
  1. Container filesystem: /app/README.md
170
  2. Local development: src/envs/{env_name}/README.md
 
172
  """
173
  import os
174
  from pathlib import Path
175
+
176
  # Try container filesystem first
177
  container_readme = Path("/app/README.md")
178
  if container_readme.exists():
179
  try:
180
+ return container_readme.read_text(encoding="utf-8")
181
  except Exception:
182
  pass
183
+
184
  # Try environment variable path
185
  custom_path = os.environ.get("ENV_README_PATH")
186
  if custom_path and Path(custom_path).exists():
187
  try:
188
+ return Path(custom_path).read_text(encoding="utf-8")
189
  except Exception:
190
  pass
191
+
192
  # Try local development path
193
  if env_name:
194
  local_readme = Path(f"src/envs/{env_name}/README.md")
195
  if local_readme.exists():
196
  try:
197
+ return local_readme.read_text(encoding="utf-8")
198
  except Exception:
199
  pass
200
+
201
  return None
202
 
203
 
204
+ class ActionLog(BaseModel):
 
205
  """Log entry for an action taken."""
206
+
207
+ model_config = ConfigDict(extra="forbid", validate_assignment=True)
208
+
209
+ timestamp: str = Field(description="Timestamp when action was taken")
210
+ action: Dict[str, Any] = Field(description="Action that was taken")
211
+ observation: Dict[str, Any] = Field(description="Observation returned from action")
212
+ reward: Optional[float] = Field(
213
+ default=None, description="Reward received from action"
214
+ )
215
+ done: bool = Field(description="Whether the episode is done after this action")
216
+ step_count: int = Field(description="Step count when this action was taken")
217
 
218
 
219
+ class EpisodeState(BaseModel):
 
220
  """Current episode state for the web interface."""
221
+
222
+ model_config = ConfigDict(extra="forbid", validate_assignment=True)
223
+
224
+ episode_id: Optional[str] = Field(default=None, description="Current episode ID")
225
+ step_count: int = Field(description="Current step count in episode")
226
+ current_observation: Optional[Dict[str, Any]] = Field(
227
+ default=None, description="Current observation"
228
+ )
229
+ action_logs: List[ActionLog] = Field(
230
+ default_factory=list, description="List of action logs"
231
+ )
232
+ is_reset: bool = Field(
233
+ default=True, description="Whether the episode has been reset"
234
+ )
235
 
236
 
237
  class WebInterfaceManager:
238
  """Manages the web interface for an environment."""
239
+
240
+ MAX_ACTION_LOGS = 1000
241
+
242
  def __init__(
243
  self,
244
  env: Environment,
 
246
  observation_cls: Type[Observation],
247
  metadata: Optional[EnvironmentMetadata] = None,
248
  ):
249
+ import inspect
250
+
251
+ # If env is a class or factory function, instantiate it
252
+ if inspect.isclass(env) or inspect.isfunction(env):
253
+ self.env = env()
254
+ else:
255
+ self.env = env
256
  self.action_cls = action_cls
257
  self.observation_cls = observation_cls
258
  self.metadata = metadata or EnvironmentMetadata(
259
  name=env.__class__.__name__,
260
+ description=f"{env.__class__.__name__} environment",
261
  )
262
  self.episode_state = EpisodeState(
263
  episode_id=None,
264
  step_count=0,
265
  current_observation=None,
266
+ action_logs=[],
267
  )
268
  self.connected_clients: List[WebSocket] = []
269
+ # Thread pool for running sync code (e.g., Playwright sync API) in async context
270
+ self._executor = ThreadPoolExecutor(max_workers=1)
271
+
272
+ async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
273
+ """Run a synchronous function in the thread pool executor.
274
+
275
+ This is needed for environments using sync libraries (e.g., Playwright sync API)
276
+ that cannot be called directly from an async context.
277
+ """
278
+ loop = asyncio.get_event_loop()
279
+ # Use default arguments to capture values at lambda definition time
280
+ # to avoid closure issues with late binding
281
+ return await loop.run_in_executor(
282
+ self._executor, lambda f=func, a=args, kw=kwargs: f(*a, **kw)
283
+ )
284
+
285
  async def connect_websocket(self, websocket: WebSocket):
286
  """Connect a new WebSocket client."""
287
  await websocket.accept()
288
  self.connected_clients.append(websocket)
289
+
290
  # Send current state to the new client
291
  await self._send_state_update()
292
+
293
  async def disconnect_websocket(self, websocket: WebSocket):
294
  """Disconnect a WebSocket client."""
295
  if websocket in self.connected_clients:
296
  self.connected_clients.remove(websocket)
297
+
298
  async def _send_state_update(self):
299
  """Send current state to all connected clients."""
300
  if not self.connected_clients:
301
  return
302
+
303
  state_data = {
304
  "type": "state_update",
305
+ "episode_state": self.episode_state.model_dump(),
306
  }
307
+
308
  # Send to all connected clients
309
  disconnected_clients = []
310
  for client in self.connected_clients:
311
  try:
312
  await client.send_text(json.dumps(state_data))
313
+ except Exception:
314
  disconnected_clients.append(client)
315
+
316
  # Remove disconnected clients
317
  for client in disconnected_clients:
318
  self.connected_clients.remove(client)
319
+
320
  async def reset_environment(self) -> Dict[str, Any]:
321
  """Reset the environment and update state."""
322
+ # Run sync reset in thread pool to avoid blocking event loop
323
+ # and to support environments using sync libraries (e.g., Playwright)
324
+ observation: Observation = await self._run_sync_in_thread_pool(self.env.reset)
325
+ state: State = self.env.state
326
+
327
+ # Serialize observation once using shared utility
328
+ serialized = serialize_observation(observation)
329
+
330
  # Update episode state
331
  self.episode_state.episode_id = state.episode_id
332
  self.episode_state.step_count = 0
333
+ self.episode_state.current_observation = serialized["observation"]
334
  self.episode_state.action_logs = []
335
  self.episode_state.is_reset = True
336
+
337
  # Send state update
338
  await self._send_state_update()
339
+
340
+ return serialized
341
+
 
 
 
 
342
  async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
343
  """Execute a step in the environment and update state."""
344
+ # Deserialize action with preprocessing for web interface special cases
345
+ action: Action = deserialize_action_with_preprocessing(
346
+ action_data, self.action_cls
347
+ )
348
+
349
+ # Run sync step in thread pool to avoid blocking event loop
350
+ # and to support environments using sync libraries (e.g., Playwright)
351
+ observation: Observation = await self._run_sync_in_thread_pool(
352
+ self.env.step, action
353
+ )
354
+ state: State = self.env.state
355
+
356
+ # Serialize observation once using shared utility
357
+ serialized = serialize_observation(observation)
358
+
359
  # Create action log
360
  action_log = ActionLog(
361
  timestamp=datetime.now().isoformat(),
362
+ action=action.model_dump(exclude={"metadata"}),
363
+ observation=serialized["observation"],
364
  reward=observation.reward,
365
  done=observation.done,
366
+ step_count=state.step_count,
367
  )
368
+
369
  # Update episode state
370
  self.episode_state.episode_id = state.episode_id
371
  self.episode_state.step_count = state.step_count
372
+ self.episode_state.current_observation = serialized["observation"]
373
  self.episode_state.action_logs.append(action_log)
374
+ if len(self.episode_state.action_logs) > self.MAX_ACTION_LOGS:
375
+ self.episode_state.action_logs = self.episode_state.action_logs[
376
+ -self.MAX_ACTION_LOGS :
377
+ ]
378
  self.episode_state.is_reset = False
379
+
380
  # Send state update
381
  await self._send_state_update()
382
+
383
+ return serialized
384
+
 
 
 
 
385
  def get_state(self) -> Dict[str, Any]:
386
  """Get current environment state."""
387
+ state: State = self.env.state
388
+ return state.model_dump()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
 
391
  def create_web_interface_app(
 
393
  action_cls: Type[Action],
394
  observation_cls: Type[Observation],
395
  env_name: Optional[str] = None,
396
+ max_concurrent_envs: Optional[int] = None,
397
+ concurrency_config: Optional[Any] = None,
398
+ gradio_builder: Optional[Callable[..., Any]] = None,
399
  ) -> FastAPI:
400
  """
401
  Create a FastAPI application with web interface for the given environment.
402
+
403
  Args:
404
  env: The Environment instance to serve
405
  action_cls: The Action subclass this environment expects
406
  observation_cls: The Observation subclass this environment returns
407
  env_name: Optional environment name for README loading
408
+ max_concurrent_envs: Maximum concurrent WebSocket sessions
409
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings
410
+ gradio_builder: Optional callable (web_manager, action_fields, metadata,
411
+ is_chat_env, title, quick_start_md) -> gr.Blocks to use instead of the
412
+ default Gradio UI. Lets envs replace or customize the /web interface.
413
+
414
  Returns:
415
  FastAPI application instance with web interface
416
  """
417
  from .http_server import create_fastapi_app
418
+
419
  # Create the base environment app
420
+ app = create_fastapi_app(
421
+ env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
422
+ )
423
+
424
  # Load environment metadata
425
  metadata = load_environment_metadata(env, env_name)
426
+
427
  # Create web interface manager
428
  web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
429
+
430
+ # Web API routes first (so they take precedence over Gradio mount at /web)
 
 
 
 
 
431
  @app.get("/web/metadata")
432
  async def web_metadata():
433
  """Get environment metadata."""
434
+ return web_manager.metadata.model_dump()
435
+
436
+ @app.websocket("/ws/ui")
437
+ async def websocket_ui_endpoint(websocket: WebSocket):
438
+ """WebSocket endpoint for web UI real-time updates.
439
+
440
+ Note: Uses /ws/ui to avoid conflict with /ws in http_server.py
441
+ which is used for concurrent environment sessions.
442
+ """
443
  await web_manager.connect_websocket(websocket)
444
  try:
445
  while True:
 
447
  await websocket.receive_text()
448
  except WebSocketDisconnect:
449
  await web_manager.disconnect_websocket(websocket)
450
+
451
  @app.post("/web/reset")
452
  async def web_reset():
453
  """Reset endpoint for web interface."""
454
  return await web_manager.reset_environment()
455
+
456
  @app.post("/web/step")
457
  async def web_step(request: Dict[str, Any]):
458
  """Step endpoint for web interface."""
459
  # Check if this is a message-based request (chat environment)
460
  if "message" in request:
461
  message = request["message"]
462
+ if hasattr(web_manager.env, "message_to_action"):
463
+ action = web_manager.env.message_to_action(message)
464
+ if hasattr(action, "tokens"):
465
+ action_data = {"tokens": action.tokens.tolist()}
466
+ else:
467
+ action_data = action.model_dump(exclude={"metadata"})
468
+ else:
469
+ action_data = {"message": message}
470
  else:
471
  action_data = request.get("action", {})
472
+
473
  return await web_manager.step_environment(action_data)
474
+
475
  @app.get("/web/state")
476
  async def web_state():
477
  """State endpoint for web interface."""
478
  return web_manager.get_state()
479
+
480
+ action_fields = _extract_action_fields(action_cls)
481
+ is_chat_env = _is_chat_env(action_cls)
482
+ quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls)
483
+
484
+ default_blocks = build_gradio_app(
485
+ web_manager,
486
+ action_fields,
487
+ metadata,
488
+ is_chat_env,
489
+ title=metadata.name,
490
+ quick_start_md=quick_start_md,
491
+ )
492
+ if gradio_builder is not None:
493
+ custom_blocks = gradio_builder(
494
+ web_manager,
495
+ action_fields,
496
+ metadata,
497
+ is_chat_env,
498
+ metadata.name,
499
+ quick_start_md,
500
+ )
501
+ if not isinstance(custom_blocks, gr.Blocks):
502
+ raise TypeError(
503
+ f"gradio_builder must return a gr.Blocks instance, "
504
+ f"got {type(custom_blocks).__name__}"
505
+ )
506
+ gradio_blocks = gr.TabbedInterface(
507
+ [default_blocks, custom_blocks],
508
+ tab_names=["Playground", "Visualization"],
509
+ title=get_gradio_display_title(metadata),
510
+ )
511
+ else:
512
+ gradio_blocks = default_blocks
513
+ app = gr.mount_gradio_app(
514
+ app,
515
+ gradio_blocks,
516
+ path="/web",
517
+ theme=OPENENV_GRADIO_THEME,
518
+ css=OPENENV_GRADIO_CSS,
519
+ )
520
+
521
  return app
522
 
523
 
524
+ def _is_chat_env(action_cls: Type[Action]) -> bool:
525
+ """Return True if the action class is a chat-style env (tokens field)."""
526
+ if hasattr(action_cls, "model_fields"):
527
+ for field_name, field_info in action_cls.model_fields.items():
528
+ if (
529
+ field_name == "tokens"
530
+ and hasattr(field_info.annotation, "__name__")
531
+ and "Tensor" in str(field_info.annotation)
532
+ ):
533
+ return True
534
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
 
537
  def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
538
  """Extract enhanced field metadata from Action class for form generation."""
539
+ # Use Pydantic's JSON schema generation for robust metadata extraction
540
+ try:
541
+ schema = action_cls.model_json_schema()
542
+ except AttributeError:
543
+ # Fallback for non-Pydantic v2 models or if something goes wrong
544
+ return []
545
+
546
+ properties = schema.get("properties", {})
547
+ required_fields = schema.get("required", [])
548
+
549
  action_fields = []
550
+
551
+ for field_name, field_info in properties.items():
552
+ if field_name == "metadata":
 
 
553
  continue
554
+
555
+ # JSON schema "type" can be a string or list/undefined
556
+ # Determine our internal input type
557
+ input_type = _determine_input_type_from_schema(field_info, field_name)
558
+
559
+ is_required = field_name in required_fields
560
+
561
+ action_fields.append(
562
+ {
563
+ "name": field_name,
564
+ "type": input_type,
565
+ "required": is_required,
566
+ "description": field_info.get("description", ""),
567
+ "default_value": field_info.get("default"),
568
+ "choices": field_info.get("enum"),
569
+ "min_value": field_info.get("minimum"),
570
+ "max_value": field_info.get("maximum"),
571
+ "min_length": field_info.get("minLength"),
572
+ "max_length": field_info.get("maxLength"),
573
+ "pattern": field_info.get("pattern"),
574
+ "placeholder": _generate_placeholder(field_name, field_info),
575
+ "help_text": _generate_help_text(field_name, field_info),
576
+ }
577
+ )
578
+
579
  return action_fields
580
 
581
 
582
+ def _determine_input_type_from_schema(
583
+ field_info: Dict[str, Any], field_name: str
584
+ ) -> str:
585
+ """Determine input type from JSON schema for form generation (Gradio UI)."""
586
+ schema_type = field_info.get("type")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
+ # Check for specific tensor field convention
589
+ if "tokens" in field_name.lower():
590
+ return "tensor"
591
 
592
+ if "enum" in field_info:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  return "select"
594
+
595
+ if schema_type == "boolean":
596
+ return "checkbox"
597
+
598
+ if schema_type == "integer" or schema_type == "number":
599
+ return "number"
600
+
601
+ if schema_type == "string":
602
+ # Check if it should be a textarea
603
+ if (
604
+ field_info.get("maxLength", 0) > 100
605
+ or "message" in field_name.lower()
606
+ or "code" in field_name.lower()
607
+ ):
608
+ return "textarea"
609
  return "text"
610
 
611
+ # Default fallback
612
+ return "text"
613
 
614
+
615
+ def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str:
616
+ """Generate placeholder text."""
617
+ if "message" in field_name.lower():
618
+ return f"Enter {field_name.replace('_', ' ')}..."
619
+ elif "code" in field_name.lower():
620
+ return "Enter Python code here..."
621
+ elif "tokens" in field_name.lower():
622
+ return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
  else:
624
+ return f"Enter {field_name.replace('_', ' ')}..."
625
+
626
+
627
+ def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str:
628
+ """Generate help text."""
629
+ description = field_info.get("description", "")
630
+ if description:
631
+ return description
632
+
633
+ if "action_id" in field_name.lower():
634
+ return "The action ID to execute in environment"
635
+ elif "game_name" in field_name.lower():
636
+ return "Name of game or environment"
637
+ elif "tokens" in field_name.lower():
638
+ return "Token IDs as a comma-separated list of integers"
639
+ elif "code" in field_name.lower():
640
+ return "Python code to execute in environment"
641
+ elif "message" in field_name.lower():
642
+ return "Text message to send"
643
+
644
+ return ""
src/core/evals/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Evaluation harness support for OpenEnv."""
8
+
9
+ from openenv.core.evals.base import EvalHarness
10
+ from openenv.core.evals.inspect_harness import InspectAIHarness
11
+ from openenv.core.evals.types import EvalConfig, EvalResult
12
+
13
+ __all__ = [
14
+ "EvalHarness",
15
+ "EvalConfig",
16
+ "EvalResult",
17
+ "InspectAIHarness",
18
+ ]
src/core/evals/base.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Base class for evaluation harnesses."""
8
+
9
+ from abc import ABC, abstractmethod
10
+ from typing import Any, Dict
11
+
12
+ from openenv.core.evals.types import EvalConfig, EvalResult
13
+
14
+
15
+ class EvalHarness(ABC):
16
+ """Abstract base class for evaluation harnesses.
17
+
18
+ Subclasses implement run() to define evaluation logic.
19
+ """
20
+
21
+ @abstractmethod
22
+ def run(
23
+ self,
24
+ harness_version: str,
25
+ library_versions: Dict[str, str],
26
+ dataset: str,
27
+ eval_parameters: Dict[str, Any],
28
+ ) -> Dict[str, Any]:
29
+ """Run the evaluation and return scores.
30
+
31
+ Args:
32
+ harness_version: Version of the evaluation harness.
33
+ library_versions: Versions of libraries used in the evaluation.
34
+ dataset: Name of the dataset to evaluate on.
35
+ eval_parameters: Parameters for the evaluation.
36
+
37
+ Returns:
38
+ Dictionary of scores from the evaluation.
39
+ """
40
+ raise NotImplementedError
41
+
42
+ def run_from_config(self, config: EvalConfig) -> EvalResult:
43
+ """Run evaluation from an EvalConfig and return an EvalResult.
44
+
45
+ Args:
46
+ config: Configuration for the evaluation.
47
+
48
+ Returns:
49
+ EvalResult containing the config and scores.
50
+ """
51
+ scores = self.run(
52
+ harness_version=config.harness_version,
53
+ library_versions=config.library_versions,
54
+ dataset=config.dataset,
55
+ eval_parameters=config.eval_parameters,
56
+ )
57
+ return EvalResult(config=config, scores=scores)
58
+
59
+ @property
60
+ def name(self) -> str:
61
+ """Return the name of the harness (class name)."""
62
+ return self.__class__.__name__
src/core/evals/inspect_harness.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Inspect AI harness integration for OpenEnv.
8
+
9
+ Requires the ``inspect-ai`` package: ``pip install 'inspect-ai>=0.3.0'``
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Any, Dict, Optional
15
+
16
+ from openenv.core.evals.base import EvalHarness
17
+
18
+
19
+ class InspectAIHarness(EvalHarness):
20
+ """Evaluation harness wrapping Inspect AI's ``eval()`` function.
21
+
22
+ All ``inspect_ai`` imports are deferred to :meth:`run` so this class is
23
+ importable without inspect-ai installed. An ``ImportError`` with a clear
24
+ message is raised at call time if the dependency is missing.
25
+
26
+ Args:
27
+ log_dir: Directory for evaluation log output. Defaults to None
28
+ (Inspect AI writes logs to its default location).
29
+
30
+ ``eval_parameters`` keys accepted by :meth:`run`:
31
+
32
+ +--------------------------+----------+-----------------+-----------------------------------+
33
+ | Key | Type | Default | Purpose |
34
+ +==========================+==========+=================+===================================+
35
+ | ``model`` | str | *required* | Model string, e.g. "openai/gpt-4o"|
36
+ | ``task`` | str|None | ``dataset`` arg | Task file path or task string |
37
+ | ``task_args`` | dict | ``{}`` | Arguments to pass to the task |
38
+ | ``max_samples`` | int|None | None | Limit samples per task |
39
+ | ``temperature`` | float|None| None | Model generation temperature |
40
+ | ``max_tokens`` | int|None | None | Max generation tokens |
41
+ | ``epochs`` | int|None | None | Number of evaluation epochs |
42
+ | ``solver`` | list|None| None | Solver pipeline override |
43
+ | ``scorer`` | list|None| None | Scorer override |
44
+ | ``model_args`` | dict | ``{}`` | Provider-specific model kwargs |
45
+ +--------------------------+----------+-----------------+-----------------------------------+
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ *,
51
+ log_dir: Optional[str] = None,
52
+ ):
53
+ self.log_dir = log_dir
54
+
55
+ def run(
56
+ self,
57
+ harness_version: str,
58
+ library_versions: Dict[str, str],
59
+ dataset: str,
60
+ eval_parameters: Dict[str, Any],
61
+ ) -> Dict[str, Any]:
62
+ """Run an Inspect AI evaluation.
63
+
64
+ Args:
65
+ harness_version: Version of inspect-ai being used.
66
+ library_versions: Versions of supporting libraries.
67
+ dataset: Default task string (used when ``task`` is not specified
68
+ in *eval_parameters*).
69
+ eval_parameters: See class docstring for accepted keys.
70
+
71
+ Returns:
72
+ Dictionary mapping metric names to scores.
73
+
74
+ Raises:
75
+ ImportError: If ``inspect-ai`` is not installed.
76
+ ValueError: If ``model`` is missing from *eval_parameters*.
77
+ RuntimeError: If the evaluation fails (log status is not "success").
78
+ """
79
+ try:
80
+ from inspect_ai import eval as inspect_eval
81
+ except ImportError:
82
+ raise ImportError(
83
+ "inspect-ai is required for InspectAIHarness. "
84
+ "Install it with: pip install 'inspect-ai>=0.3.0'"
85
+ )
86
+
87
+ # Extract required model parameter
88
+ model = eval_parameters.get("model")
89
+ if model is None:
90
+ raise ValueError(
91
+ "eval_parameters must include 'model' "
92
+ "(e.g. 'openai/gpt-4o', 'hf/meta-llama/...')."
93
+ )
94
+
95
+ # Task: explicit parameter or fall back to dataset
96
+ task = eval_parameters.get("task", dataset)
97
+
98
+ # Build eval kwargs
99
+ eval_kwargs: Dict[str, Any] = {}
100
+
101
+ task_args = eval_parameters.get("task_args", {})
102
+ if task_args:
103
+ eval_kwargs["task_args"] = task_args
104
+
105
+ model_args = eval_parameters.get("model_args", {})
106
+ if model_args:
107
+ eval_kwargs["model_args"] = model_args
108
+
109
+ for key in ("max_samples", "temperature", "max_tokens", "epochs"):
110
+ value = eval_parameters.get(key)
111
+ if value is not None:
112
+ eval_kwargs[key] = value
113
+
114
+ if eval_parameters.get("solver") is not None:
115
+ eval_kwargs["solver"] = eval_parameters["solver"]
116
+
117
+ if eval_parameters.get("scorer") is not None:
118
+ eval_kwargs["scorer"] = eval_parameters["scorer"]
119
+
120
+ if self.log_dir is not None:
121
+ eval_kwargs["log_dir"] = self.log_dir
122
+
123
+ # Run evaluation
124
+ logs = inspect_eval(task, model=model, **eval_kwargs)
125
+
126
+ # Extract results from the first log
127
+ if not logs:
128
+ raise RuntimeError(
129
+ "Inspect AI evaluation returned no logs. "
130
+ "Check that the task and model arguments are valid."
131
+ )
132
+ log = logs[0]
133
+ if log.status != "success":
134
+ raise RuntimeError(
135
+ f"Inspect AI evaluation failed with status: {log.status}"
136
+ )
137
+
138
+ return self._extract_scores(log)
139
+
140
+ def _extract_scores(self, log: Any) -> Dict[str, Any]:
141
+ """Parse an EvalLog's results into a flat score dictionary.
142
+
143
+ Iterates over ``log.results.scores`` (a list of ``EvalScore``),
144
+ flattening each scorer's ``metrics`` dict into a single output dict.
145
+
146
+ Args:
147
+ log: An ``inspect_ai`` ``EvalLog`` object.
148
+
149
+ Returns:
150
+ Dictionary mapping metric names to their values.
151
+ """
152
+ scores: Dict[str, Any] = {}
153
+ if log.results is None:
154
+ return scores
155
+
156
+ for eval_score in log.results.scores:
157
+ for metric_name, metric in eval_score.metrics.items():
158
+ scores[metric_name] = metric.value
159
+
160
+ return scores