Spaces:
Running
Running
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +36 -15
- README.md +401 -27
- __init__.py +31 -0
- client.py +120 -0
- envs/atari_env/README.md +408 -0
- envs/atari_env/__init__.py +31 -0
- envs/atari_env/client.py +120 -0
- envs/atari_env/models.py +85 -0
- envs/atari_env/server/Dockerfile +43 -0
- envs/atari_env/server/__init__.py +15 -0
- envs/atari_env/server/app.py +80 -0
- envs/atari_env/server/atari_environment.py +246 -0
- envs/atari_env/server/requirements.txt +3 -0
- envs/atari_env/test_atari_docker.sh +333 -0
- models.py +85 -0
- pyproject.toml +147 -0
- server/Dockerfile +43 -0
- server/__init__.py +15 -0
- server/app.py +80 -0
- server/atari_environment.py +246 -0
- server/requirements.txt +3 -0
- src/__init__.py +7 -0
- src/core/README.md +212 -0
- src/core/__init__.py +70 -8
- src/core/client_types.py +23 -0
- src/core/containers/__init__.py +1 -1
- src/core/containers/images/Dockerfile +29 -11
- src/core/containers/images/README.md +8 -8
- src/core/containers/runtime/__init__.py +12 -2
- src/core/containers/runtime/daytona_provider.py +572 -0
- src/core/containers/runtime/providers.py +389 -9
- src/core/containers/runtime/uv_provider.py +224 -0
- src/core/containers/test_local_docker_provider.py +8 -6
- src/core/env_client.py +484 -0
- src/core/env_server/__init__.py +118 -3
- src/core/env_server/base_transforms.py +1 -1
- src/core/env_server/exceptions.py +105 -0
- src/core/env_server/gradio_theme.py +128 -0
- src/core/env_server/gradio_ui.py +240 -0
- src/core/env_server/http_server.py +1263 -105
- src/core/env_server/interfaces.py +189 -10
- src/core/env_server/mcp_environment.py +624 -0
- src/core/env_server/mcp_types.py +321 -0
- src/core/env_server/route_config.py +57 -0
- src/core/env_server/serialization.py +137 -0
- src/core/env_server/types.py +361 -31
- src/core/env_server/web_interface.py +426 -1395
- src/core/evals/__init__.py +18 -0
- src/core/evals/base.py +62 -0
- src/core/evals/inspect_harness.py +160 -0
Dockerfile
CHANGED
|
@@ -1,20 +1,39 @@
|
|
| 1 |
-
#
|
| 2 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
#
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Install
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
# Copy only what's needed for this environment
|
| 16 |
COPY src/core/ /app/src/core/
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# Health check
|
| 20 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
|
@@ -22,4 +41,6 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
|
| 22 |
|
| 23 |
# Run the FastAPI server
|
| 24 |
CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
|
|
|
| 25 |
ENV ENABLE_WEB_INTERFACE=true
|
|
|
|
| 1 |
+
# Dockerfile for Atari Environment
|
| 2 |
+
# This image provides Atari 2600 games via the Arcade Learning Environment (ALE)
|
| 3 |
+
|
| 4 |
+
# Configurable base image - defaults to local build, can be overridden for CI/CD
|
| 5 |
+
# Base image provides: fastapi, uvicorn, requests, curl, PYTHONPATH=/app/src
|
| 6 |
+
#
|
| 7 |
+
# Local build: docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile .
|
| 8 |
+
# docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 9 |
#
|
| 10 |
+
# CI/CD build: docker build --build-arg BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest \
|
| 11 |
+
# -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 12 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 13 |
+
FROM ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
|
| 15 |
+
# Install dependencies
|
| 16 |
+
COPY envs/atari_env/server/requirements.txt /tmp/requirements.txt
|
| 17 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy OpenEnv core (base image already set WORKDIR=/app)
|
|
|
|
|
|
|
| 20 |
COPY src/core/ /app/src/core/
|
| 21 |
+
|
| 22 |
+
# Copy Atari environment code
|
| 23 |
+
COPY envs/atari_env/ /app/envs/atari_env/
|
| 24 |
+
|
| 25 |
+
# Copy README for web interface documentation
|
| 26 |
+
COPY envs/atari_env/README.md /app/README.md
|
| 27 |
+
|
| 28 |
+
# Atari-specific environment variables (can be overridden at runtime)
|
| 29 |
+
ENV ATARI_GAME=pong
|
| 30 |
+
ENV ATARI_OBS_TYPE=rgb
|
| 31 |
+
ENV ATARI_FULL_ACTION_SPACE=false
|
| 32 |
+
ENV ATARI_REPEAT_ACTION_PROB=0.0
|
| 33 |
+
ENV ATARI_FRAMESKIP=4
|
| 34 |
+
|
| 35 |
+
# Expose port
|
| 36 |
+
EXPOSE 8000
|
| 37 |
|
| 38 |
# Health check
|
| 39 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
|
|
|
| 41 |
|
| 42 |
# Run the FastAPI server
|
| 43 |
CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 44 |
+
ENV PYTHONPATH=/app/src/core:/app/src:${PYTHONPATH}
|
| 45 |
+
|
| 46 |
ENV ENABLE_WEB_INTERFACE=true
|
README.md
CHANGED
|
@@ -1,51 +1,425 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🕹️
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
app_port: 8000
|
| 9 |
base_path: /web
|
| 10 |
tags:
|
|
|
|
| 11 |
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
-
#
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
Built with FastAPI and OpenEnv framework.
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
- **State Observer**: Real-time view of environment state and action history
|
| 28 |
-
- **Live Updates**: WebSocket-based real-time updates
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
```
|
| 44 |
|
| 45 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
|
|
|
| 48 |
|
| 49 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Atari Environment Server
|
| 3 |
emoji: 🕹️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
app_port: 8000
|
| 9 |
base_path: /web
|
| 10 |
tags:
|
| 11 |
+
- openenv-main
|
| 12 |
- openenv
|
| 13 |
---
|
| 14 |
|
| 15 |
+
## Hugging Face Space Deployment
|
| 16 |
|
| 17 |
+
This Space is built from OpenEnv environment `atari_env`.
|
| 18 |
|
| 19 |
+
- Space URL: `https://huggingface.co/spaces/openenv/atari_env`
|
| 20 |
+
- OpenEnv pinned ref: `main`
|
| 21 |
+
- Hub tag: `openenv`
|
| 22 |
|
| 23 |
+
### Connecting from Code
|
|
|
|
| 24 |
|
| 25 |
+
```python
|
| 26 |
+
from envs.atari_env import AtariEnv
|
| 27 |
|
| 28 |
+
env = AtariEnv(base_url="https://huggingface.co/spaces/openenv/atari_env")
|
| 29 |
+
```
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# Atari Environment
|
| 32 |
|
| 33 |
+
Integration of Atari 2600 games with the OpenEnv framework via the Arcade Learning Environment (ALE). ALE provides access to 100+ classic Atari games for RL research.
|
| 34 |
|
| 35 |
+
## Supported Games
|
| 36 |
|
| 37 |
+
ALE supports 100+ Atari 2600 games including:
|
| 38 |
+
|
| 39 |
+
### Popular Games
|
| 40 |
+
- **Pong** - Classic two-player tennis
|
| 41 |
+
- **Breakout** - Break bricks with a ball
|
| 42 |
+
- **Space Invaders** - Shoot descending aliens
|
| 43 |
+
- **Pac-Man / Ms. Pac-Man** - Navigate mazes and eat pellets
|
| 44 |
+
- **Asteroids** - Destroy asteroids in space
|
| 45 |
+
- **Defender** - Side-scrolling space shooter
|
| 46 |
+
- **Centipede** - Shoot segmented centipede
|
| 47 |
+
- **Donkey Kong** - Jump over barrels to save princess
|
| 48 |
+
- **Frogger** - Cross road and river safely
|
| 49 |
+
- **Q*bert** - Jump on pyramid cubes
|
| 50 |
+
|
| 51 |
+
And many more! For a complete list, see [ALE documentation](https://ale.farama.org/environments/complete_list/).
|
| 52 |
+
|
| 53 |
+
## Architecture
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
┌────────────────────────────────────┐
|
| 57 |
+
│ RL Training Code (Client) │
|
| 58 |
+
│ AtariEnv.step(action) │
|
| 59 |
+
└──────────────┬─────────────────────┘
|
| 60 |
+
│ HTTP
|
| 61 |
+
┌──────────────▼─────────────────────┐
|
| 62 |
+
│ FastAPI Server (Docker) │
|
| 63 |
+
│ AtariEnvironment │
|
| 64 |
+
│ ├─ Wraps ALEInterface │
|
| 65 |
+
│ ├─ Handles observations │
|
| 66 |
+
│ └─ Action execution │
|
| 67 |
+
└────────────────────────────────────┘
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Installation & Usage
|
| 71 |
+
|
| 72 |
+
### Option 1: Local Development (without Docker)
|
| 73 |
+
|
| 74 |
+
**Requirements:**
|
| 75 |
+
- Python 3.11+
|
| 76 |
+
- ale-py installed: `pip install ale-py`
|
| 77 |
+
|
| 78 |
+
The client is **async by default**:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
import asyncio
|
| 82 |
+
from atari_env import AtariEnv, AtariAction
|
| 83 |
+
|
| 84 |
+
async def main():
|
| 85 |
+
# Start local server manually: python -m atari_env.server.app
|
| 86 |
+
async with AtariEnv(base_url="http://localhost:8000") as env:
|
| 87 |
+
# Reset environment
|
| 88 |
+
result = await env.reset()
|
| 89 |
+
print(f"Screen shape: {result.observation.screen_shape}")
|
| 90 |
+
print(f"Legal actions: {result.observation.legal_actions}")
|
| 91 |
+
|
| 92 |
+
# Take actions
|
| 93 |
+
for _ in range(10):
|
| 94 |
+
result = await env.step(AtariAction(action_id=2, game_name="pong"))
|
| 95 |
+
print(f"Reward: {result.reward}, Done: {result.done}")
|
| 96 |
+
if result.done:
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
asyncio.run(main())
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
For **synchronous usage**, use the `.sync()` wrapper:
|
| 103 |
+
|
| 104 |
+
```python
|
| 105 |
+
from atari_env import AtariEnv, AtariAction
|
| 106 |
+
|
| 107 |
+
with AtariEnv(base_url="http://localhost:8000").sync() as env:
|
| 108 |
+
result = env.reset()
|
| 109 |
+
result = env.step(AtariAction(action_id=2, game_name="pong"))
|
| 110 |
+
print(f"Reward: {result.reward}")
|
| 111 |
```
|
| 112 |
|
| 113 |
+
### Option 2: Docker (Recommended)
|
| 114 |
+
|
| 115 |
+
**Build Atari image:**
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
cd OpenEnv
|
| 119 |
+
|
| 120 |
+
# Build the image
|
| 121 |
+
docker build \
|
| 122 |
+
-f envs/atari_env/server/Dockerfile \
|
| 123 |
+
-t atari-env:latest \
|
| 124 |
+
.
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
**Run specific games:**
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
# Pong (default)
|
| 131 |
+
docker run -p 8000:8000 atari-env:latest
|
| 132 |
|
| 133 |
+
# Breakout
|
| 134 |
+
docker run -p 8000:8000 -e ATARI_GAME=breakout atari-env:latest
|
| 135 |
|
| 136 |
+
# Space Invaders with grayscale observation
|
| 137 |
+
docker run -p 8000:8000 \
|
| 138 |
+
-e ATARI_GAME=space_invaders \
|
| 139 |
+
-e ATARI_OBS_TYPE=grayscale \
|
| 140 |
+
atari-env:latest
|
| 141 |
|
| 142 |
+
# Ms. Pac-Man with full action space
|
| 143 |
+
docker run -p 8000:8000 \
|
| 144 |
+
-e ATARI_GAME=ms_pacman \
|
| 145 |
+
-e ATARI_FULL_ACTION_SPACE=true \
|
| 146 |
+
atari-env:latest
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
**Use with from_docker_image():**
|
| 150 |
+
|
| 151 |
+
```python
|
| 152 |
+
import asyncio
|
| 153 |
+
import numpy as np
|
| 154 |
+
from atari_env import AtariEnv, AtariAction
|
| 155 |
+
|
| 156 |
+
async def main():
|
| 157 |
+
# Automatically starts container
|
| 158 |
+
client = await AtariEnv.from_docker_image("atari-env:latest")
|
| 159 |
+
|
| 160 |
+
async with client:
|
| 161 |
+
result = await client.reset()
|
| 162 |
+
result = await client.step(AtariAction(action_id=2)) # UP
|
| 163 |
+
|
| 164 |
+
# Reshape screen for visualization
|
| 165 |
+
screen = np.array(result.observation.screen).reshape(result.observation.screen_shape)
|
| 166 |
+
print(f"Screen shape: {screen.shape}") # (210, 160, 3) for RGB
|
| 167 |
+
|
| 168 |
+
asyncio.run(main())
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
## Observation Types
|
| 172 |
+
|
| 173 |
+
### 1. RGB (Default)
|
| 174 |
+
- **Shape**: [210, 160, 3]
|
| 175 |
+
- **Description**: Full-color screen observation
|
| 176 |
+
- **Usage**: Most realistic, good for vision-based learning
|
| 177 |
+
|
| 178 |
+
```python
|
| 179 |
+
docker run -p 8000:8000 -e ATARI_OBS_TYPE=rgb atari-env:latest
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
### 2. Grayscale
|
| 183 |
+
- **Shape**: [210, 160]
|
| 184 |
+
- **Description**: Grayscale screen observation
|
| 185 |
+
- **Usage**: Reduced dimensionality, faster processing
|
| 186 |
+
|
| 187 |
+
```python
|
| 188 |
+
docker run -p 8000:8000 -e ATARI_OBS_TYPE=grayscale atari-env:latest
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
### 3. RAM
|
| 192 |
+
- **Shape**: [128]
|
| 193 |
+
- **Description**: Raw 128-byte Atari 2600 RAM contents
|
| 194 |
+
- **Usage**: Compact representation, useful for specific research
|
| 195 |
+
|
| 196 |
+
```python
|
| 197 |
+
docker run -p 8000:8000 -e ATARI_OBS_TYPE=ram atari-env:latest
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
## Action Spaces
|
| 201 |
+
|
| 202 |
+
### Minimal Action Set (Default)
|
| 203 |
+
Game-specific minimal actions (typically 4-9 actions).
|
| 204 |
+
- Pong: 6 actions (NOOP, FIRE, UP, DOWN, etc.)
|
| 205 |
+
- Breakout: 4 actions (NOOP, FIRE, LEFT, RIGHT)
|
| 206 |
+
|
| 207 |
+
```python
|
| 208 |
+
docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=false atari-env:latest
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
### Full Action Set
|
| 212 |
+
All 18 possible Atari 2600 actions:
|
| 213 |
+
0. NOOP
|
| 214 |
+
1. FIRE
|
| 215 |
+
2. UP
|
| 216 |
+
3. RIGHT
|
| 217 |
+
4. LEFT
|
| 218 |
+
5. DOWN
|
| 219 |
+
6. UPRIGHT
|
| 220 |
+
7. UPLEFT
|
| 221 |
+
8. DOWNRIGHT
|
| 222 |
+
9. DOWNLEFT
|
| 223 |
+
10. UPFIRE
|
| 224 |
+
11. RIGHTFIRE
|
| 225 |
+
12. LEFTFIRE
|
| 226 |
+
13. DOWNFIRE
|
| 227 |
+
14. UPRIGHTFIRE
|
| 228 |
+
15. UPLEFTFIRE
|
| 229 |
+
16. DOWNRIGHTFIRE
|
| 230 |
+
17. DOWNLEFTFIRE
|
| 231 |
+
|
| 232 |
+
```python
|
| 233 |
+
docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=true atari-env:latest
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
## Configuration
|
| 237 |
+
|
| 238 |
+
### Environment Variables
|
| 239 |
+
|
| 240 |
+
- `ATARI_GAME`: Game name (default: "pong")
|
| 241 |
+
- `ATARI_OBS_TYPE`: Observation type - "rgb", "grayscale", "ram" (default: "rgb")
|
| 242 |
+
- `ATARI_FULL_ACTION_SPACE`: Use full action space - "true"/"false" (default: "false")
|
| 243 |
+
- `ATARI_MODE`: Game mode (optional, game-specific)
|
| 244 |
+
- `ATARI_DIFFICULTY`: Game difficulty (optional, game-specific)
|
| 245 |
+
- `ATARI_REPEAT_ACTION_PROB`: Sticky action probability 0.0-1.0 (default: "0.0")
|
| 246 |
+
- `ATARI_FRAMESKIP`: Frames to skip per action (default: "4")
|
| 247 |
+
|
| 248 |
+
### Example: Breakout with Custom Settings
|
| 249 |
+
|
| 250 |
+
```bash
|
| 251 |
+
docker run -p 8000:8000 \
|
| 252 |
+
-e ATARI_GAME=breakout \
|
| 253 |
+
-e ATARI_OBS_TYPE=grayscale \
|
| 254 |
+
-e ATARI_FULL_ACTION_SPACE=true \
|
| 255 |
+
-e ATARI_REPEAT_ACTION_PROB=0.25 \
|
| 256 |
+
-e ATARI_FRAMESKIP=4 \
|
| 257 |
+
atari-env:latest
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
## API Reference
|
| 261 |
+
|
| 262 |
+
### AtariAction
|
| 263 |
+
|
| 264 |
+
```python
|
| 265 |
+
@dataclass
|
| 266 |
+
class AtariAction(Action):
|
| 267 |
+
action_id: int # Action index to execute
|
| 268 |
+
game_name: str = "pong" # Game name
|
| 269 |
+
obs_type: str = "rgb" # Observation type
|
| 270 |
+
full_action_space: bool = False # Full or minimal action space
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
### AtariObservation
|
| 274 |
+
|
| 275 |
+
```python
|
| 276 |
+
@dataclass
|
| 277 |
+
class AtariObservation(Observation):
|
| 278 |
+
screen: List[int] # Flattened screen pixels
|
| 279 |
+
screen_shape: List[int] # Original screen shape
|
| 280 |
+
legal_actions: List[int] # Legal action indices
|
| 281 |
+
lives: int # Lives remaining
|
| 282 |
+
episode_frame_number: int # Frame # in episode
|
| 283 |
+
frame_number: int # Total frame #
|
| 284 |
+
done: bool # Episode finished
|
| 285 |
+
reward: Optional[float] # Reward from last action
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
### AtariState
|
| 289 |
+
|
| 290 |
+
```python
|
| 291 |
+
@dataclass
|
| 292 |
+
class AtariState(State):
|
| 293 |
+
episode_id: str # Unique episode ID
|
| 294 |
+
step_count: int # Number of steps
|
| 295 |
+
game_name: str # Game name
|
| 296 |
+
obs_type: str # Observation type
|
| 297 |
+
full_action_space: bool # Action space type
|
| 298 |
+
mode: Optional[int] # Game mode
|
| 299 |
+
difficulty: Optional[int] # Game difficulty
|
| 300 |
+
repeat_action_probability: float # Sticky action prob
|
| 301 |
+
frameskip: int # Frameskip setting
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
## Example Script
|
| 305 |
+
|
| 306 |
+
```python
|
| 307 |
+
#!/usr/bin/env python3
|
| 308 |
+
"""Example training loop with Atari environment."""
|
| 309 |
+
|
| 310 |
+
import asyncio
|
| 311 |
+
import numpy as np
|
| 312 |
+
from atari_env import AtariEnv, AtariAction
|
| 313 |
+
|
| 314 |
+
async def train():
|
| 315 |
+
# Start environment
|
| 316 |
+
client = await AtariEnv.from_docker_image("atari-env:latest")
|
| 317 |
+
|
| 318 |
+
async with client:
|
| 319 |
+
# Training loop
|
| 320 |
+
for episode in range(10):
|
| 321 |
+
result = await client.reset()
|
| 322 |
+
episode_reward = 0
|
| 323 |
+
steps = 0
|
| 324 |
+
|
| 325 |
+
while not result.done:
|
| 326 |
+
# Random policy (replace with your RL agent)
|
| 327 |
+
action_id = np.random.choice(result.observation.legal_actions)
|
| 328 |
+
|
| 329 |
+
# Take action
|
| 330 |
+
result = await client.step(AtariAction(action_id=action_id))
|
| 331 |
+
|
| 332 |
+
episode_reward += result.reward or 0
|
| 333 |
+
steps += 1
|
| 334 |
+
|
| 335 |
+
# Reshape screen for processing
|
| 336 |
+
screen = np.array(result.observation.screen).reshape(
|
| 337 |
+
result.observation.screen_shape
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# Your RL training code here
|
| 341 |
+
# ...
|
| 342 |
+
|
| 343 |
+
print(f"Episode {episode}: reward={episode_reward:.2f}, steps={steps}")
|
| 344 |
+
|
| 345 |
+
asyncio.run(train())
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
## Testing
|
| 349 |
+
|
| 350 |
+
### Local Testing
|
| 351 |
+
|
| 352 |
+
```bash
|
| 353 |
+
# Install dependencies
|
| 354 |
+
pip install ale-py fastapi uvicorn requests
|
| 355 |
+
|
| 356 |
+
# Start server
|
| 357 |
+
export PYTHONPATH=src:envs
|
| 358 |
+
python -m atari_env.server.app
|
| 359 |
+
|
| 360 |
+
# Test from another terminal (using sync wrapper for simplicity)
|
| 361 |
+
python -c "
|
| 362 |
+
from atari_env import AtariEnv, AtariAction
|
| 363 |
+
with AtariEnv(base_url='http://localhost:8000').sync() as env:
|
| 364 |
+
result = env.reset()
|
| 365 |
+
print(f'Initial obs: {result.observation.screen_shape}')
|
| 366 |
+
result = env.step(AtariAction(action_id=2))
|
| 367 |
+
print(f'After step: reward={result.reward}, done={result.done}')
|
| 368 |
+
"
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
### Docker Testing
|
| 372 |
+
|
| 373 |
+
```bash
|
| 374 |
+
# Build and run
|
| 375 |
+
docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 376 |
+
docker run -p 8000:8000 atari-env:latest
|
| 377 |
+
|
| 378 |
+
# Test in another terminal
|
| 379 |
+
curl http://localhost:8000/health
|
| 380 |
+
curl -X POST http://localhost:8000/reset
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
## Popular Games and Their Characteristics
|
| 384 |
+
|
| 385 |
+
| Game | Minimal Actions | Lives | Difficulty | Notes |
|
| 386 |
+
|------|----------------|-------|-----------|-------|
|
| 387 |
+
| Pong | 6 | 1 | Low | Good for learning basics |
|
| 388 |
+
| Breakout | 4 | 5 | Medium | Classic RL benchmark |
|
| 389 |
+
| Space Invaders | 6 | 3 | Medium | Shooting game |
|
| 390 |
+
| Ms. Pac-Man | 9 | 3 | High | Complex navigation |
|
| 391 |
+
| Asteroids | 14 | 3 | Medium | Continuous shooting |
|
| 392 |
+
| Montezuma's Revenge | 18 | 5 | Very High | Exploration challenge |
|
| 393 |
+
| Pitfall | 18 | 1 | High | Platformer |
|
| 394 |
+
| Seaquest | 18 | 3 | High | Submarine rescue |
|
| 395 |
+
|
| 396 |
+
## Limitations & Notes
|
| 397 |
+
|
| 398 |
+
- **Frame perfect timing**: Some games require precise timing
|
| 399 |
+
- **Exploration**: Games like Montezuma's Revenge are notoriously difficult
|
| 400 |
+
- **Observation delay**: HTTP adds minimal latency vs local gym
|
| 401 |
+
- **Determinism**: Set `ATARI_REPEAT_ACTION_PROB=0.0` for deterministic behavior
|
| 402 |
+
- **ROMs**: All ROMs are bundled with ale-py package
|
| 403 |
+
|
| 404 |
+
## References
|
| 405 |
+
|
| 406 |
+
- [Arcade Learning Environment Paper (2013)](https://jair.org/index.php/jair/article/view/10819)
|
| 407 |
+
- [ALE GitHub](https://github.com/Farama-Foundation/Arcade-Learning-Environment)
|
| 408 |
+
- [ALE Documentation](https://ale.farama.org/)
|
| 409 |
+
- [Gymnasium Atari Environments](https://gymnasium.farama.org/environments/atari/)
|
| 410 |
+
|
| 411 |
+
## Citation
|
| 412 |
+
|
| 413 |
+
If you use ALE in your research, please cite:
|
| 414 |
+
|
| 415 |
+
```bibtex
|
| 416 |
+
@Article{bellemare13arcade,
|
| 417 |
+
author = {{Bellemare}, M.~G. and {Naddaf}, Y. and {Veness}, J. and {Bowling}, M.},
|
| 418 |
+
title = {The Arcade Learning Environment: An Evaluation Platform for General Agents},
|
| 419 |
+
journal = {Journal of Artificial Intelligence Research},
|
| 420 |
+
year = "2013",
|
| 421 |
+
month = "jun",
|
| 422 |
+
volume = "47",
|
| 423 |
+
pages = "253--279",
|
| 424 |
+
}
|
| 425 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment for OpenEnv.
|
| 9 |
+
|
| 10 |
+
This module provides OpenEnv integration for Atari 2600 games via the
|
| 11 |
+
Arcade Learning Environment (ALE).
|
| 12 |
+
|
| 13 |
+
Example:
|
| 14 |
+
>>> from envs.atari_env import AtariEnv, AtariAction
|
| 15 |
+
>>>
|
| 16 |
+
>>> # Connect to a running server or start via Docker
|
| 17 |
+
>>> env = AtariEnv.from_docker_image("atari-env:latest")
|
| 18 |
+
>>>
|
| 19 |
+
>>> # Reset and interact
|
| 20 |
+
>>> result = env.reset()
|
| 21 |
+
>>> result = env.step(AtariAction(action_id=2)) # UP
|
| 22 |
+
>>> print(result.reward, result.done)
|
| 23 |
+
>>>
|
| 24 |
+
>>> # Cleanup
|
| 25 |
+
>>> env.close()
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from .client import AtariEnv
|
| 29 |
+
from .models import AtariAction, AtariObservation, AtariState
|
| 30 |
+
|
| 31 |
+
__all__ = ["AtariEnv", "AtariAction", "AtariObservation", "AtariState"]
|
client.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment Client.
|
| 9 |
+
|
| 10 |
+
This module provides the client for connecting to an Atari Environment server
|
| 11 |
+
via WebSocket for persistent sessions.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from openenv.core.client_types import StepResult
|
| 19 |
+
from openenv.core.env_client import EnvClient
|
| 20 |
+
|
| 21 |
+
from .models import AtariAction, AtariObservation, AtariState
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from openenv.core.containers.runtime import ContainerProvider
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AtariEnv(EnvClient[AtariAction, AtariObservation, AtariState]):
|
| 28 |
+
"""
|
| 29 |
+
Client for Atari Environment.
|
| 30 |
+
|
| 31 |
+
This client maintains a persistent WebSocket connection to the environment
|
| 32 |
+
server, enabling efficient multi-step interactions with lower latency.
|
| 33 |
+
|
| 34 |
+
Example:
|
| 35 |
+
>>> # Connect to a running server
|
| 36 |
+
>>> with AtariEnv(base_url="http://localhost:8000") as client:
|
| 37 |
+
... result = client.reset()
|
| 38 |
+
... print(result.observation.screen_shape)
|
| 39 |
+
...
|
| 40 |
+
... result = client.step(AtariAction(action_id=2)) # UP
|
| 41 |
+
... print(result.reward, result.done)
|
| 42 |
+
|
| 43 |
+
Example with Docker:
|
| 44 |
+
>>> # Automatically start container and connect
|
| 45 |
+
>>> client = AtariEnv.from_docker_image("atari-env:latest")
|
| 46 |
+
>>> try:
|
| 47 |
+
... result = client.reset()
|
| 48 |
+
... result = client.step(AtariAction(action_id=0)) # NOOP
|
| 49 |
+
... finally:
|
| 50 |
+
... client.close()
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def _step_payload(self, action: AtariAction) -> Dict[str, Any]:
|
| 54 |
+
"""
|
| 55 |
+
Convert AtariAction to JSON payload for step request.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
action: AtariAction instance.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dictionary representation suitable for JSON encoding.
|
| 62 |
+
"""
|
| 63 |
+
return {
|
| 64 |
+
"action_id": action.action_id,
|
| 65 |
+
"game_name": action.game_name,
|
| 66 |
+
"obs_type": action.obs_type,
|
| 67 |
+
"full_action_space": action.full_action_space,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AtariObservation]:
|
| 71 |
+
"""
|
| 72 |
+
Parse server response into StepResult[AtariObservation].
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
payload: JSON response from server.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
StepResult with AtariObservation.
|
| 79 |
+
"""
|
| 80 |
+
obs_data = payload.get("observation", {})
|
| 81 |
+
|
| 82 |
+
observation = AtariObservation(
|
| 83 |
+
screen=obs_data.get("screen", []),
|
| 84 |
+
screen_shape=obs_data.get("screen_shape", []),
|
| 85 |
+
legal_actions=obs_data.get("legal_actions", []),
|
| 86 |
+
lives=obs_data.get("lives", 0),
|
| 87 |
+
episode_frame_number=obs_data.get("episode_frame_number", 0),
|
| 88 |
+
frame_number=obs_data.get("frame_number", 0),
|
| 89 |
+
done=payload.get("done", False),
|
| 90 |
+
reward=payload.get("reward"),
|
| 91 |
+
metadata=obs_data.get("metadata", {}),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return StepResult(
|
| 95 |
+
observation=observation,
|
| 96 |
+
reward=payload.get("reward"),
|
| 97 |
+
done=payload.get("done", False),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _parse_state(self, payload: Dict[str, Any]) -> AtariState:
|
| 101 |
+
"""
|
| 102 |
+
Parse server response into AtariState object.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
payload: JSON response from /state endpoint.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
AtariState object with environment state information.
|
| 109 |
+
"""
|
| 110 |
+
return AtariState(
|
| 111 |
+
episode_id=payload.get("episode_id"),
|
| 112 |
+
step_count=payload.get("step_count", 0),
|
| 113 |
+
game_name=payload.get("game_name", "unknown"),
|
| 114 |
+
obs_type=payload.get("obs_type", "rgb"),
|
| 115 |
+
full_action_space=payload.get("full_action_space", False),
|
| 116 |
+
mode=payload.get("mode"),
|
| 117 |
+
difficulty=payload.get("difficulty"),
|
| 118 |
+
repeat_action_probability=payload.get("repeat_action_probability", 0.0),
|
| 119 |
+
frameskip=payload.get("frameskip", 4),
|
| 120 |
+
)
|
envs/atari_env/README.md
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Atari Environment Server
|
| 3 |
+
emoji: 🕹️
|
| 4 |
+
colorFrom: '#FF6200'
|
| 5 |
+
colorTo: '#D4151B'
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Atari Environment
|
| 15 |
+
|
| 16 |
+
Integration of Atari 2600 games with the OpenEnv framework via the Arcade Learning Environment (ALE). ALE provides access to 100+ classic Atari games for RL research.
|
| 17 |
+
|
| 18 |
+
## Supported Games
|
| 19 |
+
|
| 20 |
+
ALE supports 100+ Atari 2600 games including:
|
| 21 |
+
|
| 22 |
+
### Popular Games
|
| 23 |
+
- **Pong** - Classic two-player tennis
|
| 24 |
+
- **Breakout** - Break bricks with a ball
|
| 25 |
+
- **Space Invaders** - Shoot descending aliens
|
| 26 |
+
- **Pac-Man / Ms. Pac-Man** - Navigate mazes and eat pellets
|
| 27 |
+
- **Asteroids** - Destroy asteroids in space
|
| 28 |
+
- **Defender** - Side-scrolling space shooter
|
| 29 |
+
- **Centipede** - Shoot segmented centipede
|
| 30 |
+
- **Donkey Kong** - Jump over barrels to save princess
|
| 31 |
+
- **Frogger** - Cross road and river safely
|
| 32 |
+
- **Q*bert** - Jump on pyramid cubes
|
| 33 |
+
|
| 34 |
+
And many more! For a complete list, see [ALE documentation](https://ale.farama.org/environments/complete_list/).
|
| 35 |
+
|
| 36 |
+
## Architecture
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
┌────────────────────────────────────┐
|
| 40 |
+
│ RL Training Code (Client) │
|
| 41 |
+
│ AtariEnv.step(action) │
|
| 42 |
+
└──────────────┬─────────────────────┘
|
| 43 |
+
│ HTTP
|
| 44 |
+
┌──────────────▼─────────────────────┐
|
| 45 |
+
│ FastAPI Server (Docker) │
|
| 46 |
+
│ AtariEnvironment │
|
| 47 |
+
│ ├─ Wraps ALEInterface │
|
| 48 |
+
│ ├─ Handles observations │
|
| 49 |
+
│ └─ Action execution │
|
| 50 |
+
└────────────────────────────────────┘
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Installation & Usage
|
| 54 |
+
|
| 55 |
+
### Option 1: Local Development (without Docker)
|
| 56 |
+
|
| 57 |
+
**Requirements:**
|
| 58 |
+
- Python 3.11+
|
| 59 |
+
- ale-py installed: `pip install ale-py`
|
| 60 |
+
|
| 61 |
+
The client is **async by default**:
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
import asyncio
|
| 65 |
+
from atari_env import AtariEnv, AtariAction
|
| 66 |
+
|
| 67 |
+
async def main():
|
| 68 |
+
# Start local server manually: python -m atari_env.server.app
|
| 69 |
+
async with AtariEnv(base_url="http://localhost:8000") as env:
|
| 70 |
+
# Reset environment
|
| 71 |
+
result = await env.reset()
|
| 72 |
+
print(f"Screen shape: {result.observation.screen_shape}")
|
| 73 |
+
print(f"Legal actions: {result.observation.legal_actions}")
|
| 74 |
+
|
| 75 |
+
# Take actions
|
| 76 |
+
for _ in range(10):
|
| 77 |
+
result = await env.step(AtariAction(action_id=2, game_name="pong"))
|
| 78 |
+
print(f"Reward: {result.reward}, Done: {result.done}")
|
| 79 |
+
if result.done:
|
| 80 |
+
break
|
| 81 |
+
|
| 82 |
+
asyncio.run(main())
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
For **synchronous usage**, use the `.sync()` wrapper:
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
from atari_env import AtariEnv, AtariAction
|
| 89 |
+
|
| 90 |
+
with AtariEnv(base_url="http://localhost:8000").sync() as env:
|
| 91 |
+
result = env.reset()
|
| 92 |
+
result = env.step(AtariAction(action_id=2, game_name="pong"))
|
| 93 |
+
print(f"Reward: {result.reward}")
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Option 2: Docker (Recommended)
|
| 97 |
+
|
| 98 |
+
**Build Atari image:**
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
cd OpenEnv
|
| 102 |
+
|
| 103 |
+
# Build the image
|
| 104 |
+
docker build \
|
| 105 |
+
-f envs/atari_env/server/Dockerfile \
|
| 106 |
+
-t atari-env:latest \
|
| 107 |
+
.
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
**Run specific games:**
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
# Pong (default)
|
| 114 |
+
docker run -p 8000:8000 atari-env:latest
|
| 115 |
+
|
| 116 |
+
# Breakout
|
| 117 |
+
docker run -p 8000:8000 -e ATARI_GAME=breakout atari-env:latest
|
| 118 |
+
|
| 119 |
+
# Space Invaders with grayscale observation
|
| 120 |
+
docker run -p 8000:8000 \
|
| 121 |
+
-e ATARI_GAME=space_invaders \
|
| 122 |
+
-e ATARI_OBS_TYPE=grayscale \
|
| 123 |
+
atari-env:latest
|
| 124 |
+
|
| 125 |
+
# Ms. Pac-Man with full action space
|
| 126 |
+
docker run -p 8000:8000 \
|
| 127 |
+
-e ATARI_GAME=ms_pacman \
|
| 128 |
+
-e ATARI_FULL_ACTION_SPACE=true \
|
| 129 |
+
atari-env:latest
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
**Use with from_docker_image():**
|
| 133 |
+
|
| 134 |
+
```python
|
| 135 |
+
import asyncio
|
| 136 |
+
import numpy as np
|
| 137 |
+
from atari_env import AtariEnv, AtariAction
|
| 138 |
+
|
| 139 |
+
async def main():
|
| 140 |
+
# Automatically starts container
|
| 141 |
+
client = await AtariEnv.from_docker_image("atari-env:latest")
|
| 142 |
+
|
| 143 |
+
async with client:
|
| 144 |
+
result = await client.reset()
|
| 145 |
+
result = await client.step(AtariAction(action_id=2)) # UP
|
| 146 |
+
|
| 147 |
+
# Reshape screen for visualization
|
| 148 |
+
screen = np.array(result.observation.screen).reshape(result.observation.screen_shape)
|
| 149 |
+
print(f"Screen shape: {screen.shape}") # (210, 160, 3) for RGB
|
| 150 |
+
|
| 151 |
+
asyncio.run(main())
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
## Observation Types
|
| 155 |
+
|
| 156 |
+
### 1. RGB (Default)
|
| 157 |
+
- **Shape**: [210, 160, 3]
|
| 158 |
+
- **Description**: Full-color screen observation
|
| 159 |
+
- **Usage**: Most realistic, good for vision-based learning
|
| 160 |
+
|
| 161 |
+
```python
|
| 162 |
+
docker run -p 8000:8000 -e ATARI_OBS_TYPE=rgb atari-env:latest
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
### 2. Grayscale
|
| 166 |
+
- **Shape**: [210, 160]
|
| 167 |
+
- **Description**: Grayscale screen observation
|
| 168 |
+
- **Usage**: Reduced dimensionality, faster processing
|
| 169 |
+
|
| 170 |
+
```python
|
| 171 |
+
docker run -p 8000:8000 -e ATARI_OBS_TYPE=grayscale atari-env:latest
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### 3. RAM
|
| 175 |
+
- **Shape**: [128]
|
| 176 |
+
- **Description**: Raw 128-byte Atari 2600 RAM contents
|
| 177 |
+
- **Usage**: Compact representation, useful for specific research
|
| 178 |
+
|
| 179 |
+
```python
|
| 180 |
+
docker run -p 8000:8000 -e ATARI_OBS_TYPE=ram atari-env:latest
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## Action Spaces
|
| 184 |
+
|
| 185 |
+
### Minimal Action Set (Default)
|
| 186 |
+
Game-specific minimal actions (typically 4-9 actions).
|
| 187 |
+
- Pong: 6 actions (NOOP, FIRE, UP, DOWN, etc.)
|
| 188 |
+
- Breakout: 4 actions (NOOP, FIRE, LEFT, RIGHT)
|
| 189 |
+
|
| 190 |
+
```python
|
| 191 |
+
docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=false atari-env:latest
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### Full Action Set
|
| 195 |
+
All 18 possible Atari 2600 actions:
|
| 196 |
+
0. NOOP
|
| 197 |
+
1. FIRE
|
| 198 |
+
2. UP
|
| 199 |
+
3. RIGHT
|
| 200 |
+
4. LEFT
|
| 201 |
+
5. DOWN
|
| 202 |
+
6. UPRIGHT
|
| 203 |
+
7. UPLEFT
|
| 204 |
+
8. DOWNRIGHT
|
| 205 |
+
9. DOWNLEFT
|
| 206 |
+
10. UPFIRE
|
| 207 |
+
11. RIGHTFIRE
|
| 208 |
+
12. LEFTFIRE
|
| 209 |
+
13. DOWNFIRE
|
| 210 |
+
14. UPRIGHTFIRE
|
| 211 |
+
15. UPLEFTFIRE
|
| 212 |
+
16. DOWNRIGHTFIRE
|
| 213 |
+
17. DOWNLEFTFIRE
|
| 214 |
+
|
| 215 |
+
```python
|
| 216 |
+
docker run -p 8000:8000 -e ATARI_FULL_ACTION_SPACE=true atari-env:latest
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
## Configuration
|
| 220 |
+
|
| 221 |
+
### Environment Variables
|
| 222 |
+
|
| 223 |
+
- `ATARI_GAME`: Game name (default: "pong")
|
| 224 |
+
- `ATARI_OBS_TYPE`: Observation type - "rgb", "grayscale", "ram" (default: "rgb")
|
| 225 |
+
- `ATARI_FULL_ACTION_SPACE`: Use full action space - "true"/"false" (default: "false")
|
| 226 |
+
- `ATARI_MODE`: Game mode (optional, game-specific)
|
| 227 |
+
- `ATARI_DIFFICULTY`: Game difficulty (optional, game-specific)
|
| 228 |
+
- `ATARI_REPEAT_ACTION_PROB`: Sticky action probability 0.0-1.0 (default: "0.0")
|
| 229 |
+
- `ATARI_FRAMESKIP`: Frames to skip per action (default: "4")
|
| 230 |
+
|
| 231 |
+
### Example: Breakout with Custom Settings
|
| 232 |
+
|
| 233 |
+
```bash
|
| 234 |
+
docker run -p 8000:8000 \
|
| 235 |
+
-e ATARI_GAME=breakout \
|
| 236 |
+
-e ATARI_OBS_TYPE=grayscale \
|
| 237 |
+
-e ATARI_FULL_ACTION_SPACE=true \
|
| 238 |
+
-e ATARI_REPEAT_ACTION_PROB=0.25 \
|
| 239 |
+
-e ATARI_FRAMESKIP=4 \
|
| 240 |
+
atari-env:latest
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
## API Reference
|
| 244 |
+
|
| 245 |
+
### AtariAction
|
| 246 |
+
|
| 247 |
+
```python
|
| 248 |
+
@dataclass
|
| 249 |
+
class AtariAction(Action):
|
| 250 |
+
action_id: int # Action index to execute
|
| 251 |
+
game_name: str = "pong" # Game name
|
| 252 |
+
obs_type: str = "rgb" # Observation type
|
| 253 |
+
full_action_space: bool = False # Full or minimal action space
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
### AtariObservation
|
| 257 |
+
|
| 258 |
+
```python
|
| 259 |
+
@dataclass
|
| 260 |
+
class AtariObservation(Observation):
|
| 261 |
+
screen: List[int] # Flattened screen pixels
|
| 262 |
+
screen_shape: List[int] # Original screen shape
|
| 263 |
+
legal_actions: List[int] # Legal action indices
|
| 264 |
+
lives: int # Lives remaining
|
| 265 |
+
episode_frame_number: int # Frame # in episode
|
| 266 |
+
frame_number: int # Total frame #
|
| 267 |
+
done: bool # Episode finished
|
| 268 |
+
reward: Optional[float] # Reward from last action
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
### AtariState
|
| 272 |
+
|
| 273 |
+
```python
|
| 274 |
+
@dataclass
|
| 275 |
+
class AtariState(State):
|
| 276 |
+
episode_id: str # Unique episode ID
|
| 277 |
+
step_count: int # Number of steps
|
| 278 |
+
game_name: str # Game name
|
| 279 |
+
obs_type: str # Observation type
|
| 280 |
+
full_action_space: bool # Action space type
|
| 281 |
+
mode: Optional[int] # Game mode
|
| 282 |
+
difficulty: Optional[int] # Game difficulty
|
| 283 |
+
repeat_action_probability: float # Sticky action prob
|
| 284 |
+
frameskip: int # Frameskip setting
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
## Example Script
|
| 288 |
+
|
| 289 |
+
```python
|
| 290 |
+
#!/usr/bin/env python3
|
| 291 |
+
"""Example training loop with Atari environment."""
|
| 292 |
+
|
| 293 |
+
import asyncio
|
| 294 |
+
import numpy as np
|
| 295 |
+
from atari_env import AtariEnv, AtariAction
|
| 296 |
+
|
| 297 |
+
async def train():
|
| 298 |
+
# Start environment
|
| 299 |
+
client = await AtariEnv.from_docker_image("atari-env:latest")
|
| 300 |
+
|
| 301 |
+
async with client:
|
| 302 |
+
# Training loop
|
| 303 |
+
for episode in range(10):
|
| 304 |
+
result = await client.reset()
|
| 305 |
+
episode_reward = 0
|
| 306 |
+
steps = 0
|
| 307 |
+
|
| 308 |
+
while not result.done:
|
| 309 |
+
# Random policy (replace with your RL agent)
|
| 310 |
+
action_id = np.random.choice(result.observation.legal_actions)
|
| 311 |
+
|
| 312 |
+
# Take action
|
| 313 |
+
result = await client.step(AtariAction(action_id=action_id))
|
| 314 |
+
|
| 315 |
+
episode_reward += result.reward or 0
|
| 316 |
+
steps += 1
|
| 317 |
+
|
| 318 |
+
# Reshape screen for processing
|
| 319 |
+
screen = np.array(result.observation.screen).reshape(
|
| 320 |
+
result.observation.screen_shape
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Your RL training code here
|
| 324 |
+
# ...
|
| 325 |
+
|
| 326 |
+
print(f"Episode {episode}: reward={episode_reward:.2f}, steps={steps}")
|
| 327 |
+
|
| 328 |
+
asyncio.run(train())
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
## Testing
|
| 332 |
+
|
| 333 |
+
### Local Testing
|
| 334 |
+
|
| 335 |
+
```bash
|
| 336 |
+
# Install dependencies
|
| 337 |
+
pip install ale-py fastapi uvicorn requests
|
| 338 |
+
|
| 339 |
+
# Start server
|
| 340 |
+
export PYTHONPATH=src:envs
|
| 341 |
+
python -m atari_env.server.app
|
| 342 |
+
|
| 343 |
+
# Test from another terminal (using sync wrapper for simplicity)
|
| 344 |
+
python -c "
|
| 345 |
+
from atari_env import AtariEnv, AtariAction
|
| 346 |
+
with AtariEnv(base_url='http://localhost:8000').sync() as env:
|
| 347 |
+
result = env.reset()
|
| 348 |
+
print(f'Initial obs: {result.observation.screen_shape}')
|
| 349 |
+
result = env.step(AtariAction(action_id=2))
|
| 350 |
+
print(f'After step: reward={result.reward}, done={result.done}')
|
| 351 |
+
"
|
| 352 |
+
```
|
| 353 |
+
|
| 354 |
+
### Docker Testing
|
| 355 |
+
|
| 356 |
+
```bash
|
| 357 |
+
# Build and run
|
| 358 |
+
docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 359 |
+
docker run -p 8000:8000 atari-env:latest
|
| 360 |
+
|
| 361 |
+
# Test in another terminal
|
| 362 |
+
curl http://localhost:8000/health
|
| 363 |
+
curl -X POST http://localhost:8000/reset
|
| 364 |
+
```
|
| 365 |
+
|
| 366 |
+
## Popular Games and Their Characteristics
|
| 367 |
+
|
| 368 |
+
| Game | Minimal Actions | Lives | Difficulty | Notes |
|
| 369 |
+
|------|----------------|-------|-----------|-------|
|
| 370 |
+
| Pong | 6 | 1 | Low | Good for learning basics |
|
| 371 |
+
| Breakout | 4 | 5 | Medium | Classic RL benchmark |
|
| 372 |
+
| Space Invaders | 6 | 3 | Medium | Shooting game |
|
| 373 |
+
| Ms. Pac-Man | 9 | 3 | High | Complex navigation |
|
| 374 |
+
| Asteroids | 14 | 3 | Medium | Continuous shooting |
|
| 375 |
+
| Montezuma's Revenge | 18 | 5 | Very High | Exploration challenge |
|
| 376 |
+
| Pitfall | 18 | 1 | High | Platformer |
|
| 377 |
+
| Seaquest | 18 | 3 | High | Submarine rescue |
|
| 378 |
+
|
| 379 |
+
## Limitations & Notes
|
| 380 |
+
|
| 381 |
+
- **Frame perfect timing**: Some games require precise timing
|
| 382 |
+
- **Exploration**: Games like Montezuma's Revenge are notoriously difficult
|
| 383 |
+
- **Observation delay**: HTTP adds minimal latency vs local gym
|
| 384 |
+
- **Determinism**: Set `ATARI_REPEAT_ACTION_PROB=0.0` for deterministic behavior
|
| 385 |
+
- **ROMs**: All ROMs are bundled with ale-py package
|
| 386 |
+
|
| 387 |
+
## References
|
| 388 |
+
|
| 389 |
+
- [Arcade Learning Environment Paper (2013)](https://jair.org/index.php/jair/article/view/10819)
|
| 390 |
+
- [ALE GitHub](https://github.com/Farama-Foundation/Arcade-Learning-Environment)
|
| 391 |
+
- [ALE Documentation](https://ale.farama.org/)
|
| 392 |
+
- [Gymnasium Atari Environments](https://gymnasium.farama.org/environments/atari/)
|
| 393 |
+
|
| 394 |
+
## Citation
|
| 395 |
+
|
| 396 |
+
If you use ALE in your research, please cite:
|
| 397 |
+
|
| 398 |
+
```bibtex
|
| 399 |
+
@Article{bellemare13arcade,
|
| 400 |
+
author = {{Bellemare}, M.~G. and {Naddaf}, Y. and {Veness}, J. and {Bowling}, M.},
|
| 401 |
+
title = {The Arcade Learning Environment: An Evaluation Platform for General Agents},
|
| 402 |
+
journal = {Journal of Artificial Intelligence Research},
|
| 403 |
+
year = "2013",
|
| 404 |
+
month = "jun",
|
| 405 |
+
volume = "47",
|
| 406 |
+
pages = "253--279",
|
| 407 |
+
}
|
| 408 |
+
```
|
envs/atari_env/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment for OpenEnv.
|
| 9 |
+
|
| 10 |
+
This module provides OpenEnv integration for Atari 2600 games via the
|
| 11 |
+
Arcade Learning Environment (ALE).
|
| 12 |
+
|
| 13 |
+
Example:
|
| 14 |
+
>>> from envs.atari_env import AtariEnv, AtariAction
|
| 15 |
+
>>>
|
| 16 |
+
>>> # Connect to a running server or start via Docker
|
| 17 |
+
>>> env = AtariEnv.from_docker_image("atari-env:latest")
|
| 18 |
+
>>>
|
| 19 |
+
>>> # Reset and interact
|
| 20 |
+
>>> result = env.reset()
|
| 21 |
+
>>> result = env.step(AtariAction(action_id=2)) # UP
|
| 22 |
+
>>> print(result.reward, result.done)
|
| 23 |
+
>>>
|
| 24 |
+
>>> # Cleanup
|
| 25 |
+
>>> env.close()
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from .client import AtariEnv
|
| 29 |
+
from .models import AtariAction, AtariObservation, AtariState
|
| 30 |
+
|
| 31 |
+
__all__ = ["AtariEnv", "AtariAction", "AtariObservation", "AtariState"]
|
envs/atari_env/client.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment Client.
|
| 9 |
+
|
| 10 |
+
This module provides the client for connecting to an Atari Environment server
|
| 11 |
+
via WebSocket for persistent sessions.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from openenv.core.client_types import StepResult
|
| 19 |
+
from openenv.core.env_client import EnvClient
|
| 20 |
+
|
| 21 |
+
from .models import AtariAction, AtariObservation, AtariState
|
| 22 |
+
|
| 23 |
+
if TYPE_CHECKING:
|
| 24 |
+
from openenv.core.containers.runtime import ContainerProvider
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AtariEnv(EnvClient[AtariAction, AtariObservation, AtariState]):
|
| 28 |
+
"""
|
| 29 |
+
Client for Atari Environment.
|
| 30 |
+
|
| 31 |
+
This client maintains a persistent WebSocket connection to the environment
|
| 32 |
+
server, enabling efficient multi-step interactions with lower latency.
|
| 33 |
+
|
| 34 |
+
Example:
|
| 35 |
+
>>> # Connect to a running server
|
| 36 |
+
>>> with AtariEnv(base_url="http://localhost:8000") as client:
|
| 37 |
+
... result = client.reset()
|
| 38 |
+
... print(result.observation.screen_shape)
|
| 39 |
+
...
|
| 40 |
+
... result = client.step(AtariAction(action_id=2)) # UP
|
| 41 |
+
... print(result.reward, result.done)
|
| 42 |
+
|
| 43 |
+
Example with Docker:
|
| 44 |
+
>>> # Automatically start container and connect
|
| 45 |
+
>>> client = AtariEnv.from_docker_image("atari-env:latest")
|
| 46 |
+
>>> try:
|
| 47 |
+
... result = client.reset()
|
| 48 |
+
... result = client.step(AtariAction(action_id=0)) # NOOP
|
| 49 |
+
... finally:
|
| 50 |
+
... client.close()
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def _step_payload(self, action: AtariAction) -> Dict[str, Any]:
|
| 54 |
+
"""
|
| 55 |
+
Convert AtariAction to JSON payload for step request.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
action: AtariAction instance.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dictionary representation suitable for JSON encoding.
|
| 62 |
+
"""
|
| 63 |
+
return {
|
| 64 |
+
"action_id": action.action_id,
|
| 65 |
+
"game_name": action.game_name,
|
| 66 |
+
"obs_type": action.obs_type,
|
| 67 |
+
"full_action_space": action.full_action_space,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AtariObservation]:
|
| 71 |
+
"""
|
| 72 |
+
Parse server response into StepResult[AtariObservation].
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
payload: JSON response from server.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
StepResult with AtariObservation.
|
| 79 |
+
"""
|
| 80 |
+
obs_data = payload.get("observation", {})
|
| 81 |
+
|
| 82 |
+
observation = AtariObservation(
|
| 83 |
+
screen=obs_data.get("screen", []),
|
| 84 |
+
screen_shape=obs_data.get("screen_shape", []),
|
| 85 |
+
legal_actions=obs_data.get("legal_actions", []),
|
| 86 |
+
lives=obs_data.get("lives", 0),
|
| 87 |
+
episode_frame_number=obs_data.get("episode_frame_number", 0),
|
| 88 |
+
frame_number=obs_data.get("frame_number", 0),
|
| 89 |
+
done=payload.get("done", False),
|
| 90 |
+
reward=payload.get("reward"),
|
| 91 |
+
metadata=obs_data.get("metadata", {}),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return StepResult(
|
| 95 |
+
observation=observation,
|
| 96 |
+
reward=payload.get("reward"),
|
| 97 |
+
done=payload.get("done", False),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _parse_state(self, payload: Dict[str, Any]) -> AtariState:
|
| 101 |
+
"""
|
| 102 |
+
Parse server response into AtariState object.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
payload: JSON response from /state endpoint.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
AtariState object with environment state information.
|
| 109 |
+
"""
|
| 110 |
+
return AtariState(
|
| 111 |
+
episode_id=payload.get("episode_id"),
|
| 112 |
+
step_count=payload.get("step_count", 0),
|
| 113 |
+
game_name=payload.get("game_name", "unknown"),
|
| 114 |
+
obs_type=payload.get("obs_type", "rgb"),
|
| 115 |
+
full_action_space=payload.get("full_action_space", False),
|
| 116 |
+
mode=payload.get("mode"),
|
| 117 |
+
difficulty=payload.get("difficulty"),
|
| 118 |
+
repeat_action_probability=payload.get("repeat_action_probability", 0.0),
|
| 119 |
+
frameskip=payload.get("frameskip", 4),
|
| 120 |
+
)
|
envs/atari_env/models.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for Atari Environment.
|
| 9 |
+
|
| 10 |
+
This module defines the Action, Observation, and State types for Atari games
|
| 11 |
+
via the Arcade Learning Environment (ALE).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 17 |
+
|
| 18 |
+
from openenv.core.env_server import Action, Observation, State
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AtariAction(Action):
|
| 22 |
+
"""
|
| 23 |
+
Action for Atari environments.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
action_id: The integer action ID to take (from legal_actions).
|
| 27 |
+
game_name: Name of the Atari game (e.g., "pong", "breakout", "space_invaders").
|
| 28 |
+
obs_type: Observation type ("rgb", "grayscale", or "ram").
|
| 29 |
+
full_action_space: Whether to use full (18 actions) or minimal action space.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
action_id: int
|
| 33 |
+
game_name: str = "pong"
|
| 34 |
+
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
|
| 35 |
+
full_action_space: bool = False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AtariObservation(Observation):
|
| 39 |
+
"""
|
| 40 |
+
Observation from Atari environment.
|
| 41 |
+
|
| 42 |
+
This represents what the agent sees after taking an action.
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
screen: Screen observation as a flattened list of pixels.
|
| 46 |
+
Shape depends on obs_type:
|
| 47 |
+
- rgb: [210, 160, 3] flattened
|
| 48 |
+
- grayscale: [210, 160] flattened
|
| 49 |
+
- ram: [128] (RAM contents)
|
| 50 |
+
screen_shape: Original shape of the screen before flattening.
|
| 51 |
+
legal_actions: List of legal action IDs the agent can take.
|
| 52 |
+
lives: Number of lives remaining.
|
| 53 |
+
episode_frame_number: Frame number within current episode.
|
| 54 |
+
frame_number: Total frame number since environment creation.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
screen: List[int]
|
| 58 |
+
screen_shape: List[int]
|
| 59 |
+
legal_actions: List[int]
|
| 60 |
+
lives: int = 0
|
| 61 |
+
episode_frame_number: int = 0
|
| 62 |
+
frame_number: int = 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class AtariState(State):
|
| 66 |
+
"""
|
| 67 |
+
State for Atari environment.
|
| 68 |
+
|
| 69 |
+
Attributes:
|
| 70 |
+
game_name: Name of the Atari game.
|
| 71 |
+
obs_type: Observation type ("rgb", "grayscale", or "ram").
|
| 72 |
+
full_action_space: Whether using full or minimal action space.
|
| 73 |
+
mode: Game mode (if applicable).
|
| 74 |
+
difficulty: Game difficulty (if applicable).
|
| 75 |
+
repeat_action_probability: Probability of repeating previous action (sticky actions).
|
| 76 |
+
frameskip: Number of frames to skip per action.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
game_name: str = "pong"
|
| 80 |
+
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
|
| 81 |
+
full_action_space: bool = False
|
| 82 |
+
mode: Optional[int] = None
|
| 83 |
+
difficulty: Optional[int] = None
|
| 84 |
+
repeat_action_probability: float = 0.0
|
| 85 |
+
frameskip: int = 4
|
envs/atari_env/server/Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Atari Environment
|
| 2 |
+
# This image provides Atari 2600 games via the Arcade Learning Environment (ALE)
|
| 3 |
+
|
| 4 |
+
# Configurable base image - defaults to local build, can be overridden for CI/CD
|
| 5 |
+
# Base image provides: fastapi, uvicorn, requests, curl, PYTHONPATH=/app/src
|
| 6 |
+
#
|
| 7 |
+
# Local build: docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile .
|
| 8 |
+
# docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 9 |
+
#
|
| 10 |
+
# CI/CD build: docker build --build-arg BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest \
|
| 11 |
+
# -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 12 |
+
ARG BASE_IMAGE=openenv-base:latest
|
| 13 |
+
FROM ${BASE_IMAGE}
|
| 14 |
+
|
| 15 |
+
# Install dependencies
|
| 16 |
+
COPY envs/atari_env/server/requirements.txt /tmp/requirements.txt
|
| 17 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy OpenEnv core (base image already set WORKDIR=/app)
|
| 20 |
+
COPY src/core/ /app/src/core/
|
| 21 |
+
|
| 22 |
+
# Copy Atari environment code
|
| 23 |
+
COPY envs/atari_env/ /app/envs/atari_env/
|
| 24 |
+
|
| 25 |
+
# Copy README for web interface documentation
|
| 26 |
+
COPY envs/atari_env/README.md /app/README.md
|
| 27 |
+
|
| 28 |
+
# Atari-specific environment variables (can be overridden at runtime)
|
| 29 |
+
ENV ATARI_GAME=pong
|
| 30 |
+
ENV ATARI_OBS_TYPE=rgb
|
| 31 |
+
ENV ATARI_FULL_ACTION_SPACE=false
|
| 32 |
+
ENV ATARI_REPEAT_ACTION_PROB=0.0
|
| 33 |
+
ENV ATARI_FRAMESKIP=4
|
| 34 |
+
|
| 35 |
+
# Expose port
|
| 36 |
+
EXPOSE 8000
|
| 37 |
+
|
| 38 |
+
# Health check
|
| 39 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 40 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 41 |
+
|
| 42 |
+
# Run the FastAPI server
|
| 43 |
+
CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
envs/atari_env/server/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment Server.
|
| 9 |
+
|
| 10 |
+
Server-side implementation of Atari environment for OpenEnv.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from .atari_environment import AtariEnvironment
|
| 14 |
+
|
| 15 |
+
__all__ = ["AtariEnvironment"]
|
envs/atari_env/server/app.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Atari Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes Atari games
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
# Development (with auto-reload):
|
| 15 |
+
uvicorn envs.atari_env.server.app:app --reload --host 0.0.0.0 --port 8000
|
| 16 |
+
|
| 17 |
+
# Production:
|
| 18 |
+
uvicorn envs.atari_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 19 |
+
|
| 20 |
+
# Or run directly:
|
| 21 |
+
python -m envs.atari_env.server.app
|
| 22 |
+
|
| 23 |
+
Environment variables:
|
| 24 |
+
ATARI_GAME: Game name to serve (default: "pong")
|
| 25 |
+
ATARI_OBS_TYPE: Observation type (default: "rgb")
|
| 26 |
+
ATARI_FULL_ACTION_SPACE: Use full action space (default: "false")
|
| 27 |
+
ATARI_MODE: Game mode (optional)
|
| 28 |
+
ATARI_DIFFICULTY: Game difficulty (optional)
|
| 29 |
+
ATARI_REPEAT_ACTION_PROB: Sticky action probability (default: "0.0")
|
| 30 |
+
ATARI_FRAMESKIP: Frameskip (default: "4")
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
from openenv.core.env_server import create_app
|
| 36 |
+
|
| 37 |
+
from ..models import AtariAction, AtariObservation
|
| 38 |
+
from .atari_environment import AtariEnvironment
|
| 39 |
+
|
| 40 |
+
# Get configuration from environment variables
|
| 41 |
+
game_name = os.getenv("ATARI_GAME", "pong")
|
| 42 |
+
obs_type = os.getenv("ATARI_OBS_TYPE", "rgb")
|
| 43 |
+
full_action_space = os.getenv("ATARI_FULL_ACTION_SPACE", "false").lower() == "true"
|
| 44 |
+
repeat_action_prob = float(os.getenv("ATARI_REPEAT_ACTION_PROB", "0.0"))
|
| 45 |
+
frameskip = int(os.getenv("ATARI_FRAMESKIP", "4"))
|
| 46 |
+
|
| 47 |
+
# Optional parameters
|
| 48 |
+
mode = os.getenv("ATARI_MODE")
|
| 49 |
+
difficulty = os.getenv("ATARI_DIFFICULTY")
|
| 50 |
+
|
| 51 |
+
# Convert to int if specified
|
| 52 |
+
mode = int(mode) if mode is not None else None
|
| 53 |
+
difficulty = int(difficulty) if difficulty is not None else None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Factory function to create AtariEnvironment instances
|
| 57 |
+
def create_atari_environment():
|
| 58 |
+
"""Factory function that creates AtariEnvironment with config."""
|
| 59 |
+
return AtariEnvironment(
|
| 60 |
+
game_name=game_name,
|
| 61 |
+
obs_type=obs_type,
|
| 62 |
+
full_action_space=full_action_space,
|
| 63 |
+
mode=mode,
|
| 64 |
+
difficulty=difficulty,
|
| 65 |
+
repeat_action_probability=repeat_action_prob,
|
| 66 |
+
frameskip=frameskip,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Create the FastAPI app with web interface and README integration
|
| 71 |
+
# Pass the factory function instead of an instance for WebSocket session support
|
| 72 |
+
app = create_app(
|
| 73 |
+
create_atari_environment, AtariAction, AtariObservation, env_name="atari_env"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
import uvicorn
|
| 79 |
+
|
| 80 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
envs/atari_env/server/atari_environment.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment Server Implementation.
|
| 9 |
+
|
| 10 |
+
This module wraps ALE's ALEInterface and exposes it
|
| 11 |
+
via the OpenEnv Environment interface.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import uuid
|
| 15 |
+
from typing import Any, Dict, Literal, Optional
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server import Action, Environment, Observation
|
| 18 |
+
|
| 19 |
+
from ..models import AtariAction, AtariObservation, AtariState
|
| 20 |
+
|
| 21 |
+
# Import ALE
|
| 22 |
+
try:
|
| 23 |
+
import numpy as np
|
| 24 |
+
from ale_py import ALEInterface, roms
|
| 25 |
+
except ImportError as e:
|
| 26 |
+
raise ImportError(
|
| 27 |
+
"ALE (Arcade Learning Environment) is not installed. "
|
| 28 |
+
"Please install it with: pip install ale-py"
|
| 29 |
+
) from e
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AtariEnvironment(Environment):
|
| 33 |
+
"""
|
| 34 |
+
Atari Environment wrapper for OpenEnv.
|
| 35 |
+
|
| 36 |
+
This environment wraps Atari 2600 games via the Arcade Learning Environment (ALE)
|
| 37 |
+
and provides a clean interface for RL training.
|
| 38 |
+
|
| 39 |
+
Supported games include: pong, breakout, space_invaders, and 100+ others.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
game_name: Name of the Atari game (e.g., "pong", "breakout").
|
| 43 |
+
obs_type: Observation type - "rgb", "grayscale", or "ram".
|
| 44 |
+
full_action_space: Use full action space (18 actions) vs minimal.
|
| 45 |
+
mode: Game mode (if applicable).
|
| 46 |
+
difficulty: Game difficulty (if applicable).
|
| 47 |
+
repeat_action_probability: Sticky action probability (default 0.0).
|
| 48 |
+
frameskip: Number of frames to skip per action (default 4).
|
| 49 |
+
|
| 50 |
+
Example:
|
| 51 |
+
>>> env = AtariEnvironment("pong")
|
| 52 |
+
>>> obs = env.reset()
|
| 53 |
+
>>> print(obs.screen_shape) # [210, 160, 3]
|
| 54 |
+
>>> obs = env.step(AtariAction(action_id=2)) # UP
|
| 55 |
+
>>> print(obs.reward, obs.done)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
game_name: str = "pong",
|
| 61 |
+
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb",
|
| 62 |
+
full_action_space: bool = False,
|
| 63 |
+
mode: Optional[int] = None,
|
| 64 |
+
difficulty: Optional[int] = None,
|
| 65 |
+
repeat_action_probability: float = 0.0,
|
| 66 |
+
frameskip: int = 4,
|
| 67 |
+
):
|
| 68 |
+
"""Initialize Atari environment."""
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.game_name = game_name
|
| 72 |
+
self.obs_type = obs_type
|
| 73 |
+
self.full_action_space = full_action_space
|
| 74 |
+
self.mode = mode
|
| 75 |
+
self.difficulty = difficulty
|
| 76 |
+
self.repeat_action_probability = repeat_action_probability
|
| 77 |
+
self.frameskip = frameskip
|
| 78 |
+
|
| 79 |
+
# Create ALE interface
|
| 80 |
+
self.ale = ALEInterface()
|
| 81 |
+
|
| 82 |
+
# Configure ALE
|
| 83 |
+
from ale_py import LoggerMode
|
| 84 |
+
|
| 85 |
+
self.ale.setLoggerMode(LoggerMode.Error) # Error mode only
|
| 86 |
+
self.ale.setFloat("repeat_action_probability", repeat_action_probability)
|
| 87 |
+
|
| 88 |
+
# Load ROM
|
| 89 |
+
try:
|
| 90 |
+
rom_path = roms.get_rom_path(game_name)
|
| 91 |
+
self.ale.loadROM(rom_path)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"Failed to load Atari game '{game_name}': {e}\n"
|
| 95 |
+
f"Available games can be found via: ale_py.roms.list_roms()"
|
| 96 |
+
) from e
|
| 97 |
+
|
| 98 |
+
# Set mode and difficulty if specified
|
| 99 |
+
if mode is not None:
|
| 100 |
+
self.ale.setMode(mode)
|
| 101 |
+
if difficulty is not None:
|
| 102 |
+
self.ale.setDifficulty(difficulty)
|
| 103 |
+
|
| 104 |
+
# Get action set
|
| 105 |
+
if full_action_space:
|
| 106 |
+
self._action_set = self.ale.getLegalActionSet()
|
| 107 |
+
else:
|
| 108 |
+
self._action_set = self.ale.getMinimalActionSet()
|
| 109 |
+
|
| 110 |
+
# Get screen dimensions for observation space
|
| 111 |
+
self.screen_height, self.screen_width = self.ale.getScreenDims()
|
| 112 |
+
if obs_type == "rgb":
|
| 113 |
+
self.screen_shape = [self.screen_height, self.screen_width, 3]
|
| 114 |
+
elif obs_type == "grayscale":
|
| 115 |
+
self.screen_shape = [self.screen_height, self.screen_width]
|
| 116 |
+
elif obs_type == "ram":
|
| 117 |
+
self.screen_shape = [self.ale.getRAMSize()]
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError(f"Invalid obs_type: {obs_type}")
|
| 120 |
+
|
| 121 |
+
# Initialize state
|
| 122 |
+
self._state = AtariState(
|
| 123 |
+
game_name=game_name,
|
| 124 |
+
obs_type=obs_type,
|
| 125 |
+
full_action_space=full_action_space,
|
| 126 |
+
mode=mode,
|
| 127 |
+
difficulty=difficulty,
|
| 128 |
+
repeat_action_probability=repeat_action_probability,
|
| 129 |
+
frameskip=frameskip,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def reset(self) -> Observation:
|
| 133 |
+
"""
|
| 134 |
+
Reset the environment and return initial observation.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Initial observation for the agent.
|
| 138 |
+
"""
|
| 139 |
+
# Reset ALE
|
| 140 |
+
self.ale.reset_game()
|
| 141 |
+
|
| 142 |
+
# Reset state tracking
|
| 143 |
+
self._state.episode_id = str(uuid.uuid4())
|
| 144 |
+
self._state.step_count = 0
|
| 145 |
+
|
| 146 |
+
# Get initial observation
|
| 147 |
+
return self._make_observation()
|
| 148 |
+
|
| 149 |
+
def step(self, action: Action) -> Observation:
|
| 150 |
+
"""
|
| 151 |
+
Execute agent's action and return resulting observation.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
action: AtariAction containing the action_id to execute.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Observation after action execution.
|
| 158 |
+
|
| 159 |
+
Raises:
|
| 160 |
+
ValueError: If action is not an AtariAction.
|
| 161 |
+
"""
|
| 162 |
+
if not isinstance(action, AtariAction):
|
| 163 |
+
raise ValueError(f"Expected AtariAction, got {type(action)}")
|
| 164 |
+
|
| 165 |
+
# Validate action_id
|
| 166 |
+
if action.action_id < 0 or action.action_id >= len(self._action_set):
|
| 167 |
+
raise ValueError(
|
| 168 |
+
f"Invalid action_id: {action.action_id}. "
|
| 169 |
+
f"Valid range: [0, {len(self._action_set) - 1}]"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Get actual ALE action
|
| 173 |
+
ale_action = self._action_set[action.action_id]
|
| 174 |
+
|
| 175 |
+
# Execute action with frameskip
|
| 176 |
+
total_reward = 0.0
|
| 177 |
+
for _ in range(self.frameskip):
|
| 178 |
+
total_reward += self.ale.act(ale_action)
|
| 179 |
+
if self.ale.game_over():
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
self._state.step_count += 1
|
| 183 |
+
|
| 184 |
+
# Get observation
|
| 185 |
+
obs = self._make_observation()
|
| 186 |
+
obs.reward = total_reward
|
| 187 |
+
|
| 188 |
+
return obs
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def state(self) -> AtariState:
|
| 192 |
+
"""Get current environment state."""
|
| 193 |
+
return self._state
|
| 194 |
+
|
| 195 |
+
def _make_observation(self) -> AtariObservation:
|
| 196 |
+
"""
|
| 197 |
+
Create an AtariObservation from current ALE state.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
AtariObservation for the agent.
|
| 201 |
+
"""
|
| 202 |
+
# Get screen observation
|
| 203 |
+
if self.obs_type == "rgb":
|
| 204 |
+
screen = self.ale.getScreenRGB()
|
| 205 |
+
elif self.obs_type == "grayscale":
|
| 206 |
+
screen = self.ale.getScreenGrayscale()
|
| 207 |
+
elif self.obs_type == "ram":
|
| 208 |
+
screen = self.ale.getRAM()
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(f"Invalid obs_type: {self.obs_type}")
|
| 211 |
+
|
| 212 |
+
# Flatten screen for JSON serialization
|
| 213 |
+
# Handle both numpy arrays and lists
|
| 214 |
+
if hasattr(screen, "flatten"):
|
| 215 |
+
screen_flat = screen.flatten().tolist()
|
| 216 |
+
elif hasattr(screen, "tolist"):
|
| 217 |
+
screen_flat = screen.tolist()
|
| 218 |
+
else:
|
| 219 |
+
screen_flat = list(screen)
|
| 220 |
+
|
| 221 |
+
# Get game info
|
| 222 |
+
lives = self.ale.lives()
|
| 223 |
+
episode_frame_number = self.ale.getEpisodeFrameNumber()
|
| 224 |
+
frame_number = self.ale.getFrameNumber()
|
| 225 |
+
done = self.ale.game_over()
|
| 226 |
+
|
| 227 |
+
# Create legal actions list (indices into action_set)
|
| 228 |
+
legal_actions = list(range(len(self._action_set)))
|
| 229 |
+
|
| 230 |
+
# Create observation
|
| 231 |
+
obs = AtariObservation(
|
| 232 |
+
screen=screen_flat,
|
| 233 |
+
screen_shape=self.screen_shape,
|
| 234 |
+
legal_actions=legal_actions,
|
| 235 |
+
lives=lives,
|
| 236 |
+
episode_frame_number=episode_frame_number,
|
| 237 |
+
frame_number=frame_number,
|
| 238 |
+
done=done,
|
| 239 |
+
reward=0.0, # Will be filled in by step()
|
| 240 |
+
metadata={
|
| 241 |
+
"game_name": self.game_name,
|
| 242 |
+
"action_meanings": [str(a) for a in self._action_set],
|
| 243 |
+
},
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return obs
|
envs/atari_env/server/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gymnasium>=0.29.0
|
| 2 |
+
ale-py>=0.8.0
|
| 3 |
+
numpy>=1.24.0
|
envs/atari_env/test_atari_docker.sh
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Comprehensive Docker test for Atari environment
|
| 3 |
+
# Tests: Build, Start, Health, Reset, Step, State, Cleanup
|
| 4 |
+
|
| 5 |
+
set -e # Exit on error
|
| 6 |
+
|
| 7 |
+
# Colors for output
|
| 8 |
+
RED='\033[0;31m'
|
| 9 |
+
GREEN='\033[0;32m'
|
| 10 |
+
YELLOW='\033[1;33m'
|
| 11 |
+
BLUE='\033[0;34m'
|
| 12 |
+
NC='\033[0m' # No Color
|
| 13 |
+
|
| 14 |
+
# Configuration
|
| 15 |
+
IMAGE_NAME="atari-env"
|
| 16 |
+
IMAGE_TAG="test"
|
| 17 |
+
CONTAINER_NAME="atari-env-test"
|
| 18 |
+
PORT="8765" # Use non-standard port to avoid conflicts
|
| 19 |
+
HEALTH_RETRIES=30
|
| 20 |
+
HEALTH_DELAY=2
|
| 21 |
+
|
| 22 |
+
# Cleanup function
|
| 23 |
+
cleanup() {
|
| 24 |
+
echo -e "\n${BLUE}Cleaning up...${NC}"
|
| 25 |
+
docker stop ${CONTAINER_NAME} 2>/dev/null || true
|
| 26 |
+
docker rm ${CONTAINER_NAME} 2>/dev/null || true
|
| 27 |
+
echo -e "${GREEN}✓${NC} Cleanup complete"
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
# Set trap to cleanup on exit
|
| 31 |
+
trap cleanup EXIT
|
| 32 |
+
|
| 33 |
+
# Header
|
| 34 |
+
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
| 35 |
+
echo " ATARI ENVIRONMENT DOCKER TEST"
|
| 36 |
+
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
| 37 |
+
echo ""
|
| 38 |
+
|
| 39 |
+
# Check prerequisites
|
| 40 |
+
echo -e "${BLUE}Checking prerequisites...${NC}"
|
| 41 |
+
if ! command -v docker &> /dev/null; then
|
| 42 |
+
echo -e "${RED}✗${NC} Docker is not installed"
|
| 43 |
+
exit 1
|
| 44 |
+
fi
|
| 45 |
+
echo -e "${GREEN}✓${NC} Docker is installed"
|
| 46 |
+
|
| 47 |
+
if ! command -v curl &> /dev/null; then
|
| 48 |
+
echo -e "${RED}✗${NC} curl is not installed"
|
| 49 |
+
exit 1
|
| 50 |
+
fi
|
| 51 |
+
echo -e "${GREEN}✓${NC} curl is installed"
|
| 52 |
+
|
| 53 |
+
# Check if we're in the right directory
|
| 54 |
+
if [ ! -f "envs/atari_env/server/Dockerfile" ]; then
|
| 55 |
+
echo -e "${RED}✗${NC} Must run from OpenEnv root directory"
|
| 56 |
+
exit 1
|
| 57 |
+
fi
|
| 58 |
+
echo -e "${GREEN}✓${NC} In correct directory"
|
| 59 |
+
|
| 60 |
+
# Step 1: Build Docker image
|
| 61 |
+
echo ""
|
| 62 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 63 |
+
echo -e "${BLUE}STEP 1: Building Docker Image${NC}"
|
| 64 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 65 |
+
|
| 66 |
+
echo "Building ${IMAGE_NAME}:${IMAGE_TAG}..."
|
| 67 |
+
if docker build -f envs/atari_env/server/Dockerfile -t ${IMAGE_NAME}:${IMAGE_TAG} . 2>&1 | tee /tmp/atari_build.log | tail -n 20; then
|
| 68 |
+
echo -e "${GREEN}✓${NC} Docker image built successfully"
|
| 69 |
+
else
|
| 70 |
+
echo -e "${RED}✗${NC} Docker build failed"
|
| 71 |
+
echo "See /tmp/atari_build.log for full output"
|
| 72 |
+
exit 1
|
| 73 |
+
fi
|
| 74 |
+
|
| 75 |
+
# Check image exists
|
| 76 |
+
if docker image inspect ${IMAGE_NAME}:${IMAGE_TAG} &> /dev/null; then
|
| 77 |
+
IMAGE_SIZE=$(docker image inspect ${IMAGE_NAME}:${IMAGE_TAG} --format='{{.Size}}' | awk '{print $1/1024/1024}')
|
| 78 |
+
echo -e "${GREEN}✓${NC} Image size: ${IMAGE_SIZE} MB"
|
| 79 |
+
else
|
| 80 |
+
echo -e "${RED}✗${NC} Image not found after build"
|
| 81 |
+
exit 1
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
# Step 2: Start container
|
| 85 |
+
echo ""
|
| 86 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 87 |
+
echo -e "${BLUE}STEP 2: Starting Container${NC}"
|
| 88 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 89 |
+
|
| 90 |
+
# Clean up any existing container
|
| 91 |
+
docker rm -f ${CONTAINER_NAME} 2>/dev/null || true
|
| 92 |
+
|
| 93 |
+
echo "Starting container on port ${PORT}..."
|
| 94 |
+
docker run -d \
|
| 95 |
+
--name ${CONTAINER_NAME} \
|
| 96 |
+
-p ${PORT}:8000 \
|
| 97 |
+
-e ATARI_GAME=pong \
|
| 98 |
+
-e ATARI_OBS_TYPE=ram \
|
| 99 |
+
-e ATARI_FRAMESKIP=4 \
|
| 100 |
+
${IMAGE_NAME}:${IMAGE_TAG}
|
| 101 |
+
|
| 102 |
+
if [ $? -eq 0 ]; then
|
| 103 |
+
echo -e "${GREEN}✓${NC} Container started: ${CONTAINER_NAME}"
|
| 104 |
+
else
|
| 105 |
+
echo -e "${RED}✗${NC} Failed to start container"
|
| 106 |
+
exit 1
|
| 107 |
+
fi
|
| 108 |
+
|
| 109 |
+
# Wait for container to be running
|
| 110 |
+
sleep 2
|
| 111 |
+
if docker ps | grep -q ${CONTAINER_NAME}; then
|
| 112 |
+
echo -e "${GREEN}✓${NC} Container is running"
|
| 113 |
+
else
|
| 114 |
+
echo -e "${RED}✗${NC} Container is not running"
|
| 115 |
+
docker logs ${CONTAINER_NAME}
|
| 116 |
+
exit 1
|
| 117 |
+
fi
|
| 118 |
+
|
| 119 |
+
# Step 3: Wait for health check
|
| 120 |
+
echo ""
|
| 121 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 122 |
+
echo -e "${BLUE}STEP 3: Waiting for Server${NC}"
|
| 123 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 124 |
+
|
| 125 |
+
echo "Waiting for server to be ready (timeout: ${HEALTH_RETRIES}s)..."
|
| 126 |
+
for i in $(seq 1 ${HEALTH_RETRIES}); do
|
| 127 |
+
if curl -s http://localhost:${PORT}/health > /dev/null 2>&1; then
|
| 128 |
+
echo -e "${GREEN}✓${NC} Server is ready (${i}s)"
|
| 129 |
+
break
|
| 130 |
+
fi
|
| 131 |
+
|
| 132 |
+
if [ $i -eq ${HEALTH_RETRIES} ]; then
|
| 133 |
+
echo -e "${RED}✗${NC} Server did not become ready in time"
|
| 134 |
+
echo "Container logs:"
|
| 135 |
+
docker logs ${CONTAINER_NAME}
|
| 136 |
+
exit 1
|
| 137 |
+
fi
|
| 138 |
+
|
| 139 |
+
echo -n "."
|
| 140 |
+
sleep ${HEALTH_DELAY}
|
| 141 |
+
done
|
| 142 |
+
|
| 143 |
+
# Step 4: Test health endpoint
|
| 144 |
+
echo ""
|
| 145 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 146 |
+
echo -e "${BLUE}STEP 4: Testing Health Endpoint${NC}"
|
| 147 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 148 |
+
|
| 149 |
+
HEALTH_RESPONSE=$(curl -s http://localhost:${PORT}/health)
|
| 150 |
+
echo "Response: ${HEALTH_RESPONSE}"
|
| 151 |
+
|
| 152 |
+
if echo "${HEALTH_RESPONSE}" | grep -q "healthy"; then
|
| 153 |
+
echo -e "${GREEN}✓${NC} Health endpoint working"
|
| 154 |
+
else
|
| 155 |
+
echo -e "${RED}✗${NC} Health endpoint failed"
|
| 156 |
+
exit 1
|
| 157 |
+
fi
|
| 158 |
+
|
| 159 |
+
# Step 5: Test reset endpoint
|
| 160 |
+
echo ""
|
| 161 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 162 |
+
echo -e "${BLUE}STEP 5: Testing Reset Endpoint${NC}"
|
| 163 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 164 |
+
|
| 165 |
+
RESET_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/reset -H "Content-Type: application/json" -d '{}')
|
| 166 |
+
|
| 167 |
+
if [ -z "${RESET_RESPONSE}" ]; then
|
| 168 |
+
echo -e "${RED}✗${NC} Reset endpoint returned empty response"
|
| 169 |
+
docker logs ${CONTAINER_NAME} | tail -20
|
| 170 |
+
exit 1
|
| 171 |
+
fi
|
| 172 |
+
|
| 173 |
+
echo "Response (first 200 chars): ${RESET_RESPONSE:0:200}..."
|
| 174 |
+
|
| 175 |
+
# Check if response contains expected fields
|
| 176 |
+
if echo "${RESET_RESPONSE}" | grep -q "observation" && \
|
| 177 |
+
echo "${RESET_RESPONSE}" | grep -q "screen" && \
|
| 178 |
+
echo "${RESET_RESPONSE}" | grep -q "legal_actions"; then
|
| 179 |
+
echo -e "${GREEN}✓${NC} Reset endpoint working"
|
| 180 |
+
|
| 181 |
+
# Extract some info
|
| 182 |
+
SCREEN_LEN=$(echo "${RESET_RESPONSE}" | grep -o '"screen":\[[^]]*\]' | wc -c)
|
| 183 |
+
echo " Screen data length: ${SCREEN_LEN} chars"
|
| 184 |
+
else
|
| 185 |
+
echo -e "${RED}✗${NC} Reset response missing required fields"
|
| 186 |
+
echo "Full response: ${RESET_RESPONSE}"
|
| 187 |
+
exit 1
|
| 188 |
+
fi
|
| 189 |
+
|
| 190 |
+
# Step 6: Test step endpoint
|
| 191 |
+
echo ""
|
| 192 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 193 |
+
echo -e "${BLUE}STEP 6: Testing Step Endpoint${NC}"
|
| 194 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 195 |
+
|
| 196 |
+
STEP_PAYLOAD='{"action": {"action_id": 0, "game_name": "pong"}}'
|
| 197 |
+
STEP_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/step -H "Content-Type: application/json" -d "${STEP_PAYLOAD}")
|
| 198 |
+
|
| 199 |
+
if [ -z "${STEP_RESPONSE}" ]; then
|
| 200 |
+
echo -e "${RED}✗${NC} Step endpoint returned empty response"
|
| 201 |
+
docker logs ${CONTAINER_NAME} | tail -20
|
| 202 |
+
exit 1
|
| 203 |
+
fi
|
| 204 |
+
|
| 205 |
+
echo "Response (first 200 chars): ${STEP_RESPONSE:0:200}..."
|
| 206 |
+
|
| 207 |
+
# Check if response contains expected fields
|
| 208 |
+
if echo "${STEP_RESPONSE}" | grep -q "observation" && \
|
| 209 |
+
echo "${STEP_RESPONSE}" | grep -q "reward" && \
|
| 210 |
+
echo "${STEP_RESPONSE}" | grep -q "done"; then
|
| 211 |
+
echo -e "${GREEN}✓${NC} Step endpoint working"
|
| 212 |
+
|
| 213 |
+
# Extract reward and done
|
| 214 |
+
REWARD=$(echo "${STEP_RESPONSE}" | grep -o '"reward":[^,}]*' | cut -d: -f2)
|
| 215 |
+
DONE=$(echo "${STEP_RESPONSE}" | grep -o '"done":[^,}]*' | cut -d: -f2)
|
| 216 |
+
echo " Reward: ${REWARD}"
|
| 217 |
+
echo " Done: ${DONE}"
|
| 218 |
+
else
|
| 219 |
+
echo -e "${RED}✗${NC} Step response missing required fields"
|
| 220 |
+
echo "Full response: ${STEP_RESPONSE}"
|
| 221 |
+
exit 1
|
| 222 |
+
fi
|
| 223 |
+
|
| 224 |
+
# Step 7: Test state endpoint
|
| 225 |
+
echo ""
|
| 226 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 227 |
+
echo -e "${BLUE}STEP 7: Testing State Endpoint${NC}"
|
| 228 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 229 |
+
|
| 230 |
+
STATE_RESPONSE=$(curl -s http://localhost:${PORT}/state)
|
| 231 |
+
|
| 232 |
+
if [ -z "${STATE_RESPONSE}" ]; then
|
| 233 |
+
echo -e "${RED}✗${NC} State endpoint returned empty response"
|
| 234 |
+
docker logs ${CONTAINER_NAME} | tail -20
|
| 235 |
+
exit 1
|
| 236 |
+
fi
|
| 237 |
+
|
| 238 |
+
echo "Response: ${STATE_RESPONSE}"
|
| 239 |
+
|
| 240 |
+
# Check if response contains expected fields
|
| 241 |
+
if echo "${STATE_RESPONSE}" | grep -q "episode_id" && \
|
| 242 |
+
echo "${STATE_RESPONSE}" | grep -q "step_count" && \
|
| 243 |
+
echo "${STATE_RESPONSE}" | grep -q "game_name"; then
|
| 244 |
+
echo -e "${GREEN}✓${NC} State endpoint working"
|
| 245 |
+
|
| 246 |
+
# Extract info
|
| 247 |
+
GAME_NAME=$(echo "${STATE_RESPONSE}" | grep -o '"game_name":"[^"]*"' | cut -d'"' -f4)
|
| 248 |
+
STEP_COUNT=$(echo "${STATE_RESPONSE}" | grep -o '"step_count":[^,}]*' | cut -d: -f2)
|
| 249 |
+
echo " Game: ${GAME_NAME}"
|
| 250 |
+
echo " Steps: ${STEP_COUNT}"
|
| 251 |
+
else
|
| 252 |
+
echo -e "${RED}✗${NC} State response missing required fields"
|
| 253 |
+
echo "Full response: ${STATE_RESPONSE}"
|
| 254 |
+
exit 1
|
| 255 |
+
fi
|
| 256 |
+
|
| 257 |
+
# Step 8: Test multiple steps
|
| 258 |
+
echo ""
|
| 259 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 260 |
+
echo -e "${BLUE}STEP 8: Testing Multiple Steps${NC}"
|
| 261 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 262 |
+
|
| 263 |
+
echo "Taking 10 steps..."
|
| 264 |
+
TOTAL_REWARD=0
|
| 265 |
+
for i in {1..10}; do
|
| 266 |
+
ACTION_ID=$((RANDOM % 3)) # Random action 0-2
|
| 267 |
+
STEP_PAYLOAD="{\"action\": {\"action_id\": ${ACTION_ID}, \"game_name\": \"pong\"}}"
|
| 268 |
+
STEP_RESPONSE=$(curl -s -X POST http://localhost:${PORT}/step -H "Content-Type: application/json" -d "${STEP_PAYLOAD}")
|
| 269 |
+
|
| 270 |
+
if ! echo "${STEP_RESPONSE}" | grep -q "observation"; then
|
| 271 |
+
echo -e "${RED}✗${NC} Step ${i} failed"
|
| 272 |
+
exit 1
|
| 273 |
+
fi
|
| 274 |
+
|
| 275 |
+
REWARD=$(echo "${STEP_RESPONSE}" | grep -o '"reward":[^,}]*' | cut -d: -f2 | sed 's/null/0/')
|
| 276 |
+
DONE=$(echo "${STEP_RESPONSE}" | grep -o '"done":[^,}]*' | cut -d: -f2)
|
| 277 |
+
|
| 278 |
+
echo " Step ${i}: action=${ACTION_ID}, reward=${REWARD}, done=${DONE}"
|
| 279 |
+
|
| 280 |
+
if [ "${DONE}" = "true" ]; then
|
| 281 |
+
echo " Episode completed early at step ${i}"
|
| 282 |
+
break
|
| 283 |
+
fi
|
| 284 |
+
done
|
| 285 |
+
|
| 286 |
+
echo -e "${GREEN}✓${NC} Multiple steps completed successfully"
|
| 287 |
+
|
| 288 |
+
# Step 9: Check container logs for errors
|
| 289 |
+
echo ""
|
| 290 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 291 |
+
echo -e "${BLUE}STEP 9: Checking Container Logs${NC}"
|
| 292 |
+
echo -e "${BLUE}━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━${NC}"
|
| 293 |
+
|
| 294 |
+
LOGS=$(docker logs ${CONTAINER_NAME} 2>&1)
|
| 295 |
+
|
| 296 |
+
if echo "${LOGS}" | grep -i "error" | grep -v "LoggerMode.Error"; then
|
| 297 |
+
echo -e "${YELLOW}⚠${NC} Found errors in logs:"
|
| 298 |
+
echo "${LOGS}" | grep -i "error" | head -5
|
| 299 |
+
else
|
| 300 |
+
echo -e "${GREEN}✓${NC} No errors in container logs"
|
| 301 |
+
fi
|
| 302 |
+
|
| 303 |
+
if echo "${LOGS}" | grep -i "exception"; then
|
| 304 |
+
echo -e "${RED}✗${NC} Found exceptions in logs:"
|
| 305 |
+
echo "${LOGS}" | grep -i "exception" | head -5
|
| 306 |
+
exit 1
|
| 307 |
+
else
|
| 308 |
+
echo -e "${GREEN}✓${NC} No exceptions in container logs"
|
| 309 |
+
fi
|
| 310 |
+
|
| 311 |
+
# Final Summary
|
| 312 |
+
echo ""
|
| 313 |
+
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
| 314 |
+
echo -e "${GREEN}✅ ALL DOCKER TESTS PASSED${NC}"
|
| 315 |
+
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
| 316 |
+
echo ""
|
| 317 |
+
echo "Summary:"
|
| 318 |
+
echo " ✓ Docker image built successfully"
|
| 319 |
+
echo " ✓ Container started and ran"
|
| 320 |
+
echo " ✓ Health endpoint working"
|
| 321 |
+
echo " ✓ Reset endpoint working"
|
| 322 |
+
echo " ✓ Step endpoint working"
|
| 323 |
+
echo " ✓ State endpoint working"
|
| 324 |
+
echo " ✓ Multiple steps working"
|
| 325 |
+
echo " ✓ No errors or exceptions"
|
| 326 |
+
echo ""
|
| 327 |
+
echo "Image: ${IMAGE_NAME}:${IMAGE_TAG}"
|
| 328 |
+
echo "Container: ${CONTAINER_NAME}"
|
| 329 |
+
echo "Port: ${PORT}"
|
| 330 |
+
echo ""
|
| 331 |
+
echo "To keep container running: docker start ${CONTAINER_NAME}"
|
| 332 |
+
echo "To view logs: docker logs ${CONTAINER_NAME}"
|
| 333 |
+
echo ""
|
models.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Data models for Atari Environment.
|
| 9 |
+
|
| 10 |
+
This module defines the Action, Observation, and State types for Atari games
|
| 11 |
+
via the Arcade Learning Environment (ALE).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 17 |
+
|
| 18 |
+
from openenv.core.env_server import Action, Observation, State
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AtariAction(Action):
|
| 22 |
+
"""
|
| 23 |
+
Action for Atari environments.
|
| 24 |
+
|
| 25 |
+
Attributes:
|
| 26 |
+
action_id: The integer action ID to take (from legal_actions).
|
| 27 |
+
game_name: Name of the Atari game (e.g., "pong", "breakout", "space_invaders").
|
| 28 |
+
obs_type: Observation type ("rgb", "grayscale", or "ram").
|
| 29 |
+
full_action_space: Whether to use full (18 actions) or minimal action space.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
action_id: int
|
| 33 |
+
game_name: str = "pong"
|
| 34 |
+
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
|
| 35 |
+
full_action_space: bool = False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class AtariObservation(Observation):
|
| 39 |
+
"""
|
| 40 |
+
Observation from Atari environment.
|
| 41 |
+
|
| 42 |
+
This represents what the agent sees after taking an action.
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
screen: Screen observation as a flattened list of pixels.
|
| 46 |
+
Shape depends on obs_type:
|
| 47 |
+
- rgb: [210, 160, 3] flattened
|
| 48 |
+
- grayscale: [210, 160] flattened
|
| 49 |
+
- ram: [128] (RAM contents)
|
| 50 |
+
screen_shape: Original shape of the screen before flattening.
|
| 51 |
+
legal_actions: List of legal action IDs the agent can take.
|
| 52 |
+
lives: Number of lives remaining.
|
| 53 |
+
episode_frame_number: Frame number within current episode.
|
| 54 |
+
frame_number: Total frame number since environment creation.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
screen: List[int]
|
| 58 |
+
screen_shape: List[int]
|
| 59 |
+
legal_actions: List[int]
|
| 60 |
+
lives: int = 0
|
| 61 |
+
episode_frame_number: int = 0
|
| 62 |
+
frame_number: int = 0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class AtariState(State):
|
| 66 |
+
"""
|
| 67 |
+
State for Atari environment.
|
| 68 |
+
|
| 69 |
+
Attributes:
|
| 70 |
+
game_name: Name of the Atari game.
|
| 71 |
+
obs_type: Observation type ("rgb", "grayscale", or "ram").
|
| 72 |
+
full_action_space: Whether using full or minimal action space.
|
| 73 |
+
mode: Game mode (if applicable).
|
| 74 |
+
difficulty: Game difficulty (if applicable).
|
| 75 |
+
repeat_action_probability: Probability of repeating previous action (sticky actions).
|
| 76 |
+
frameskip: Number of frames to skip per action.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
game_name: str = "pong"
|
| 80 |
+
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb"
|
| 81 |
+
full_action_space: bool = False
|
| 82 |
+
mode: Optional[int] = None
|
| 83 |
+
difficulty: Optional[int] = None
|
| 84 |
+
repeat_action_probability: float = 0.0
|
| 85 |
+
frameskip: int = 4
|
pyproject.toml
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-core"
|
| 7 |
+
version = "0.2.2.dev0"
|
| 8 |
+
description = "A unified framework for reinforcement learning environments"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
dependencies = [
|
| 12 |
+
# Core shared dependencies - minimal set required for all environments
|
| 13 |
+
# Heavy dependencies (torch, numpy, smolagents, etc.) should be in
|
| 14 |
+
# individual environment pyproject.toml files
|
| 15 |
+
"fastapi>=0.104.0",
|
| 16 |
+
"pydantic>=2.0.0",
|
| 17 |
+
"uvicorn>=0.24.0",
|
| 18 |
+
"requests>=2.25.0",
|
| 19 |
+
# CLI dependencies
|
| 20 |
+
"typer>=0.9.0",
|
| 21 |
+
"rich>=13.0.0",
|
| 22 |
+
"pyyaml>=6.0",
|
| 23 |
+
"huggingface_hub>=0.20.0",
|
| 24 |
+
"openai>=2.7.2",
|
| 25 |
+
"tomli>=2.3.0",
|
| 26 |
+
"tomli-w>=1.2.0",
|
| 27 |
+
"websockets>=15.0.1",
|
| 28 |
+
# MCP support
|
| 29 |
+
"fastmcp>=3.0.0",
|
| 30 |
+
# Web UI dependencies
|
| 31 |
+
"gradio>=4.0.0",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[project.optional-dependencies]
|
| 35 |
+
core = [
|
| 36 |
+
"fastapi>=0.104.0",
|
| 37 |
+
"pydantic>=2.0.0",
|
| 38 |
+
"uvicorn>=0.24.0",
|
| 39 |
+
"requests>=2.25.0",
|
| 40 |
+
"websockets>=15.0.1",
|
| 41 |
+
]
|
| 42 |
+
cli = [
|
| 43 |
+
"typer>=0.9.0",
|
| 44 |
+
"rich>=13.0.0",
|
| 45 |
+
"pyyaml>=6.0",
|
| 46 |
+
"huggingface_hub>=0.20.0",
|
| 47 |
+
"openai>=2.7.2",
|
| 48 |
+
"tomli>=2.3.0",
|
| 49 |
+
"tomli-w>=1.2.0",
|
| 50 |
+
]
|
| 51 |
+
docs = [
|
| 52 |
+
"sphinx==7.2.6",
|
| 53 |
+
"pytorch-sphinx-theme2",
|
| 54 |
+
"sphinxcontrib.katex==0.9.10",
|
| 55 |
+
"docutils>=0.18.1,<0.21",
|
| 56 |
+
"sphinx-design==0.6.1",
|
| 57 |
+
"sphinxcontrib-mermaid==1.0.0",
|
| 58 |
+
"myst-parser",
|
| 59 |
+
"sphinxext-opengraph",
|
| 60 |
+
"sphinx-sitemap==2.7.1",
|
| 61 |
+
"sphinx-gallery>=0.14.0",
|
| 62 |
+
"matplotlib",
|
| 63 |
+
"nest-asyncio",
|
| 64 |
+
"smolagents",
|
| 65 |
+
]
|
| 66 |
+
all = [
|
| 67 |
+
"openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main",
|
| 68 |
+
"openenv-core[cli]",
|
| 69 |
+
]
|
| 70 |
+
daytona = [
|
| 71 |
+
"daytona>=0.136.0",
|
| 72 |
+
"pyyaml>=6.0",
|
| 73 |
+
]
|
| 74 |
+
inspect = [
|
| 75 |
+
"inspect-ai>=0.3.0",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
[project.scripts]
|
| 79 |
+
openenv = "openenv.cli.__main__:main"
|
| 80 |
+
|
| 81 |
+
[tool.setuptools]
|
| 82 |
+
package-dir = {"" = "src"}
|
| 83 |
+
include-package-data = true
|
| 84 |
+
|
| 85 |
+
[tool.setuptools.package-data]
|
| 86 |
+
"openenv.cli" = ["templates/**/*"]
|
| 87 |
+
|
| 88 |
+
[tool.setuptools.packages.find]
|
| 89 |
+
where = ["src"]
|
| 90 |
+
|
| 91 |
+
[tool.coverage.run]
|
| 92 |
+
omit = [
|
| 93 |
+
"openenv/cli/templates/**",
|
| 94 |
+
"**/templates/**",
|
| 95 |
+
"openenv/cli/__main__.py",
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
[tool.coverage.report]
|
| 99 |
+
exclude_lines = [
|
| 100 |
+
"pragma: no cover",
|
| 101 |
+
"def __repr__",
|
| 102 |
+
"raise AssertionError",
|
| 103 |
+
"raise NotImplementedError",
|
| 104 |
+
"if __name__ == .__main__.:",
|
| 105 |
+
"if TYPE_CHECKING:",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
[tool.pytest.ini_options]
|
| 109 |
+
asyncio_mode = "auto"
|
| 110 |
+
asyncio_default_fixture_loop_scope = "function"
|
| 111 |
+
markers = [
|
| 112 |
+
"docker: Tests that require Docker to be running",
|
| 113 |
+
"network: Tests that require network access (HuggingFace, etc.)",
|
| 114 |
+
"integration: Integration tests with external resources",
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
[dependency-groups]
|
| 118 |
+
dev = [
|
| 119 |
+
"ruff>=0.14.0",
|
| 120 |
+
"usort>=1.1.0",
|
| 121 |
+
"pytest>=7.0",
|
| 122 |
+
"pytest-asyncio>=0.21",
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
[tool.usort]
|
| 126 |
+
# Disable first_party auto-detection so all non-stdlib imports land in
|
| 127 |
+
# the same "third_party" bucket (the default_category). This matches
|
| 128 |
+
# pyfmt's usort behavior inside arc f, which groups openenv.* and env
|
| 129 |
+
# package imports together without blank-line separators.
|
| 130 |
+
first_party_detection = false
|
| 131 |
+
|
| 132 |
+
[tool.ruff]
|
| 133 |
+
line-length = 88
|
| 134 |
+
|
| 135 |
+
[tool.ruff.lint]
|
| 136 |
+
select = ["E", "F", "W"]
|
| 137 |
+
ignore = [
|
| 138 |
+
"E402", # Module level import not at top of file (needed for pytest.importorskip patterns)
|
| 139 |
+
"E501", # Line too long (not enforced previously, would require large refactor)
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
[tool.ruff.lint.per-file-ignores]
|
| 143 |
+
# Context manager variables that are intentionally unused
|
| 144 |
+
"tests/envs/test_websockets.py" = ["F841"]
|
| 145 |
+
"tests/test_cli/test_push.py" = ["F841"]
|
| 146 |
+
# Compatibility shim module
|
| 147 |
+
"src/openenv_core/__init__.py" = ["F401"]
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dockerfile for Atari Environment
|
| 2 |
+
# This image provides Atari 2600 games via the Arcade Learning Environment (ALE)
|
| 3 |
+
|
| 4 |
+
# Configurable base image - defaults to local build, can be overridden for CI/CD
|
| 5 |
+
# Base image provides: fastapi, uvicorn, requests, curl, PYTHONPATH=/app/src
|
| 6 |
+
#
|
| 7 |
+
# Local build: docker build -t envtorch-base:latest -f src/core/containers/images/Dockerfile .
|
| 8 |
+
# docker build -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 9 |
+
#
|
| 10 |
+
# CI/CD build: docker build --build-arg BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest \
|
| 11 |
+
# -f envs/atari_env/server/Dockerfile -t atari-env:latest .
|
| 12 |
+
ARG BASE_IMAGE=openenv-base:latest
|
| 13 |
+
FROM ${BASE_IMAGE}
|
| 14 |
+
|
| 15 |
+
# Install dependencies
|
| 16 |
+
COPY envs/atari_env/server/requirements.txt /tmp/requirements.txt
|
| 17 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy OpenEnv core (base image already set WORKDIR=/app)
|
| 20 |
+
COPY src/core/ /app/src/core/
|
| 21 |
+
|
| 22 |
+
# Copy Atari environment code
|
| 23 |
+
COPY envs/atari_env/ /app/envs/atari_env/
|
| 24 |
+
|
| 25 |
+
# Copy README for web interface documentation
|
| 26 |
+
COPY envs/atari_env/README.md /app/README.md
|
| 27 |
+
|
| 28 |
+
# Atari-specific environment variables (can be overridden at runtime)
|
| 29 |
+
ENV ATARI_GAME=pong
|
| 30 |
+
ENV ATARI_OBS_TYPE=rgb
|
| 31 |
+
ENV ATARI_FULL_ACTION_SPACE=false
|
| 32 |
+
ENV ATARI_REPEAT_ACTION_PROB=0.0
|
| 33 |
+
ENV ATARI_FRAMESKIP=4
|
| 34 |
+
|
| 35 |
+
# Expose port
|
| 36 |
+
EXPOSE 8000
|
| 37 |
+
|
| 38 |
+
# Health check
|
| 39 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 40 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 41 |
+
|
| 42 |
+
# Run the FastAPI server
|
| 43 |
+
CMD ["uvicorn", "envs.atari_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment Server.
|
| 9 |
+
|
| 10 |
+
Server-side implementation of Atari environment for OpenEnv.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from .atari_environment import AtariEnvironment
|
| 14 |
+
|
| 15 |
+
__all__ = ["AtariEnvironment"]
|
server/app.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
FastAPI application for the Atari Environment.
|
| 9 |
+
|
| 10 |
+
This module creates an HTTP server that exposes Atari games
|
| 11 |
+
over HTTP and WebSocket endpoints, compatible with EnvClient.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
# Development (with auto-reload):
|
| 15 |
+
uvicorn envs.atari_env.server.app:app --reload --host 0.0.0.0 --port 8000
|
| 16 |
+
|
| 17 |
+
# Production:
|
| 18 |
+
uvicorn envs.atari_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4
|
| 19 |
+
|
| 20 |
+
# Or run directly:
|
| 21 |
+
python -m envs.atari_env.server.app
|
| 22 |
+
|
| 23 |
+
Environment variables:
|
| 24 |
+
ATARI_GAME: Game name to serve (default: "pong")
|
| 25 |
+
ATARI_OBS_TYPE: Observation type (default: "rgb")
|
| 26 |
+
ATARI_FULL_ACTION_SPACE: Use full action space (default: "false")
|
| 27 |
+
ATARI_MODE: Game mode (optional)
|
| 28 |
+
ATARI_DIFFICULTY: Game difficulty (optional)
|
| 29 |
+
ATARI_REPEAT_ACTION_PROB: Sticky action probability (default: "0.0")
|
| 30 |
+
ATARI_FRAMESKIP: Frameskip (default: "4")
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
from openenv.core.env_server import create_app
|
| 36 |
+
|
| 37 |
+
from ..models import AtariAction, AtariObservation
|
| 38 |
+
from .atari_environment import AtariEnvironment
|
| 39 |
+
|
| 40 |
+
# Get configuration from environment variables
|
| 41 |
+
game_name = os.getenv("ATARI_GAME", "pong")
|
| 42 |
+
obs_type = os.getenv("ATARI_OBS_TYPE", "rgb")
|
| 43 |
+
full_action_space = os.getenv("ATARI_FULL_ACTION_SPACE", "false").lower() == "true"
|
| 44 |
+
repeat_action_prob = float(os.getenv("ATARI_REPEAT_ACTION_PROB", "0.0"))
|
| 45 |
+
frameskip = int(os.getenv("ATARI_FRAMESKIP", "4"))
|
| 46 |
+
|
| 47 |
+
# Optional parameters
|
| 48 |
+
mode = os.getenv("ATARI_MODE")
|
| 49 |
+
difficulty = os.getenv("ATARI_DIFFICULTY")
|
| 50 |
+
|
| 51 |
+
# Convert to int if specified
|
| 52 |
+
mode = int(mode) if mode is not None else None
|
| 53 |
+
difficulty = int(difficulty) if difficulty is not None else None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Factory function to create AtariEnvironment instances
|
| 57 |
+
def create_atari_environment():
|
| 58 |
+
"""Factory function that creates AtariEnvironment with config."""
|
| 59 |
+
return AtariEnvironment(
|
| 60 |
+
game_name=game_name,
|
| 61 |
+
obs_type=obs_type,
|
| 62 |
+
full_action_space=full_action_space,
|
| 63 |
+
mode=mode,
|
| 64 |
+
difficulty=difficulty,
|
| 65 |
+
repeat_action_probability=repeat_action_prob,
|
| 66 |
+
frameskip=frameskip,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Create the FastAPI app with web interface and README integration
|
| 71 |
+
# Pass the factory function instead of an instance for WebSocket session support
|
| 72 |
+
app = create_app(
|
| 73 |
+
create_atari_environment, AtariAction, AtariObservation, env_name="atari_env"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
import uvicorn
|
| 79 |
+
|
| 80 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
server/atari_environment.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Atari Environment Server Implementation.
|
| 9 |
+
|
| 10 |
+
This module wraps ALE's ALEInterface and exposes it
|
| 11 |
+
via the OpenEnv Environment interface.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import uuid
|
| 15 |
+
from typing import Any, Dict, Literal, Optional
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_server import Action, Environment, Observation
|
| 18 |
+
|
| 19 |
+
from ..models import AtariAction, AtariObservation, AtariState
|
| 20 |
+
|
| 21 |
+
# Import ALE
|
| 22 |
+
try:
|
| 23 |
+
import numpy as np
|
| 24 |
+
from ale_py import ALEInterface, roms
|
| 25 |
+
except ImportError as e:
|
| 26 |
+
raise ImportError(
|
| 27 |
+
"ALE (Arcade Learning Environment) is not installed. "
|
| 28 |
+
"Please install it with: pip install ale-py"
|
| 29 |
+
) from e
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AtariEnvironment(Environment):
|
| 33 |
+
"""
|
| 34 |
+
Atari Environment wrapper for OpenEnv.
|
| 35 |
+
|
| 36 |
+
This environment wraps Atari 2600 games via the Arcade Learning Environment (ALE)
|
| 37 |
+
and provides a clean interface for RL training.
|
| 38 |
+
|
| 39 |
+
Supported games include: pong, breakout, space_invaders, and 100+ others.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
game_name: Name of the Atari game (e.g., "pong", "breakout").
|
| 43 |
+
obs_type: Observation type - "rgb", "grayscale", or "ram".
|
| 44 |
+
full_action_space: Use full action space (18 actions) vs minimal.
|
| 45 |
+
mode: Game mode (if applicable).
|
| 46 |
+
difficulty: Game difficulty (if applicable).
|
| 47 |
+
repeat_action_probability: Sticky action probability (default 0.0).
|
| 48 |
+
frameskip: Number of frames to skip per action (default 4).
|
| 49 |
+
|
| 50 |
+
Example:
|
| 51 |
+
>>> env = AtariEnvironment("pong")
|
| 52 |
+
>>> obs = env.reset()
|
| 53 |
+
>>> print(obs.screen_shape) # [210, 160, 3]
|
| 54 |
+
>>> obs = env.step(AtariAction(action_id=2)) # UP
|
| 55 |
+
>>> print(obs.reward, obs.done)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
game_name: str = "pong",
|
| 61 |
+
obs_type: Literal["rgb", "grayscale", "ram"] = "rgb",
|
| 62 |
+
full_action_space: bool = False,
|
| 63 |
+
mode: Optional[int] = None,
|
| 64 |
+
difficulty: Optional[int] = None,
|
| 65 |
+
repeat_action_probability: float = 0.0,
|
| 66 |
+
frameskip: int = 4,
|
| 67 |
+
):
|
| 68 |
+
"""Initialize Atari environment."""
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.game_name = game_name
|
| 72 |
+
self.obs_type = obs_type
|
| 73 |
+
self.full_action_space = full_action_space
|
| 74 |
+
self.mode = mode
|
| 75 |
+
self.difficulty = difficulty
|
| 76 |
+
self.repeat_action_probability = repeat_action_probability
|
| 77 |
+
self.frameskip = frameskip
|
| 78 |
+
|
| 79 |
+
# Create ALE interface
|
| 80 |
+
self.ale = ALEInterface()
|
| 81 |
+
|
| 82 |
+
# Configure ALE
|
| 83 |
+
from ale_py import LoggerMode
|
| 84 |
+
|
| 85 |
+
self.ale.setLoggerMode(LoggerMode.Error) # Error mode only
|
| 86 |
+
self.ale.setFloat("repeat_action_probability", repeat_action_probability)
|
| 87 |
+
|
| 88 |
+
# Load ROM
|
| 89 |
+
try:
|
| 90 |
+
rom_path = roms.get_rom_path(game_name)
|
| 91 |
+
self.ale.loadROM(rom_path)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"Failed to load Atari game '{game_name}': {e}\n"
|
| 95 |
+
f"Available games can be found via: ale_py.roms.list_roms()"
|
| 96 |
+
) from e
|
| 97 |
+
|
| 98 |
+
# Set mode and difficulty if specified
|
| 99 |
+
if mode is not None:
|
| 100 |
+
self.ale.setMode(mode)
|
| 101 |
+
if difficulty is not None:
|
| 102 |
+
self.ale.setDifficulty(difficulty)
|
| 103 |
+
|
| 104 |
+
# Get action set
|
| 105 |
+
if full_action_space:
|
| 106 |
+
self._action_set = self.ale.getLegalActionSet()
|
| 107 |
+
else:
|
| 108 |
+
self._action_set = self.ale.getMinimalActionSet()
|
| 109 |
+
|
| 110 |
+
# Get screen dimensions for observation space
|
| 111 |
+
self.screen_height, self.screen_width = self.ale.getScreenDims()
|
| 112 |
+
if obs_type == "rgb":
|
| 113 |
+
self.screen_shape = [self.screen_height, self.screen_width, 3]
|
| 114 |
+
elif obs_type == "grayscale":
|
| 115 |
+
self.screen_shape = [self.screen_height, self.screen_width]
|
| 116 |
+
elif obs_type == "ram":
|
| 117 |
+
self.screen_shape = [self.ale.getRAMSize()]
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError(f"Invalid obs_type: {obs_type}")
|
| 120 |
+
|
| 121 |
+
# Initialize state
|
| 122 |
+
self._state = AtariState(
|
| 123 |
+
game_name=game_name,
|
| 124 |
+
obs_type=obs_type,
|
| 125 |
+
full_action_space=full_action_space,
|
| 126 |
+
mode=mode,
|
| 127 |
+
difficulty=difficulty,
|
| 128 |
+
repeat_action_probability=repeat_action_probability,
|
| 129 |
+
frameskip=frameskip,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def reset(self) -> Observation:
|
| 133 |
+
"""
|
| 134 |
+
Reset the environment and return initial observation.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Initial observation for the agent.
|
| 138 |
+
"""
|
| 139 |
+
# Reset ALE
|
| 140 |
+
self.ale.reset_game()
|
| 141 |
+
|
| 142 |
+
# Reset state tracking
|
| 143 |
+
self._state.episode_id = str(uuid.uuid4())
|
| 144 |
+
self._state.step_count = 0
|
| 145 |
+
|
| 146 |
+
# Get initial observation
|
| 147 |
+
return self._make_observation()
|
| 148 |
+
|
| 149 |
+
def step(self, action: Action) -> Observation:
|
| 150 |
+
"""
|
| 151 |
+
Execute agent's action and return resulting observation.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
action: AtariAction containing the action_id to execute.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Observation after action execution.
|
| 158 |
+
|
| 159 |
+
Raises:
|
| 160 |
+
ValueError: If action is not an AtariAction.
|
| 161 |
+
"""
|
| 162 |
+
if not isinstance(action, AtariAction):
|
| 163 |
+
raise ValueError(f"Expected AtariAction, got {type(action)}")
|
| 164 |
+
|
| 165 |
+
# Validate action_id
|
| 166 |
+
if action.action_id < 0 or action.action_id >= len(self._action_set):
|
| 167 |
+
raise ValueError(
|
| 168 |
+
f"Invalid action_id: {action.action_id}. "
|
| 169 |
+
f"Valid range: [0, {len(self._action_set) - 1}]"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Get actual ALE action
|
| 173 |
+
ale_action = self._action_set[action.action_id]
|
| 174 |
+
|
| 175 |
+
# Execute action with frameskip
|
| 176 |
+
total_reward = 0.0
|
| 177 |
+
for _ in range(self.frameskip):
|
| 178 |
+
total_reward += self.ale.act(ale_action)
|
| 179 |
+
if self.ale.game_over():
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
self._state.step_count += 1
|
| 183 |
+
|
| 184 |
+
# Get observation
|
| 185 |
+
obs = self._make_observation()
|
| 186 |
+
obs.reward = total_reward
|
| 187 |
+
|
| 188 |
+
return obs
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def state(self) -> AtariState:
|
| 192 |
+
"""Get current environment state."""
|
| 193 |
+
return self._state
|
| 194 |
+
|
| 195 |
+
def _make_observation(self) -> AtariObservation:
|
| 196 |
+
"""
|
| 197 |
+
Create an AtariObservation from current ALE state.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
AtariObservation for the agent.
|
| 201 |
+
"""
|
| 202 |
+
# Get screen observation
|
| 203 |
+
if self.obs_type == "rgb":
|
| 204 |
+
screen = self.ale.getScreenRGB()
|
| 205 |
+
elif self.obs_type == "grayscale":
|
| 206 |
+
screen = self.ale.getScreenGrayscale()
|
| 207 |
+
elif self.obs_type == "ram":
|
| 208 |
+
screen = self.ale.getRAM()
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(f"Invalid obs_type: {self.obs_type}")
|
| 211 |
+
|
| 212 |
+
# Flatten screen for JSON serialization
|
| 213 |
+
# Handle both numpy arrays and lists
|
| 214 |
+
if hasattr(screen, "flatten"):
|
| 215 |
+
screen_flat = screen.flatten().tolist()
|
| 216 |
+
elif hasattr(screen, "tolist"):
|
| 217 |
+
screen_flat = screen.tolist()
|
| 218 |
+
else:
|
| 219 |
+
screen_flat = list(screen)
|
| 220 |
+
|
| 221 |
+
# Get game info
|
| 222 |
+
lives = self.ale.lives()
|
| 223 |
+
episode_frame_number = self.ale.getEpisodeFrameNumber()
|
| 224 |
+
frame_number = self.ale.getFrameNumber()
|
| 225 |
+
done = self.ale.game_over()
|
| 226 |
+
|
| 227 |
+
# Create legal actions list (indices into action_set)
|
| 228 |
+
legal_actions = list(range(len(self._action_set)))
|
| 229 |
+
|
| 230 |
+
# Create observation
|
| 231 |
+
obs = AtariObservation(
|
| 232 |
+
screen=screen_flat,
|
| 233 |
+
screen_shape=self.screen_shape,
|
| 234 |
+
legal_actions=legal_actions,
|
| 235 |
+
lives=lives,
|
| 236 |
+
episode_frame_number=episode_frame_number,
|
| 237 |
+
frame_number=frame_number,
|
| 238 |
+
done=done,
|
| 239 |
+
reward=0.0, # Will be filled in by step()
|
| 240 |
+
metadata={
|
| 241 |
+
"game_name": self.game_name,
|
| 242 |
+
"action_meanings": [str(a) for a in self._action_set],
|
| 243 |
+
},
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return obs
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gymnasium>=0.29.0
|
| 2 |
+
ale-py>=0.8.0
|
| 3 |
+
numpy>=1.24.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""EnvTorch: Standardized agentic execution environments."""
|
src/core/README.md
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# <img width="35" height="35" alt="image" src="https://github.com/user-attachments/assets/2700a971-e5d6-4036-b03f-2f89c9791609" /> OpenEnv: Agentic Execution Environments
|
| 2 |
+
|
| 3 |
+
An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - step(), reset(), state(). Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs.
|
| 4 |
+
|
| 5 |
+
In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## Overview
|
| 9 |
+
`openenv.core` provides the foundational building blocks for creating and interacting with containerized environments over HTTP. It enables you to build agent environments that can be deployed as Docker containers and accessed via a simple HTTP API.
|
| 10 |
+
|
| 11 |
+
> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental
|
| 12 |
+
> stage. You should expect bugs, incomplete features, and APIs that may change
|
| 13 |
+
> in future versions. The project welcomes bugfixes, but to make sure things are
|
| 14 |
+
> well coordinated you should discuss any significant change before starting the
|
| 15 |
+
> work. It's recommended that you signal your intention to contribute in the
|
| 16 |
+
> issue tracker, either by filing a new issue or by claiming an existing one.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# OpenEnv Core
|
| 20 |
+
|
| 21 |
+
Core components for OpenEnv - a framework for building HTTP-based agentic environments.
|
| 22 |
+
|
| 23 |
+
## Features
|
| 24 |
+
|
| 25 |
+
- **EnvClient**: Async-first client for interacting with remote environments
|
| 26 |
+
- **SyncEnvClient**: Synchronous wrapper via `.sync()` for sync codebases
|
| 27 |
+
- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket
|
| 28 |
+
- **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.)
|
| 29 |
+
- **Type System**: Strongly-typed Action/Observation/State interfaces
|
| 30 |
+
- **Web Interface**: Optional web UI for interacting with environments
|
| 31 |
+
|
| 32 |
+
## Installation
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
pip install "openenv[core]"
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
For development:
|
| 39 |
+
```bash
|
| 40 |
+
pip install "openenv[core]"
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Quick Start
|
| 44 |
+
|
| 45 |
+
### Creating an Environment Client
|
| 46 |
+
|
| 47 |
+
EnvClient is **async by default**. Use `async with` and `await` for all operations:
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
import asyncio
|
| 51 |
+
from openenv.core import EnvClient, StepResult
|
| 52 |
+
from dataclasses import dataclass
|
| 53 |
+
from typing import Any
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class MyAction:
|
| 57 |
+
text: str
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class MyObservation:
|
| 61 |
+
response: str
|
| 62 |
+
|
| 63 |
+
class MyEnvClient(EnvClient[MyAction, MyObservation, Any]):
|
| 64 |
+
def _step_payload(self, action: MyAction) -> dict:
|
| 65 |
+
return {"text": action.text}
|
| 66 |
+
|
| 67 |
+
def _parse_result(self, payload: dict) -> StepResult[MyObservation]:
|
| 68 |
+
obs_data = payload["observation"]
|
| 69 |
+
return StepResult(
|
| 70 |
+
observation=MyObservation(**obs_data),
|
| 71 |
+
reward=payload.get("reward"),
|
| 72 |
+
done=payload.get("done", False)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def _parse_state(self, payload: dict) -> Any:
|
| 76 |
+
return payload
|
| 77 |
+
|
| 78 |
+
# Async usage (recommended)
|
| 79 |
+
async def main():
|
| 80 |
+
client = await MyEnvClient.from_docker_image("my-env:latest")
|
| 81 |
+
async with client:
|
| 82 |
+
result = await client.reset()
|
| 83 |
+
step_result = await client.step(MyAction(text="hello"))
|
| 84 |
+
|
| 85 |
+
asyncio.run(main())
|
| 86 |
+
|
| 87 |
+
# Sync usage (via .sync() wrapper)
|
| 88 |
+
with MyEnvClient(base_url="http://localhost:8000").sync() as client:
|
| 89 |
+
result = client.reset()
|
| 90 |
+
step_result = client.step(MyAction(text="hello"))
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
### Creating an Environment Server
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
from openenv.core.env_server import Environment, HTTPEnvServer, create_app
|
| 97 |
+
from dataclasses import dataclass
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class MyAction:
|
| 101 |
+
text: str
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class MyObservation:
|
| 105 |
+
response: str
|
| 106 |
+
reward: float = 0.0
|
| 107 |
+
done: bool = False
|
| 108 |
+
|
| 109 |
+
class MyEnvironment(Environment):
|
| 110 |
+
def reset(self) -> MyObservation:
|
| 111 |
+
return MyObservation(response="Ready")
|
| 112 |
+
|
| 113 |
+
def step(self, action: MyAction) -> MyObservation:
|
| 114 |
+
return MyObservation(
|
| 115 |
+
response=f"Echo: {action.text}",
|
| 116 |
+
reward=1.0,
|
| 117 |
+
done=False
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Create FastAPI app
|
| 121 |
+
env = MyEnvironment()
|
| 122 |
+
app = create_app(env, MyAction, MyObservation)
|
| 123 |
+
|
| 124 |
+
# Run with: uvicorn module:app --host 0.0.0.0 --port 8000
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
## Container Providers
|
| 128 |
+
|
| 129 |
+
OpenEnv Core supports multiple container providers:
|
| 130 |
+
|
| 131 |
+
### Local Docker Provider
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
from openenv.core.containers.runtime import LocalDockerProvider
|
| 135 |
+
|
| 136 |
+
provider = LocalDockerProvider()
|
| 137 |
+
base_url = provider.start_container("my-env:latest")
|
| 138 |
+
provider.wait_for_ready(base_url)
|
| 139 |
+
# Use environment...
|
| 140 |
+
provider.stop_container()
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
### Kubernetes Provider (Coming Soon)
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from openenv.core.containers.runtime import KubernetesProvider
|
| 147 |
+
|
| 148 |
+
provider = KubernetesProvider(namespace="envs")
|
| 149 |
+
base_url = provider.start_container("my-env:latest")
|
| 150 |
+
# Use environment...
|
| 151 |
+
provider.stop_container()
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
## API Reference
|
| 156 |
+
|
| 157 |
+
### EnvClient
|
| 158 |
+
|
| 159 |
+
Async base class for environment clients. Key methods:
|
| 160 |
+
|
| 161 |
+
- `async connect()`: Establish WebSocket connection
|
| 162 |
+
- `async reset(**kwargs)`: Reset environment
|
| 163 |
+
- `async step(action)`: Execute action
|
| 164 |
+
- `async state()`: Get current state
|
| 165 |
+
- `async close()`: Close connection and cleanup
|
| 166 |
+
- `sync()`: Return a SyncEnvClient wrapper for synchronous usage
|
| 167 |
+
|
| 168 |
+
Abstract methods to implement:
|
| 169 |
+
- `_step_payload(action)`: Convert action to JSON
|
| 170 |
+
- `_parse_result(payload)`: Parse response to StepResult
|
| 171 |
+
- `_parse_state(payload)`: Parse state response
|
| 172 |
+
|
| 173 |
+
### SyncEnvClient
|
| 174 |
+
|
| 175 |
+
Synchronous wrapper around EnvClient. Use `client.sync()` to get one:
|
| 176 |
+
|
| 177 |
+
```python
|
| 178 |
+
sync_client = async_client.sync()
|
| 179 |
+
with sync_client:
|
| 180 |
+
result = sync_client.reset()
|
| 181 |
+
result = sync_client.step(action)
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
### HTTPEnvServer
|
| 185 |
+
|
| 186 |
+
Server wrapper with these methods:
|
| 187 |
+
|
| 188 |
+
- `register_routes(app)`: Register endpoints on FastAPI app
|
| 189 |
+
- `_deserialize_action(data)`: Convert JSON to Action
|
| 190 |
+
- `_serialize_observation(obs)`: Convert Observation to JSON
|
| 191 |
+
|
| 192 |
+
### Environment Interface
|
| 193 |
+
|
| 194 |
+
Base interface for environment implementations:
|
| 195 |
+
|
| 196 |
+
- `reset()`: Reset environment and return initial observation
|
| 197 |
+
- `step(action)`: Execute action and return observation
|
| 198 |
+
- `state`: Property returning current environment state
|
| 199 |
+
|
| 200 |
+
## License
|
| 201 |
+
|
| 202 |
+
This project is licensed under the BSD-3-Clause License - see the LICENSE file for details.
|
| 203 |
+
|
| 204 |
+
## Contributing
|
| 205 |
+
|
| 206 |
+
Contributions are welcome! Please see the main OpenEnv repository for contribution guidelines.
|
| 207 |
+
|
| 208 |
+
## Links
|
| 209 |
+
|
| 210 |
+
- **Homepage**: https://github.com/meta-pytorch/OpenEnv
|
| 211 |
+
- **Documentation**: https://github.com/meta-pytorch/OpenEnv/blob/main/README.md
|
| 212 |
+
- **Bug Tracker**: https://github.com/meta-pytorch/OpenEnv/issues
|
src/core/__init__.py
CHANGED
|
@@ -6,14 +6,76 @@
|
|
| 6 |
|
| 7 |
"""Core components for agentic environments."""
|
| 8 |
|
| 9 |
-
|
| 10 |
-
from .env_server import *
|
| 11 |
-
from .http_env_client import HTTPEnvClient
|
| 12 |
-
from .types import StepResult
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
__all__ = [
|
| 17 |
-
"
|
| 18 |
-
"
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
"""Core components for agentic environments."""
|
| 8 |
|
| 9 |
+
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
from importlib import import_module
|
| 12 |
+
from typing import TYPE_CHECKING
|
| 13 |
+
|
| 14 |
+
from . import env_server
|
| 15 |
+
from .env_server import * # noqa: F403
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from .env_client import EnvClient
|
| 19 |
+
from .generic_client import GenericAction, GenericEnvClient
|
| 20 |
+
from .llm_client import (
|
| 21 |
+
AnthropicClient,
|
| 22 |
+
create_llm_client,
|
| 23 |
+
LLMClient,
|
| 24 |
+
LLMResponse,
|
| 25 |
+
OpenAIClient,
|
| 26 |
+
ToolCall,
|
| 27 |
+
)
|
| 28 |
+
from .mcp_client import MCPClientBase, MCPToolClient
|
| 29 |
+
from .sync_client import SyncEnvClient
|
| 30 |
|
| 31 |
__all__ = [
|
| 32 |
+
"EnvClient",
|
| 33 |
+
"SyncEnvClient",
|
| 34 |
+
"GenericEnvClient",
|
| 35 |
+
"GenericAction",
|
| 36 |
+
"MCPClientBase",
|
| 37 |
+
"MCPToolClient",
|
| 38 |
+
"AnthropicClient",
|
| 39 |
+
"LLMClient",
|
| 40 |
+
"LLMResponse",
|
| 41 |
+
"OpenAIClient",
|
| 42 |
+
"ToolCall",
|
| 43 |
+
"create_llm_client",
|
| 44 |
+
] + env_server.__all__ # type: ignore
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_LAZY_ATTRS = {
|
| 48 |
+
"EnvClient": (".env_client", "EnvClient"),
|
| 49 |
+
"SyncEnvClient": (".sync_client", "SyncEnvClient"),
|
| 50 |
+
"GenericEnvClient": (".generic_client", "GenericEnvClient"),
|
| 51 |
+
"GenericAction": (".generic_client", "GenericAction"),
|
| 52 |
+
"MCPClientBase": (".mcp_client", "MCPClientBase"),
|
| 53 |
+
"MCPToolClient": (".mcp_client", "MCPToolClient"),
|
| 54 |
+
"AnthropicClient": (".llm_client", "AnthropicClient"),
|
| 55 |
+
"LLMClient": (".llm_client", "LLMClient"),
|
| 56 |
+
"LLMResponse": (".llm_client", "LLMResponse"),
|
| 57 |
+
"OpenAIClient": (".llm_client", "OpenAIClient"),
|
| 58 |
+
"ToolCall": (".llm_client", "ToolCall"),
|
| 59 |
+
"create_llm_client": (".llm_client", "create_llm_client"),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def __getattr__(name: str):
|
| 64 |
+
if name in _LAZY_ATTRS:
|
| 65 |
+
module_path, attr_name = _LAZY_ATTRS[name]
|
| 66 |
+
module = import_module(module_path, __name__)
|
| 67 |
+
value = getattr(module, attr_name)
|
| 68 |
+
globals()[name] = value
|
| 69 |
+
return value
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
value = getattr(env_server, name)
|
| 73 |
+
except AttributeError as exc:
|
| 74 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc
|
| 75 |
+
|
| 76 |
+
globals()[name] = value
|
| 77 |
+
return value
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def __dir__() -> list[str]:
|
| 81 |
+
return sorted(set(globals().keys()) | set(__all__))
|
src/core/client_types.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Type definitions for EnvTorch
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Generic, Optional, TypeVar
|
| 4 |
+
|
| 5 |
+
# Generic type for observations
|
| 6 |
+
ObsT = TypeVar("ObsT")
|
| 7 |
+
StateT = TypeVar("StateT")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class StepResult(Generic[ObsT]):
|
| 12 |
+
"""
|
| 13 |
+
Represents the result of one environment step.
|
| 14 |
+
|
| 15 |
+
Attributes:
|
| 16 |
+
observation: The environment's observation after the action.
|
| 17 |
+
reward: Scalar reward for this step (optional).
|
| 18 |
+
done: Whether the episode is finished.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
observation: ObsT
|
| 22 |
+
reward: Optional[float] = None
|
| 23 |
+
done: bool = False
|
src/core/containers/__init__.py
CHANGED
|
@@ -4,4 +4,4 @@
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
"""Container management for environment servers."""
|
|
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
"""Container management for environment servers."""
|
src/core/containers/images/Dockerfile
CHANGED
|
@@ -8,30 +8,47 @@
|
|
| 8 |
# OpenEnv Base Image
|
| 9 |
#
|
| 10 |
# This is the standard base image for all OpenEnv environment servers.
|
| 11 |
-
# It includes the minimal dependencies needed to run HTTP environment servers
|
|
|
|
| 12 |
#
|
| 13 |
-
# Build: docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 14 |
-
# Tag: docker tag openenv-base:latest openenv-base:0.
|
| 15 |
#
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
FROM python:3.11-slim
|
| 18 |
|
| 19 |
# Set metadata
|
| 20 |
LABEL maintainer="OpenEnv Team"
|
| 21 |
-
LABEL description="Base image for OpenEnv based environment servers"
|
| 22 |
-
LABEL version="0.
|
| 23 |
|
| 24 |
# Install system dependencies
|
| 25 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 26 |
curl \
|
|
|
|
| 27 |
&& rm -rf /var/lib/apt/lists/*
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# Set working directory
|
| 37 |
WORKDIR /app
|
|
@@ -39,6 +56,7 @@ WORKDIR /app
|
|
| 39 |
# Default environment variables
|
| 40 |
ENV PYTHONPATH=/app/src
|
| 41 |
ENV PYTHONUNBUFFERED=1
|
|
|
|
| 42 |
|
| 43 |
# Default expose port (can be overridden)
|
| 44 |
EXPOSE 8000
|
|
|
|
| 8 |
# OpenEnv Base Image
|
| 9 |
#
|
| 10 |
# This is the standard base image for all OpenEnv environment servers.
|
| 11 |
+
# It includes the minimal dependencies needed to run HTTP environment servers
|
| 12 |
+
# and uv for fast dependency management.
|
| 13 |
#
|
| 14 |
+
# Build from repo root: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
|
| 15 |
+
# Tag: docker tag openenv-base:latest openenv-base:0.2.0
|
| 16 |
#
|
| 17 |
|
| 18 |
+
FROM ghcr.io/astral-sh/uv:0.5.27-python3.11-bookworm-slim AS builder
|
| 19 |
+
|
| 20 |
+
# Set working directory
|
| 21 |
+
WORKDIR /app
|
| 22 |
+
|
| 23 |
+
# Copy core pyproject.toml and lockfile for dependency installation
|
| 24 |
+
COPY pyproject.toml uv.lock* ./
|
| 25 |
+
|
| 26 |
+
# Install core dependencies using uv with cache mount
|
| 27 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 28 |
+
uv pip install --system -r pyproject.toml
|
| 29 |
+
|
| 30 |
+
# Final runtime stage
|
| 31 |
FROM python:3.11-slim
|
| 32 |
|
| 33 |
# Set metadata
|
| 34 |
LABEL maintainer="OpenEnv Team"
|
| 35 |
+
LABEL description="Base image for OpenEnv based environment servers with uv"
|
| 36 |
+
LABEL version="0.2.0"
|
| 37 |
|
| 38 |
# Install system dependencies
|
| 39 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 40 |
curl \
|
| 41 |
+
ca-certificates \
|
| 42 |
&& rm -rf /var/lib/apt/lists/*
|
| 43 |
|
| 44 |
+
# Copy uv from builder
|
| 45 |
+
COPY --from=builder /usr/local/bin/uv /usr/local/bin/uvx /usr/local/bin/
|
| 46 |
+
|
| 47 |
+
# Copy installed Python packages from builder
|
| 48 |
+
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
| 49 |
+
|
| 50 |
+
# Copy console scripts installed by pip (uvicorn, fastapi, etc.)
|
| 51 |
+
COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/fastapi /usr/local/bin/
|
| 52 |
|
| 53 |
# Set working directory
|
| 54 |
WORKDIR /app
|
|
|
|
| 56 |
# Default environment variables
|
| 57 |
ENV PYTHONPATH=/app/src
|
| 58 |
ENV PYTHONUNBUFFERED=1
|
| 59 |
+
ENV UV_SYSTEM_PYTHON=1
|
| 60 |
|
| 61 |
# Default expose port (can be overridden)
|
| 62 |
EXPOSE 8000
|
src/core/containers/images/README.md
CHANGED
|
@@ -36,7 +36,7 @@ Total: 465 MB (base shared, minimal duplication)
|
|
| 36 |
|
| 37 |
```bash
|
| 38 |
# From project root
|
| 39 |
-
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 40 |
```
|
| 41 |
|
| 42 |
## Usage in Environment Dockerfiles
|
|
@@ -47,8 +47,8 @@ Each environment Dockerfile should start with:
|
|
| 47 |
FROM openenv-base:latest
|
| 48 |
|
| 49 |
# Copy only environment-specific files
|
| 50 |
-
COPY src/core/ /app/src/core/
|
| 51 |
-
COPY
|
| 52 |
|
| 53 |
# Run the server
|
| 54 |
CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
@@ -66,10 +66,10 @@ CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "80
|
|
| 66 |
|
| 67 |
```bash
|
| 68 |
# Step 1: Build base image (do this once)
|
| 69 |
-
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 70 |
|
| 71 |
# Step 2: Build echo environment (uses base)
|
| 72 |
-
docker build -t echo-env:latest -f
|
| 73 |
|
| 74 |
# Step 3: Run echo environment
|
| 75 |
docker run -p 8000:8000 echo-env:latest
|
|
@@ -79,14 +79,14 @@ docker run -p 8000:8000 echo-env:latest
|
|
| 79 |
|
| 80 |
When dependencies need updating:
|
| 81 |
|
| 82 |
-
1. Update `src/core/containers/images/Dockerfile`
|
| 83 |
2. Rebuild base image
|
| 84 |
3. Rebuild all environment images (they'll use new base)
|
| 85 |
|
| 86 |
```bash
|
| 87 |
# Update base
|
| 88 |
-
docker build -t openenv-base:latest -f src/core/containers/images/Dockerfile .
|
| 89 |
|
| 90 |
# Rebuild environments (they automatically use new base)
|
| 91 |
-
docker build -t echo-env:latest -f
|
| 92 |
```
|
|
|
|
| 36 |
|
| 37 |
```bash
|
| 38 |
# From project root
|
| 39 |
+
docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
|
| 40 |
```
|
| 41 |
|
| 42 |
## Usage in Environment Dockerfiles
|
|
|
|
| 47 |
FROM openenv-base:latest
|
| 48 |
|
| 49 |
# Copy only environment-specific files
|
| 50 |
+
COPY src/openenv/core/ /app/src/openenv/core/
|
| 51 |
+
COPY envs/my_env/ /app/envs/my_env/
|
| 52 |
|
| 53 |
# Run the server
|
| 54 |
CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
| 66 |
|
| 67 |
```bash
|
| 68 |
# Step 1: Build base image (do this once)
|
| 69 |
+
docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
|
| 70 |
|
| 71 |
# Step 2: Build echo environment (uses base)
|
| 72 |
+
docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
|
| 73 |
|
| 74 |
# Step 3: Run echo environment
|
| 75 |
docker run -p 8000:8000 echo-env:latest
|
|
|
|
| 79 |
|
| 80 |
When dependencies need updating:
|
| 81 |
|
| 82 |
+
1. Update `src/openenv/core/containers/images/Dockerfile`
|
| 83 |
2. Rebuild base image
|
| 84 |
3. Rebuild all environment images (they'll use new base)
|
| 85 |
|
| 86 |
```bash
|
| 87 |
# Update base
|
| 88 |
+
docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
|
| 89 |
|
| 90 |
# Rebuild environments (they automatically use new base)
|
| 91 |
+
docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
|
| 92 |
```
|
src/core/containers/runtime/__init__.py
CHANGED
|
@@ -6,10 +6,20 @@
|
|
| 6 |
|
| 7 |
"""Container runtime providers."""
|
| 8 |
|
| 9 |
-
from .providers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
__all__ = [
|
| 12 |
"ContainerProvider",
|
|
|
|
| 13 |
"LocalDockerProvider",
|
| 14 |
"KubernetesProvider",
|
| 15 |
-
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
"""Container runtime providers."""
|
| 8 |
|
| 9 |
+
from .providers import (
|
| 10 |
+
ContainerProvider,
|
| 11 |
+
DockerSwarmProvider,
|
| 12 |
+
KubernetesProvider,
|
| 13 |
+
LocalDockerProvider,
|
| 14 |
+
RuntimeProvider,
|
| 15 |
+
)
|
| 16 |
+
from .uv_provider import UVProvider
|
| 17 |
|
| 18 |
__all__ = [
|
| 19 |
"ContainerProvider",
|
| 20 |
+
"DockerSwarmProvider",
|
| 21 |
"LocalDockerProvider",
|
| 22 |
"KubernetesProvider",
|
| 23 |
+
"RuntimeProvider",
|
| 24 |
+
"UVProvider",
|
| 25 |
+
]
|
src/core/containers/runtime/daytona_provider.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Daytona container provider for running OpenEnv environments in Daytona cloud sandboxes.
|
| 9 |
+
|
| 10 |
+
Requires the ``daytona`` SDK: ``pip install daytona>=0.10``
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import shlex
|
| 18 |
+
import time
|
| 19 |
+
from typing import Any, Callable, Dict, Optional
|
| 20 |
+
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
from .providers import ContainerProvider
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DaytonaProvider(ContainerProvider):
|
| 27 |
+
"""
|
| 28 |
+
Container provider that runs environments in Daytona cloud sandboxes.
|
| 29 |
+
|
| 30 |
+
Example:
|
| 31 |
+
>>> provider = DaytonaProvider(api_key="your-key")
|
| 32 |
+
>>> image = DaytonaProvider.image_from_dockerfile("envs/echo_env/server/Dockerfile")
|
| 33 |
+
>>> base_url = provider.start_container(image)
|
| 34 |
+
>>> provider.wait_for_ready(base_url)
|
| 35 |
+
>>> provider.stop_container()
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
_dockerfile_registry: Dict[str, Dict[str, Any]] = {}
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
*,
|
| 43 |
+
api_key: Optional[str] = None,
|
| 44 |
+
public: bool = False,
|
| 45 |
+
resources: Optional[Any] = None,
|
| 46 |
+
auto_stop_interval: int = 15,
|
| 47 |
+
target: Optional[str] = None,
|
| 48 |
+
on_snapshot_create_logs: Optional[Callable[[str], None]] = None,
|
| 49 |
+
cmd: Optional[str] = None,
|
| 50 |
+
create_timeout: float = 300,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Args:
|
| 54 |
+
api_key: Daytona API key. Falls back to ``DAYTONA_API_KEY`` env var.
|
| 55 |
+
public: If True, the sandbox preview is publicly accessible.
|
| 56 |
+
resources: Optional ``daytona.Resources`` instance for CPU/memory.
|
| 57 |
+
auto_stop_interval: Minutes of inactivity before auto-stop (0 disables).
|
| 58 |
+
target: Daytona target region (e.g. "us").
|
| 59 |
+
on_snapshot_create_logs: Callback for snapshot build log lines.
|
| 60 |
+
cmd: Shell command to start the server inside the sandbox.
|
| 61 |
+
create_timeout: Seconds to wait for sandbox creation (default 300).
|
| 62 |
+
Heavy images (e.g. with Playwright/Chromium) may need more.
|
| 63 |
+
"""
|
| 64 |
+
from daytona import Daytona, DaytonaConfig
|
| 65 |
+
|
| 66 |
+
config_kwargs: Dict[str, Any] = {}
|
| 67 |
+
resolved_key = api_key or os.environ.get("DAYTONA_API_KEY")
|
| 68 |
+
if resolved_key:
|
| 69 |
+
config_kwargs["api_key"] = resolved_key
|
| 70 |
+
if target:
|
| 71 |
+
config_kwargs["target"] = target
|
| 72 |
+
|
| 73 |
+
self._daytona = Daytona(DaytonaConfig(**config_kwargs))
|
| 74 |
+
self._public = public
|
| 75 |
+
self._resources = resources
|
| 76 |
+
self._auto_stop_interval = auto_stop_interval
|
| 77 |
+
self._on_snapshot_create_logs = on_snapshot_create_logs
|
| 78 |
+
self._cmd = cmd
|
| 79 |
+
self._create_timeout = create_timeout
|
| 80 |
+
self._sandbox: Any = None
|
| 81 |
+
self._preview_url: Optional[str] = None
|
| 82 |
+
|
| 83 |
+
def _discover_server_cmd(self, sandbox: Any, port: int = 8000) -> str:
|
| 84 |
+
"""Discover the server command from ``openenv.yaml`` inside *sandbox*.
|
| 85 |
+
|
| 86 |
+
Finds the file, reads the ``app`` field, and constructs a command
|
| 87 |
+
of the form ``cd <env_root> && python -m uvicorn <app> --host 0.0.0.0 --port <port>``.
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
ValueError: If ``openenv.yaml`` is not found or lacks an ``app`` field.
|
| 91 |
+
"""
|
| 92 |
+
yaml_path = self._find_openenv_yaml(sandbox)
|
| 93 |
+
if yaml_path is None:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"Could not find openenv.yaml inside the sandbox. "
|
| 96 |
+
"Pass an explicit cmd= to DaytonaProvider or start_container()."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
cat_resp = sandbox.process.exec(f"cat {shlex.quote(yaml_path)}", timeout=10)
|
| 100 |
+
content = cat_resp.result if hasattr(cat_resp, "result") else str(cat_resp)
|
| 101 |
+
app = self._parse_app_field(content)
|
| 102 |
+
if app is None:
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"openenv.yaml at {yaml_path} does not contain an 'app' field. "
|
| 105 |
+
"Pass an explicit cmd= to DaytonaProvider or start_container()."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# The directory containing openenv.yaml is the env root
|
| 109 |
+
env_root = yaml_path.rsplit("/", 1)[0]
|
| 110 |
+
return (
|
| 111 |
+
f"cd {shlex.quote(env_root)} && "
|
| 112 |
+
f"python -m uvicorn {shlex.quote(app)} --host 0.0.0.0 --port {port}"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def _find_openenv_yaml(self, sandbox: Any) -> Optional[str]:
|
| 116 |
+
"""Locate ``openenv.yaml`` inside the sandbox.
|
| 117 |
+
|
| 118 |
+
Tries the modern layout path ``/app/env/openenv.yaml`` first,
|
| 119 |
+
then falls back to a ``find`` command for the old layout.
|
| 120 |
+
"""
|
| 121 |
+
# Fast path: modern Dockerfile layout
|
| 122 |
+
resp = sandbox.process.exec(
|
| 123 |
+
"test -f /app/env/openenv.yaml && echo found", timeout=10
|
| 124 |
+
)
|
| 125 |
+
out = resp.result if hasattr(resp, "result") else str(resp)
|
| 126 |
+
if "found" in (out or ""):
|
| 127 |
+
return "/app/env/openenv.yaml"
|
| 128 |
+
|
| 129 |
+
# Fallback: search for it (redirect stderr so error messages
|
| 130 |
+
# like "No such file or directory" don't get mistaken for paths).
|
| 131 |
+
resp = sandbox.process.exec(
|
| 132 |
+
"find /app -maxdepth 4 -name openenv.yaml -print -quit 2>/dev/null",
|
| 133 |
+
timeout=10,
|
| 134 |
+
)
|
| 135 |
+
path = (resp.result if hasattr(resp, "result") else str(resp) or "").strip()
|
| 136 |
+
if path and path.startswith("/"):
|
| 137 |
+
return path
|
| 138 |
+
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def _parse_app_field(yaml_content: str) -> Optional[str]:
|
| 143 |
+
"""Extract the ``app`` value from raw openenv.yaml content.
|
| 144 |
+
|
| 145 |
+
Uses PyYAML to handle comments, quotes, and nested keys correctly.
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
data = yaml.safe_load(yaml_content) or {}
|
| 149 |
+
except Exception:
|
| 150 |
+
return None
|
| 151 |
+
|
| 152 |
+
if not isinstance(data, dict):
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
value = data.get("app")
|
| 156 |
+
if isinstance(value, str):
|
| 157 |
+
value = value.strip()
|
| 158 |
+
return value if value else None
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def _parse_dockerfile_cmd(dockerfile_content: str) -> Optional[str]:
|
| 163 |
+
"""Extract the server command from the last ``CMD`` in a Dockerfile.
|
| 164 |
+
|
| 165 |
+
Handles exec form (``CMD ["prog", "arg"]``) and shell form
|
| 166 |
+
(``CMD prog arg``). When a Dockerfile has multiple ``CMD``
|
| 167 |
+
instructions (e.g. multi-stage builds), the last one wins - same
|
| 168 |
+
semantics as Docker itself. Lines where ``CMD`` appears inside a
|
| 169 |
+
comment are ignored.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
The command as a single string, or ``None`` if no ``CMD`` found.
|
| 173 |
+
"""
|
| 174 |
+
import re
|
| 175 |
+
|
| 176 |
+
last_cmd: Optional[str] = None
|
| 177 |
+
for line in dockerfile_content.splitlines():
|
| 178 |
+
stripped = line.strip()
|
| 179 |
+
# Skip comments
|
| 180 |
+
if stripped.startswith("#"):
|
| 181 |
+
continue
|
| 182 |
+
match = re.match(r"CMD\s+(.+)", stripped, flags=re.IGNORECASE)
|
| 183 |
+
if match:
|
| 184 |
+
last_cmd = match.group(1).strip()
|
| 185 |
+
|
| 186 |
+
if last_cmd is None:
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
# Exec form: CMD ["executable", "param1", ...]
|
| 190 |
+
if last_cmd.startswith("["):
|
| 191 |
+
try:
|
| 192 |
+
parts = json.loads(last_cmd)
|
| 193 |
+
if isinstance(parts, list) and all(isinstance(p, str) for p in parts):
|
| 194 |
+
return " ".join(parts)
|
| 195 |
+
except (json.JSONDecodeError, TypeError):
|
| 196 |
+
pass
|
| 197 |
+
|
| 198 |
+
# Shell form: CMD executable param1 ...
|
| 199 |
+
return last_cmd if last_cmd else None
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def strip_buildkit_syntax(dockerfile_content: str) -> str:
|
| 203 |
+
"""Remove BuildKit ``--mount=...`` flags from ``RUN`` instructions.
|
| 204 |
+
|
| 205 |
+
Handles single-line flags, multi-line continuations, and multiple
|
| 206 |
+
``--mount`` flags spread across continuation lines. Only leading
|
| 207 |
+
``--mount`` flags are removed (before the actual command starts).
|
| 208 |
+
|
| 209 |
+
Daytona's ``Image.from_dockerfile`` does not support BuildKit
|
| 210 |
+
``--mount`` syntax. This helper strips the flags so that standard
|
| 211 |
+
Dockerfiles (like the ones generated by ``openenv build``) can
|
| 212 |
+
be used directly.
|
| 213 |
+
"""
|
| 214 |
+
import re
|
| 215 |
+
|
| 216 |
+
def strip_leading_mounts(text: str) -> str:
|
| 217 |
+
remaining = text
|
| 218 |
+
while True:
|
| 219 |
+
match = re.match(r"\s*--mount=\S+\s*", remaining)
|
| 220 |
+
if not match:
|
| 221 |
+
return remaining
|
| 222 |
+
remaining = remaining[match.end() :]
|
| 223 |
+
|
| 224 |
+
lines = dockerfile_content.split("\n")
|
| 225 |
+
result: list[str] = []
|
| 226 |
+
in_run = False
|
| 227 |
+
in_mount_prefix = False
|
| 228 |
+
|
| 229 |
+
for line in lines:
|
| 230 |
+
line_out = line
|
| 231 |
+
run_start = False
|
| 232 |
+
if re.match(r"\s*RUN(\s+|$)", line, flags=re.IGNORECASE):
|
| 233 |
+
in_run = True
|
| 234 |
+
in_mount_prefix = True
|
| 235 |
+
run_start = True
|
| 236 |
+
|
| 237 |
+
if in_run and in_mount_prefix:
|
| 238 |
+
original_ends_with_slash = line_out.rstrip().endswith("\\")
|
| 239 |
+
if run_start:
|
| 240 |
+
match = re.match(r"(\s*RUN\s+)(.*)$", line_out, flags=re.IGNORECASE)
|
| 241 |
+
if match:
|
| 242 |
+
run_prefix, remainder = match.group(1), match.group(2)
|
| 243 |
+
else:
|
| 244 |
+
run_prefix, remainder = line_out, ""
|
| 245 |
+
new_remainder = strip_leading_mounts(remainder)
|
| 246 |
+
line_out = run_prefix + new_remainder
|
| 247 |
+
content_for_check = new_remainder
|
| 248 |
+
else:
|
| 249 |
+
new_remainder = strip_leading_mounts(line_out)
|
| 250 |
+
line_out = new_remainder
|
| 251 |
+
content_for_check = new_remainder
|
| 252 |
+
|
| 253 |
+
if original_ends_with_slash and not line_out.rstrip().endswith("\\"):
|
| 254 |
+
line_out = line_out.rstrip() + " \\"
|
| 255 |
+
|
| 256 |
+
if content_for_check.strip() not in ("", "\\"):
|
| 257 |
+
in_mount_prefix = False
|
| 258 |
+
|
| 259 |
+
if in_run and not line_out.rstrip().endswith("\\"):
|
| 260 |
+
in_run = False
|
| 261 |
+
in_mount_prefix = False
|
| 262 |
+
|
| 263 |
+
result.append(line_out)
|
| 264 |
+
|
| 265 |
+
return "\n".join(result)
|
| 266 |
+
|
| 267 |
+
@classmethod
|
| 268 |
+
def image_from_dockerfile(
|
| 269 |
+
cls,
|
| 270 |
+
dockerfile_path: str,
|
| 271 |
+
context_dir: str | None = None,
|
| 272 |
+
) -> str:
|
| 273 |
+
"""Validate a Dockerfile and return a ``dockerfile:`` URI for
|
| 274 |
+
:meth:`start_container`.
|
| 275 |
+
|
| 276 |
+
Eagerly validates the Dockerfile (existence, COPY sources,
|
| 277 |
+
BuildKit stripping) and stores the processed content in an
|
| 278 |
+
internal registry. The actual ``daytona.Image`` is created
|
| 279 |
+
later inside ``start_container``.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
dockerfile_path: Path to the Dockerfile on disk.
|
| 283 |
+
context_dir: Build context directory. Defaults to the
|
| 284 |
+
Dockerfile's grandparent directory, matching the
|
| 285 |
+
``openenv init`` convention where Dockerfiles live in
|
| 286 |
+
``<env>/server/Dockerfile`` and the build context is
|
| 287 |
+
``<env>/``. Pass explicitly for non-standard layouts
|
| 288 |
+
(e.g. ``context_dir="."`` for repo-root contexts).
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
A ``"dockerfile:<abs_path>"`` string to pass to
|
| 292 |
+
``start_container``.
|
| 293 |
+
|
| 294 |
+
Raises:
|
| 295 |
+
FileNotFoundError: If *dockerfile_path* does not exist.
|
| 296 |
+
ValueError: If *context_dir* is given but does not exist,
|
| 297 |
+
or if COPY sources in the Dockerfile cannot be found
|
| 298 |
+
under the resolved context directory.
|
| 299 |
+
"""
|
| 300 |
+
import pathlib
|
| 301 |
+
import re
|
| 302 |
+
|
| 303 |
+
src = pathlib.Path(dockerfile_path).resolve()
|
| 304 |
+
if not src.is_file():
|
| 305 |
+
raise FileNotFoundError(f"Dockerfile not found: {dockerfile_path}")
|
| 306 |
+
|
| 307 |
+
if context_dir is not None:
|
| 308 |
+
ctx = pathlib.Path(context_dir)
|
| 309 |
+
if not ctx.is_dir():
|
| 310 |
+
raise ValueError(f"context_dir does not exist: {context_dir}")
|
| 311 |
+
else:
|
| 312 |
+
# Default: grandparent of the Dockerfile, matching the
|
| 313 |
+
# openenv init layout (<env>/server/Dockerfile -> <env>/).
|
| 314 |
+
ctx = src.parent.parent
|
| 315 |
+
|
| 316 |
+
content = src.read_text()
|
| 317 |
+
stripped = cls.strip_buildkit_syntax(content)
|
| 318 |
+
|
| 319 |
+
# Validate that COPY sources exist under the context directory.
|
| 320 |
+
# This catches mismatches early (e.g. a Dockerfile expecting repo
|
| 321 |
+
# root as context when we defaulted to the env directory).
|
| 322 |
+
for line in stripped.splitlines():
|
| 323 |
+
m = re.match(r"^\s*COPY\s+(?!--from=)(\S+)\s+", line, re.IGNORECASE)
|
| 324 |
+
if not m:
|
| 325 |
+
continue
|
| 326 |
+
copy_src = m.group(1)
|
| 327 |
+
if copy_src.startswith("/"):
|
| 328 |
+
continue
|
| 329 |
+
resolved = ctx / copy_src
|
| 330 |
+
if not resolved.exists() and not any(ctx.glob(copy_src)):
|
| 331 |
+
raise ValueError(
|
| 332 |
+
f"Dockerfile COPY source '{copy_src}' not found "
|
| 333 |
+
f"under context_dir '{ctx}'. This Dockerfile may "
|
| 334 |
+
f"expect a different build context (e.g. the repo "
|
| 335 |
+
f"root). Pass context_dir explicitly."
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Parse CMD from the original Dockerfile so start_container can
|
| 339 |
+
# use it as a fallback when openenv.yaml is unavailable.
|
| 340 |
+
parsed_cmd = cls._parse_dockerfile_cmd(content)
|
| 341 |
+
|
| 342 |
+
cls._dockerfile_registry[str(src)] = {
|
| 343 |
+
"stripped_content": stripped,
|
| 344 |
+
"context_dir": str(ctx),
|
| 345 |
+
"server_cmd": parsed_cmd,
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
return f"dockerfile:{src}"
|
| 349 |
+
|
| 350 |
+
def start_container(
|
| 351 |
+
self,
|
| 352 |
+
image: str,
|
| 353 |
+
port: Optional[int] = None,
|
| 354 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 355 |
+
**kwargs: Any,
|
| 356 |
+
) -> str:
|
| 357 |
+
"""
|
| 358 |
+
Create a Daytona sandbox from a Docker image or snapshot.
|
| 359 |
+
|
| 360 |
+
Daytona does not execute the image's CMD (known bug — ENTRYPOINT
|
| 361 |
+
runs, CMD does not). The server command is resolved in order:
|
| 362 |
+
|
| 363 |
+
1. Explicit ``cmd`` passed to the constructor.
|
| 364 |
+
2. ``cmd`` key in ``**kwargs`` (popped before forwarding).
|
| 365 |
+
3. Auto-discovered from ``openenv.yaml`` inside the sandbox.
|
| 366 |
+
4. ``CMD`` parsed from the Dockerfile (when *image* came from
|
| 367 |
+
``image_from_dockerfile``).
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
image: Docker image name (e.g. ``"echo-env:latest"``),
|
| 371 |
+
``"snapshot:<name>"`` to create from a pre-built snapshot,
|
| 372 |
+
or ``"dockerfile:<path>"`` returned by
|
| 373 |
+
:meth:`image_from_dockerfile`.
|
| 374 |
+
port: Must be ``None`` or ``8000``. Daytona exposes port 8000
|
| 375 |
+
via its preview proxy; other ports raise ``ValueError``.
|
| 376 |
+
env_vars: Environment variables forwarded to the sandbox.
|
| 377 |
+
**kwargs: ``cmd`` (str) to override the server command;
|
| 378 |
+
remaining kwargs passed through to ``Daytona.create()``.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
HTTPS preview URL for the sandbox (base_url).
|
| 382 |
+
"""
|
| 383 |
+
if port is not None and port != 8000:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
f"DaytonaProvider only supports port 8000 (got {port}). "
|
| 386 |
+
"The Daytona preview proxy routes to port 8000 inside the sandbox."
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Resolve the server command (may be None; discovery happens after
|
| 390 |
+
# sandbox creation when we can inspect the filesystem).
|
| 391 |
+
cmd = kwargs.pop("cmd", None) or self._cmd
|
| 392 |
+
|
| 393 |
+
# CMD parsed from Dockerfile (populated for "dockerfile:" images).
|
| 394 |
+
parsed_cmd: Optional[str] = None
|
| 395 |
+
|
| 396 |
+
# Build creation params
|
| 397 |
+
create_kwargs: Dict[str, Any] = {}
|
| 398 |
+
if env_vars:
|
| 399 |
+
create_kwargs["env_vars"] = env_vars
|
| 400 |
+
if self._public:
|
| 401 |
+
create_kwargs["public"] = True
|
| 402 |
+
if self._auto_stop_interval != 15:
|
| 403 |
+
create_kwargs["auto_stop_interval"] = self._auto_stop_interval
|
| 404 |
+
|
| 405 |
+
if image.startswith("snapshot:"):
|
| 406 |
+
from daytona import CreateSandboxFromSnapshotParams
|
| 407 |
+
|
| 408 |
+
snapshot_name = image[len("snapshot:") :]
|
| 409 |
+
params = CreateSandboxFromSnapshotParams(
|
| 410 |
+
snapshot=snapshot_name, **create_kwargs
|
| 411 |
+
)
|
| 412 |
+
elif image.startswith("dockerfile:"):
|
| 413 |
+
from daytona import CreateSandboxFromImageParams, Image
|
| 414 |
+
|
| 415 |
+
dockerfile_path = image[len("dockerfile:") :]
|
| 416 |
+
meta = self._dockerfile_registry.get(dockerfile_path)
|
| 417 |
+
if meta is None:
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"No registered Dockerfile metadata for {dockerfile_path}. "
|
| 420 |
+
"Call DaytonaProvider.image_from_dockerfile() first."
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
parsed_cmd = meta.get("server_cmd")
|
| 424 |
+
|
| 425 |
+
# Build the daytona Image from the pre-stripped content.
|
| 426 |
+
import pathlib
|
| 427 |
+
import uuid
|
| 428 |
+
|
| 429 |
+
ctx = pathlib.Path(meta["context_dir"])
|
| 430 |
+
tmp_name = f".daytona-{uuid.uuid4().hex[:8]}.dockerfile"
|
| 431 |
+
tmp_path = ctx / tmp_name
|
| 432 |
+
try:
|
| 433 |
+
tmp_path.write_text(meta["stripped_content"])
|
| 434 |
+
daytona_image = Image.from_dockerfile(str(tmp_path))
|
| 435 |
+
finally:
|
| 436 |
+
tmp_path.unlink(missing_ok=True)
|
| 437 |
+
|
| 438 |
+
img_kwargs: Dict[str, Any] = {
|
| 439 |
+
"image": daytona_image,
|
| 440 |
+
**create_kwargs,
|
| 441 |
+
}
|
| 442 |
+
if self._resources is not None:
|
| 443 |
+
img_kwargs["resources"] = self._resources
|
| 444 |
+
params = CreateSandboxFromImageParams(**img_kwargs)
|
| 445 |
+
else:
|
| 446 |
+
from daytona import CreateSandboxFromImageParams
|
| 447 |
+
|
| 448 |
+
img_kwargs = {"image": image, **create_kwargs}
|
| 449 |
+
if self._resources is not None:
|
| 450 |
+
img_kwargs["resources"] = self._resources
|
| 451 |
+
params = CreateSandboxFromImageParams(**img_kwargs)
|
| 452 |
+
|
| 453 |
+
# Create sandbox
|
| 454 |
+
extra: Dict[str, Any] = dict(kwargs)
|
| 455 |
+
if self._on_snapshot_create_logs is not None:
|
| 456 |
+
extra["on_snapshot_create_logs"] = self._on_snapshot_create_logs
|
| 457 |
+
|
| 458 |
+
self._sandbox = self._daytona.create(
|
| 459 |
+
params, timeout=self._create_timeout, **extra
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
# Discover server command from openenv.yaml if not explicitly set.
|
| 464 |
+
if cmd is None:
|
| 465 |
+
try:
|
| 466 |
+
cmd = self._discover_server_cmd(self._sandbox)
|
| 467 |
+
except ValueError:
|
| 468 |
+
# Fall back to CMD parsed from Dockerfile (if available).
|
| 469 |
+
if parsed_cmd:
|
| 470 |
+
cmd = parsed_cmd
|
| 471 |
+
else:
|
| 472 |
+
raise
|
| 473 |
+
|
| 474 |
+
# Wrap in bash -c so compound commands (cd ... && uvicorn ...)
|
| 475 |
+
# are handled correctly by nohup. Write PID so we can check
|
| 476 |
+
# if the process crashed later in wait_for_ready().
|
| 477 |
+
escaped_cmd = shlex.quote(cmd)
|
| 478 |
+
self._sandbox.process.exec(
|
| 479 |
+
f"nohup bash -c {escaped_cmd} > /tmp/openenv-server.log 2>&1 &"
|
| 480 |
+
" echo $! > /tmp/openenv-server.pid",
|
| 481 |
+
timeout=10,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# Get a signed preview URL for port 8000. The token is
|
| 485 |
+
# embedded in the URL itself so no extra headers are needed.
|
| 486 |
+
signed = self._sandbox.create_signed_preview_url(
|
| 487 |
+
8000, expires_in_seconds=86400
|
| 488 |
+
)
|
| 489 |
+
self._preview_url = signed.url
|
| 490 |
+
except Exception:
|
| 491 |
+
self.stop_container()
|
| 492 |
+
raise
|
| 493 |
+
|
| 494 |
+
return self._preview_url
|
| 495 |
+
|
| 496 |
+
def refresh_preview_url(self) -> str:
|
| 497 |
+
"""Get a fresh signed preview URL (valid for 24h).
|
| 498 |
+
|
| 499 |
+
Daytona signed URLs expire after at most 24 hours. Call this to
|
| 500 |
+
get a new one for long-running sessions. The returned URL points
|
| 501 |
+
to the same sandbox — clients will need to reconnect using it.
|
| 502 |
+
"""
|
| 503 |
+
if self._sandbox is None:
|
| 504 |
+
raise RuntimeError("No active sandbox to refresh URL for.")
|
| 505 |
+
signed = self._sandbox.create_signed_preview_url(8000, expires_in_seconds=86400)
|
| 506 |
+
self._preview_url = signed.url
|
| 507 |
+
return self._preview_url
|
| 508 |
+
|
| 509 |
+
def stop_container(self) -> None:
|
| 510 |
+
"""Delete the Daytona sandbox."""
|
| 511 |
+
if self._sandbox is None:
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
try:
|
| 515 |
+
self._daytona.delete(self._sandbox)
|
| 516 |
+
finally:
|
| 517 |
+
self._sandbox = None
|
| 518 |
+
self._preview_url = None
|
| 519 |
+
|
| 520 |
+
def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
|
| 521 |
+
"""
|
| 522 |
+
Poll the /health endpoint until the sandbox is ready.
|
| 523 |
+
|
| 524 |
+
Uses a longer default timeout (120s) than Docker providers because
|
| 525 |
+
Daytona sandboxes may have cold-start latency.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
base_url: Preview URL returned by ``start_container()``.
|
| 529 |
+
timeout_s: Maximum seconds to wait.
|
| 530 |
+
|
| 531 |
+
Raises:
|
| 532 |
+
TimeoutError: If the sandbox doesn't become ready in time.
|
| 533 |
+
RuntimeError: If the server process died (detected via PID check).
|
| 534 |
+
"""
|
| 535 |
+
import requests
|
| 536 |
+
|
| 537 |
+
health_url = f"{base_url}/health"
|
| 538 |
+
|
| 539 |
+
deadline = time.time() + timeout_s
|
| 540 |
+
while time.time() < deadline:
|
| 541 |
+
try:
|
| 542 |
+
response = requests.get(health_url, timeout=5.0)
|
| 543 |
+
if response.status_code == 200:
|
| 544 |
+
return
|
| 545 |
+
except requests.RequestException:
|
| 546 |
+
pass
|
| 547 |
+
|
| 548 |
+
# Early exit: if the server process died, raise immediately
|
| 549 |
+
# instead of waiting for the full health-check timeout.
|
| 550 |
+
if self._sandbox is not None:
|
| 551 |
+
resp = self._sandbox.process.exec(
|
| 552 |
+
"kill -0 $(cat /tmp/openenv-server.pid) 2>/dev/null"
|
| 553 |
+
" && echo RUNNING || echo DEAD",
|
| 554 |
+
timeout=10,
|
| 555 |
+
)
|
| 556 |
+
out = resp.result if hasattr(resp, "result") else str(resp)
|
| 557 |
+
if "DEAD" in (out or ""):
|
| 558 |
+
log_resp = self._sandbox.process.exec(
|
| 559 |
+
"cat /tmp/openenv-server.log 2>/dev/null", timeout=10
|
| 560 |
+
)
|
| 561 |
+
log = (
|
| 562 |
+
log_resp.result
|
| 563 |
+
if hasattr(log_resp, "result")
|
| 564 |
+
else str(log_resp)
|
| 565 |
+
)
|
| 566 |
+
raise RuntimeError(f"Server process died.\nLog:\n{log}")
|
| 567 |
+
|
| 568 |
+
time.sleep(1.0)
|
| 569 |
+
|
| 570 |
+
raise TimeoutError(
|
| 571 |
+
f"Daytona sandbox at {base_url} did not become ready within {timeout_s}s"
|
| 572 |
+
)
|
src/core/containers/runtime/providers.py
CHANGED
|
@@ -8,13 +8,13 @@
|
|
| 8 |
Container provider abstractions for running environment servers.
|
| 9 |
|
| 10 |
This module provides a pluggable architecture for different container providers
|
| 11 |
-
(local Docker, Kubernetes, cloud providers, etc.) to be used with
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
from abc import ABC, abstractmethod
|
| 17 |
-
from typing import Any, Dict, Optional
|
| 18 |
|
| 19 |
|
| 20 |
class ContainerProvider(ABC):
|
|
@@ -118,7 +118,11 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 118 |
capture_output=True,
|
| 119 |
timeout=5,
|
| 120 |
)
|
| 121 |
-
except (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
raise RuntimeError(
|
| 123 |
"Docker is not available. Please install Docker Desktop or Docker Engine."
|
| 124 |
)
|
|
@@ -154,10 +158,13 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 154 |
|
| 155 |
# Build docker run command
|
| 156 |
cmd = [
|
| 157 |
-
"docker",
|
|
|
|
| 158 |
"-d", # Detached
|
| 159 |
-
"--name",
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
]
|
| 162 |
|
| 163 |
# Add environment variables
|
|
@@ -169,8 +176,12 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 169 |
cmd.append(image)
|
| 170 |
|
| 171 |
# Run container
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Wait a moment for container to start
|
| 176 |
time.sleep(1)
|
|
@@ -222,14 +233,18 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 222 |
TimeoutError: If container doesn't become ready
|
| 223 |
"""
|
| 224 |
import time
|
|
|
|
| 225 |
import requests
|
| 226 |
|
| 227 |
start_time = time.time()
|
| 228 |
health_url = f"{base_url}/health"
|
| 229 |
|
|
|
|
|
|
|
|
|
|
| 230 |
while time.time() - start_time < timeout_s:
|
| 231 |
try:
|
| 232 |
-
response = requests.get(health_url, timeout=2.0)
|
| 233 |
if response.status_code == 200:
|
| 234 |
return
|
| 235 |
except requests.RequestException:
|
|
@@ -273,6 +288,308 @@ class LocalDockerProvider(ContainerProvider):
|
|
| 273 |
return f"{clean_image}-{timestamp}"
|
| 274 |
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
class KubernetesProvider(ContainerProvider):
|
| 277 |
"""
|
| 278 |
Container provider for Kubernetes clusters.
|
|
@@ -286,4 +603,67 @@ class KubernetesProvider(ContainerProvider):
|
|
| 286 |
>>> # Pod running in k8s, accessible via service or port-forward
|
| 287 |
>>> provider.stop_container()
|
| 288 |
"""
|
|
|
|
| 289 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
Container provider abstractions for running environment servers.
|
| 9 |
|
| 10 |
This module provides a pluggable architecture for different container providers
|
| 11 |
+
(local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient.
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
from abc import ABC, abstractmethod
|
| 17 |
+
from typing import Any, Dict, Optional, Sequence
|
| 18 |
|
| 19 |
|
| 20 |
class ContainerProvider(ABC):
|
|
|
|
| 118 |
capture_output=True,
|
| 119 |
timeout=5,
|
| 120 |
)
|
| 121 |
+
except (
|
| 122 |
+
subprocess.CalledProcessError,
|
| 123 |
+
FileNotFoundError,
|
| 124 |
+
subprocess.TimeoutExpired,
|
| 125 |
+
):
|
| 126 |
raise RuntimeError(
|
| 127 |
"Docker is not available. Please install Docker Desktop or Docker Engine."
|
| 128 |
)
|
|
|
|
| 158 |
|
| 159 |
# Build docker run command
|
| 160 |
cmd = [
|
| 161 |
+
"docker",
|
| 162 |
+
"run",
|
| 163 |
"-d", # Detached
|
| 164 |
+
"--name",
|
| 165 |
+
self._container_name,
|
| 166 |
+
"-p",
|
| 167 |
+
f"{port}:8000", # Map port
|
| 168 |
]
|
| 169 |
|
| 170 |
# Add environment variables
|
|
|
|
| 176 |
cmd.append(image)
|
| 177 |
|
| 178 |
# Run container
|
| 179 |
+
try:
|
| 180 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 181 |
+
self._container_id = result.stdout.strip()
|
| 182 |
+
except subprocess.CalledProcessError as e:
|
| 183 |
+
error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}"
|
| 184 |
+
raise RuntimeError(error_msg) from e
|
| 185 |
|
| 186 |
# Wait a moment for container to start
|
| 187 |
time.sleep(1)
|
|
|
|
| 233 |
TimeoutError: If container doesn't become ready
|
| 234 |
"""
|
| 235 |
import time
|
| 236 |
+
|
| 237 |
import requests
|
| 238 |
|
| 239 |
start_time = time.time()
|
| 240 |
health_url = f"{base_url}/health"
|
| 241 |
|
| 242 |
+
# Bypass proxy for localhost to avoid proxy issues
|
| 243 |
+
proxies = {"http": None, "https": None}
|
| 244 |
+
|
| 245 |
while time.time() - start_time < timeout_s:
|
| 246 |
try:
|
| 247 |
+
response = requests.get(health_url, timeout=2.0, proxies=proxies)
|
| 248 |
if response.status_code == 200:
|
| 249 |
return
|
| 250 |
except requests.RequestException:
|
|
|
|
| 288 |
return f"{clean_image}-{timestamp}"
|
| 289 |
|
| 290 |
|
| 291 |
+
class DockerSwarmProvider(ContainerProvider):
|
| 292 |
+
"""
|
| 293 |
+
Container provider that uses Docker Swarm services for local concurrency.
|
| 294 |
+
|
| 295 |
+
This provider creates a replicated Swarm service backed by the local Docker
|
| 296 |
+
engine. The built-in load-balancer fans requests across the replicas,
|
| 297 |
+
allowing multiple container instances to run concurrently on the developer
|
| 298 |
+
workstation (mirroring the workflow described in the Docker stack docs).
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
*,
|
| 304 |
+
auto_init_swarm: bool = True,
|
| 305 |
+
overlay_network: Optional[str] = None,
|
| 306 |
+
):
|
| 307 |
+
"""
|
| 308 |
+
Args:
|
| 309 |
+
auto_init_swarm: Whether to call ``docker swarm init`` when Swarm
|
| 310 |
+
is not active. Otherwise, user must manually initialize Swarm.
|
| 311 |
+
overlay_network: Optional overlay network name for the service.
|
| 312 |
+
When provided, the network is created with
|
| 313 |
+
``docker network create --driver overlay --attachable`` if it
|
| 314 |
+
does not already exist.
|
| 315 |
+
"""
|
| 316 |
+
self._service_name: Optional[str] = None
|
| 317 |
+
self._service_id: Optional[str] = None
|
| 318 |
+
self._published_port: Optional[int] = None
|
| 319 |
+
self._overlay_network = overlay_network
|
| 320 |
+
self._auto_init_swarm = auto_init_swarm
|
| 321 |
+
|
| 322 |
+
self._ensure_docker_available()
|
| 323 |
+
self._ensure_swarm_initialized()
|
| 324 |
+
if self._overlay_network:
|
| 325 |
+
self._ensure_overlay_network(self._overlay_network)
|
| 326 |
+
|
| 327 |
+
def start_container(
|
| 328 |
+
self,
|
| 329 |
+
image: str,
|
| 330 |
+
port: Optional[int] = None,
|
| 331 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 332 |
+
**kwargs: Any,
|
| 333 |
+
) -> str:
|
| 334 |
+
"""
|
| 335 |
+
Start (or scale) a Swarm service for the given image.
|
| 336 |
+
|
| 337 |
+
Supported kwargs:
|
| 338 |
+
replicas (int): Number of container replicas (default: 2).
|
| 339 |
+
cpu_limit (float | str): CPU limit passed to ``--limit-cpu``.
|
| 340 |
+
memory_limit (str): Memory limit passed to ``--limit-memory``.
|
| 341 |
+
constraints (Sequence[str]): Placement constraints.
|
| 342 |
+
labels (Dict[str, str]): Service labels.
|
| 343 |
+
command (Sequence[str] | str): Override container command.
|
| 344 |
+
"""
|
| 345 |
+
import shlex
|
| 346 |
+
import subprocess
|
| 347 |
+
import time
|
| 348 |
+
|
| 349 |
+
allowed_kwargs = {
|
| 350 |
+
"replicas",
|
| 351 |
+
"cpu_limit",
|
| 352 |
+
"memory_limit",
|
| 353 |
+
"constraints",
|
| 354 |
+
"labels",
|
| 355 |
+
"command",
|
| 356 |
+
}
|
| 357 |
+
unknown = set(kwargs) - allowed_kwargs
|
| 358 |
+
if unknown:
|
| 359 |
+
raise ValueError(f"Unsupported kwargs for DockerSwarmProvider: {unknown}")
|
| 360 |
+
|
| 361 |
+
replicas = int(kwargs.get("replicas", 2))
|
| 362 |
+
cpu_limit = kwargs.get("cpu_limit")
|
| 363 |
+
memory_limit = kwargs.get("memory_limit")
|
| 364 |
+
constraints: Optional[Sequence[str]] = kwargs.get("constraints")
|
| 365 |
+
labels: Optional[Dict[str, str]] = kwargs.get("labels")
|
| 366 |
+
command_override = kwargs.get("command")
|
| 367 |
+
|
| 368 |
+
if port is None:
|
| 369 |
+
port = self._find_available_port()
|
| 370 |
+
|
| 371 |
+
self._service_name = self._generate_service_name(image)
|
| 372 |
+
self._published_port = port
|
| 373 |
+
|
| 374 |
+
cmd = [
|
| 375 |
+
"docker",
|
| 376 |
+
"service",
|
| 377 |
+
"create",
|
| 378 |
+
"--detach",
|
| 379 |
+
"--name",
|
| 380 |
+
self._service_name,
|
| 381 |
+
"--replicas",
|
| 382 |
+
str(max(1, replicas)),
|
| 383 |
+
"--publish",
|
| 384 |
+
f"{port}:8000",
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
if self._overlay_network:
|
| 388 |
+
cmd.extend(["--network", self._overlay_network])
|
| 389 |
+
|
| 390 |
+
if env_vars:
|
| 391 |
+
for key, value in env_vars.items():
|
| 392 |
+
cmd.extend(["--env", f"{key}={value}"])
|
| 393 |
+
|
| 394 |
+
if cpu_limit is not None:
|
| 395 |
+
cmd.extend(["--limit-cpu", str(cpu_limit)])
|
| 396 |
+
|
| 397 |
+
if memory_limit is not None:
|
| 398 |
+
cmd.extend(["--limit-memory", str(memory_limit)])
|
| 399 |
+
|
| 400 |
+
if constraints:
|
| 401 |
+
for constraint in constraints:
|
| 402 |
+
cmd.extend(["--constraint", constraint])
|
| 403 |
+
|
| 404 |
+
if labels:
|
| 405 |
+
for key, value in labels.items():
|
| 406 |
+
cmd.extend(["--label", f"{key}={value}"])
|
| 407 |
+
|
| 408 |
+
cmd.append(image)
|
| 409 |
+
|
| 410 |
+
if command_override:
|
| 411 |
+
if isinstance(command_override, str):
|
| 412 |
+
cmd.extend(shlex.split(command_override))
|
| 413 |
+
else:
|
| 414 |
+
cmd.extend(command_override)
|
| 415 |
+
|
| 416 |
+
try:
|
| 417 |
+
result = subprocess.run(
|
| 418 |
+
cmd,
|
| 419 |
+
capture_output=True,
|
| 420 |
+
text=True,
|
| 421 |
+
check=True,
|
| 422 |
+
)
|
| 423 |
+
self._service_id = result.stdout.strip()
|
| 424 |
+
except subprocess.CalledProcessError as e:
|
| 425 |
+
error_msg = (
|
| 426 |
+
"Failed to start Docker Swarm service.\n"
|
| 427 |
+
f"Command: {' '.join(cmd)}\n"
|
| 428 |
+
f"Exit code: {e.returncode}\n"
|
| 429 |
+
f"Stdout: {e.stdout}\n"
|
| 430 |
+
f"Stderr: {e.stderr}"
|
| 431 |
+
)
|
| 432 |
+
raise RuntimeError(error_msg) from e
|
| 433 |
+
|
| 434 |
+
# Give Swarm a brief moment to schedule the tasks.
|
| 435 |
+
time.sleep(1.0)
|
| 436 |
+
|
| 437 |
+
return f"http://localhost:{port}"
|
| 438 |
+
|
| 439 |
+
def stop_container(self) -> None:
|
| 440 |
+
"""
|
| 441 |
+
Remove the Swarm service (and keep the Swarm manager running).
|
| 442 |
+
"""
|
| 443 |
+
if not self._service_name:
|
| 444 |
+
return
|
| 445 |
+
|
| 446 |
+
import subprocess
|
| 447 |
+
|
| 448 |
+
try:
|
| 449 |
+
subprocess.run(
|
| 450 |
+
["docker", "service", "rm", self._service_name],
|
| 451 |
+
capture_output=True,
|
| 452 |
+
check=True,
|
| 453 |
+
timeout=10,
|
| 454 |
+
)
|
| 455 |
+
except subprocess.CalledProcessError:
|
| 456 |
+
# Service may already be gone; ignore.
|
| 457 |
+
pass
|
| 458 |
+
finally:
|
| 459 |
+
self._service_name = None
|
| 460 |
+
self._service_id = None
|
| 461 |
+
self._published_port = None
|
| 462 |
+
|
| 463 |
+
def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
|
| 464 |
+
"""
|
| 465 |
+
Wait for at least one replica to become healthy by polling /health.
|
| 466 |
+
|
| 467 |
+
Note: With Swarm's load balancer, requests round-robin across replicas,
|
| 468 |
+
so this only verifies that at least one replica is responding. Some
|
| 469 |
+
replicas may still be starting when this returns.
|
| 470 |
+
"""
|
| 471 |
+
import time
|
| 472 |
+
|
| 473 |
+
import requests
|
| 474 |
+
|
| 475 |
+
deadline = time.time() + timeout_s
|
| 476 |
+
health_url = f"{base_url}/health"
|
| 477 |
+
|
| 478 |
+
# Bypass proxy for localhost to avoid proxy issues
|
| 479 |
+
proxies = {"http": None, "https": None}
|
| 480 |
+
|
| 481 |
+
while time.time() < deadline:
|
| 482 |
+
try:
|
| 483 |
+
response = requests.get(health_url, timeout=2.0, proxies=proxies)
|
| 484 |
+
if response.status_code == 200:
|
| 485 |
+
return
|
| 486 |
+
except requests.RequestException:
|
| 487 |
+
pass
|
| 488 |
+
|
| 489 |
+
time.sleep(0.5)
|
| 490 |
+
|
| 491 |
+
raise TimeoutError(
|
| 492 |
+
f"Swarm service at {base_url} did not become ready within {timeout_s}s"
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
def _ensure_docker_available(self) -> None:
|
| 496 |
+
import subprocess
|
| 497 |
+
|
| 498 |
+
try:
|
| 499 |
+
subprocess.run(
|
| 500 |
+
["docker", "version"],
|
| 501 |
+
check=True,
|
| 502 |
+
capture_output=True,
|
| 503 |
+
timeout=5,
|
| 504 |
+
)
|
| 505 |
+
except (
|
| 506 |
+
subprocess.CalledProcessError,
|
| 507 |
+
FileNotFoundError,
|
| 508 |
+
subprocess.TimeoutExpired,
|
| 509 |
+
) as exc:
|
| 510 |
+
raise RuntimeError(
|
| 511 |
+
"Docker is not available. Please install Docker Desktop or Docker Engine."
|
| 512 |
+
) from exc
|
| 513 |
+
|
| 514 |
+
def _ensure_swarm_initialized(self) -> None:
|
| 515 |
+
import subprocess
|
| 516 |
+
|
| 517 |
+
try:
|
| 518 |
+
result = subprocess.run(
|
| 519 |
+
["docker", "info", "--format", "{{.Swarm.LocalNodeState}}"],
|
| 520 |
+
capture_output=True,
|
| 521 |
+
text=True,
|
| 522 |
+
check=True,
|
| 523 |
+
timeout=5,
|
| 524 |
+
)
|
| 525 |
+
state = result.stdout.strip().lower()
|
| 526 |
+
if state == "active":
|
| 527 |
+
return
|
| 528 |
+
except subprocess.CalledProcessError:
|
| 529 |
+
state = "unknown"
|
| 530 |
+
|
| 531 |
+
if not self._auto_init_swarm:
|
| 532 |
+
raise RuntimeError(
|
| 533 |
+
f"Docker Swarm is not active (state={state}). Enable Swarm manually or pass auto_init_swarm=True."
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
try:
|
| 537 |
+
subprocess.run(
|
| 538 |
+
["docker", "swarm", "init"],
|
| 539 |
+
check=True,
|
| 540 |
+
capture_output=True,
|
| 541 |
+
timeout=10,
|
| 542 |
+
)
|
| 543 |
+
except subprocess.CalledProcessError as e:
|
| 544 |
+
raise RuntimeError("Failed to initialize Docker Swarm") from e
|
| 545 |
+
|
| 546 |
+
def _ensure_overlay_network(self, network: str) -> None:
|
| 547 |
+
import subprocess
|
| 548 |
+
|
| 549 |
+
inspect = subprocess.run(
|
| 550 |
+
["docker", "network", "inspect", network],
|
| 551 |
+
capture_output=True,
|
| 552 |
+
text=True,
|
| 553 |
+
check=False,
|
| 554 |
+
)
|
| 555 |
+
if inspect.returncode == 0:
|
| 556 |
+
return
|
| 557 |
+
|
| 558 |
+
try:
|
| 559 |
+
subprocess.run(
|
| 560 |
+
[
|
| 561 |
+
"docker",
|
| 562 |
+
"network",
|
| 563 |
+
"create",
|
| 564 |
+
"--driver",
|
| 565 |
+
"overlay",
|
| 566 |
+
"--attachable",
|
| 567 |
+
network,
|
| 568 |
+
],
|
| 569 |
+
check=True,
|
| 570 |
+
capture_output=True,
|
| 571 |
+
timeout=10,
|
| 572 |
+
)
|
| 573 |
+
except subprocess.CalledProcessError as e:
|
| 574 |
+
raise RuntimeError(f"Failed to create overlay network '{network}'") from e
|
| 575 |
+
|
| 576 |
+
def _find_available_port(self) -> int:
|
| 577 |
+
import socket
|
| 578 |
+
|
| 579 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 580 |
+
s.bind(("", 0))
|
| 581 |
+
s.listen(1)
|
| 582 |
+
port = s.getsockname()[1]
|
| 583 |
+
return port
|
| 584 |
+
|
| 585 |
+
def _generate_service_name(self, image: str) -> str:
|
| 586 |
+
import time
|
| 587 |
+
|
| 588 |
+
clean_image = image.split("/")[-1].split(":")[0]
|
| 589 |
+
timestamp = int(time.time() * 1000)
|
| 590 |
+
return f"{clean_image}-swarm-{timestamp}"
|
| 591 |
+
|
| 592 |
+
|
| 593 |
class KubernetesProvider(ContainerProvider):
|
| 594 |
"""
|
| 595 |
Container provider for Kubernetes clusters.
|
|
|
|
| 603 |
>>> # Pod running in k8s, accessible via service or port-forward
|
| 604 |
>>> provider.stop_container()
|
| 605 |
"""
|
| 606 |
+
|
| 607 |
pass
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class RuntimeProvider(ABC):
|
| 611 |
+
"""
|
| 612 |
+
Abstract base class for runtime providers that are not container providers.
|
| 613 |
+
Providers implement this interface to support different runtime platforms:
|
| 614 |
+
- UVProvider: Runs environments via `uv run`
|
| 615 |
+
|
| 616 |
+
The provider manages a single runtime lifecycle and provides the base URL
|
| 617 |
+
for connecting to it.
|
| 618 |
+
|
| 619 |
+
Example:
|
| 620 |
+
>>> provider = UVProvider(project_path="/path/to/env")
|
| 621 |
+
>>> base_url = provider.start()
|
| 622 |
+
>>> print(base_url) # http://localhost:8000
|
| 623 |
+
>>> provider.stop()
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
@abstractmethod
|
| 627 |
+
def start(
|
| 628 |
+
self,
|
| 629 |
+
port: Optional[int] = None,
|
| 630 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 631 |
+
**kwargs: Any,
|
| 632 |
+
) -> str:
|
| 633 |
+
"""
|
| 634 |
+
Start a runtime from the specified image.
|
| 635 |
+
|
| 636 |
+
Args:
|
| 637 |
+
image: Runtime image name
|
| 638 |
+
port: Port to expose (if None, provider chooses)
|
| 639 |
+
env_vars: Environment variables for the runtime
|
| 640 |
+
**kwargs: Additional runtime options
|
| 641 |
+
"""
|
| 642 |
+
|
| 643 |
+
@abstractmethod
|
| 644 |
+
def stop(self) -> None:
|
| 645 |
+
"""
|
| 646 |
+
Stop the runtime.
|
| 647 |
+
"""
|
| 648 |
+
pass
|
| 649 |
+
|
| 650 |
+
@abstractmethod
|
| 651 |
+
def wait_for_ready(self, timeout_s: float = 30.0) -> None:
|
| 652 |
+
"""
|
| 653 |
+
Wait for the runtime to be ready to accept requests.
|
| 654 |
+
"""
|
| 655 |
+
pass
|
| 656 |
+
|
| 657 |
+
def __enter__(self) -> "RuntimeProvider":
|
| 658 |
+
"""
|
| 659 |
+
Enter the runtime provider.
|
| 660 |
+
"""
|
| 661 |
+
self.start()
|
| 662 |
+
return self
|
| 663 |
+
|
| 664 |
+
def __exit__(self, exc_type, exc, tb) -> None:
|
| 665 |
+
"""
|
| 666 |
+
Exit the runtime provider.
|
| 667 |
+
"""
|
| 668 |
+
self.stop()
|
| 669 |
+
return False
|
src/core/containers/runtime/uv_provider.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Providers for launching ASGI applications via ``uv run``."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import socket
|
| 7 |
+
import subprocess
|
| 8 |
+
import time
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
|
| 13 |
+
from .providers import RuntimeProvider
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _check_uv_installed() -> None:
|
| 17 |
+
try:
|
| 18 |
+
subprocess.check_output(["uv", "--version"])
|
| 19 |
+
except FileNotFoundError as exc:
|
| 20 |
+
raise RuntimeError(
|
| 21 |
+
"`uv` executable not found. Install uv from https://docs.astral.sh and ensure it is on PATH."
|
| 22 |
+
) from exc
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _find_free_port() -> int:
|
| 26 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| 27 |
+
sock.bind(("", 0))
|
| 28 |
+
sock.listen(1)
|
| 29 |
+
return sock.getsockname()[1]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _create_uv_command(
|
| 33 |
+
*,
|
| 34 |
+
host: str,
|
| 35 |
+
port: int,
|
| 36 |
+
reload: bool,
|
| 37 |
+
workers: int,
|
| 38 |
+
app: str,
|
| 39 |
+
project_path: str,
|
| 40 |
+
) -> list[str]:
|
| 41 |
+
command: list[str] = ["uv", "run", "--isolated", "--project", project_path]
|
| 42 |
+
|
| 43 |
+
command.append("--")
|
| 44 |
+
command.extend(
|
| 45 |
+
[
|
| 46 |
+
"uvicorn",
|
| 47 |
+
app,
|
| 48 |
+
"--host",
|
| 49 |
+
host,
|
| 50 |
+
"--port",
|
| 51 |
+
str(port),
|
| 52 |
+
"--workers",
|
| 53 |
+
str(workers),
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if reload:
|
| 58 |
+
command.append("--reload")
|
| 59 |
+
|
| 60 |
+
return command
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _poll_health(health_url: str, timeout_s: float) -> None:
|
| 64 |
+
"""Poll a health endpoint until it returns HTTP 200 or times out."""
|
| 65 |
+
|
| 66 |
+
deadline = time.time() + timeout_s
|
| 67 |
+
while time.time() < deadline:
|
| 68 |
+
try:
|
| 69 |
+
timeout = max(0.0001, min(deadline - time.time(), 2.0))
|
| 70 |
+
response = requests.get(health_url, timeout=timeout)
|
| 71 |
+
if response.status_code == 200:
|
| 72 |
+
return
|
| 73 |
+
except requests.RequestException:
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
time.sleep(0.5)
|
| 77 |
+
|
| 78 |
+
raise TimeoutError(f"Server did not become ready within {timeout_s:.1f} seconds")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class UVProvider(RuntimeProvider):
|
| 82 |
+
"""
|
| 83 |
+
RuntimeProvider implementation backed by ``uv run``.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
project_path: Local path to a uv project (passed to ``uv run --project``)
|
| 87 |
+
app: ASGI application path for uvicorn (defaults to ``server.app:app``)
|
| 88 |
+
host: Host interface to bind to (defaults to ``0.0.0.0``)
|
| 89 |
+
reload: Whether to enable uvicorn's reload mode
|
| 90 |
+
env_vars: Environment variables to pass through to the spawned process
|
| 91 |
+
context_timeout_s: How long to wait for the environment to become ready
|
| 92 |
+
|
| 93 |
+
Example:
|
| 94 |
+
>>> provider = UVProvider(project_path="/path/to/env")
|
| 95 |
+
>>> base_url = provider.start()
|
| 96 |
+
>>> print(base_url) # http://localhost:8000
|
| 97 |
+
>>> # Use the environment via base_url
|
| 98 |
+
>>> provider.stop()
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
*,
|
| 104 |
+
project_path: str,
|
| 105 |
+
app: str = "server.app:app",
|
| 106 |
+
host: str = "0.0.0.0",
|
| 107 |
+
reload: bool = False,
|
| 108 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 109 |
+
context_timeout_s: float = 60.0,
|
| 110 |
+
):
|
| 111 |
+
"""Initialize the UVProvider."""
|
| 112 |
+
self.project_path = os.path.abspath(project_path)
|
| 113 |
+
self.app = app
|
| 114 |
+
self.host = host
|
| 115 |
+
self.reload = reload
|
| 116 |
+
self.env_vars = env_vars
|
| 117 |
+
self.context_timeout_s = context_timeout_s
|
| 118 |
+
_check_uv_installed()
|
| 119 |
+
self._process = None
|
| 120 |
+
self._base_url = None
|
| 121 |
+
|
| 122 |
+
def start(
|
| 123 |
+
self,
|
| 124 |
+
port: Optional[int] = None,
|
| 125 |
+
env_vars: Optional[Dict[str, str]] = None,
|
| 126 |
+
workers: int = 1,
|
| 127 |
+
**_: Dict[str, str],
|
| 128 |
+
) -> str:
|
| 129 |
+
"""
|
| 130 |
+
Start the environment via `uv run`.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
port: The port to bind the environment to
|
| 134 |
+
env_vars: Environment variables to pass to the environment
|
| 135 |
+
workers: The number of workers to use
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
The base URL of the environment
|
| 139 |
+
|
| 140 |
+
Raises:
|
| 141 |
+
RuntimeError: If the environment is already running
|
| 142 |
+
"""
|
| 143 |
+
if self._process is not None and self._process.poll() is None:
|
| 144 |
+
raise RuntimeError("UVProvider is already running")
|
| 145 |
+
|
| 146 |
+
bind_port = port or _find_free_port()
|
| 147 |
+
|
| 148 |
+
command = _create_uv_command(
|
| 149 |
+
host=self.host,
|
| 150 |
+
port=bind_port,
|
| 151 |
+
reload=self.reload,
|
| 152 |
+
workers=workers,
|
| 153 |
+
app=self.app,
|
| 154 |
+
project_path=self.project_path,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
env = os.environ.copy()
|
| 158 |
+
|
| 159 |
+
if self.env_vars:
|
| 160 |
+
env.update(self.env_vars)
|
| 161 |
+
if env_vars:
|
| 162 |
+
env.update(env_vars)
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
self._process = subprocess.Popen(command, env=env)
|
| 166 |
+
except OSError as exc:
|
| 167 |
+
raise RuntimeError(f"Failed to launch `uv run`: {exc}") from exc
|
| 168 |
+
|
| 169 |
+
client_host = "127.0.0.1" if self.host in {"0.0.0.0", "::"} else self.host
|
| 170 |
+
self._base_url = f"http://{client_host}:{bind_port}"
|
| 171 |
+
return self._base_url
|
| 172 |
+
|
| 173 |
+
def wait_for_ready(self, timeout_s: float = 60.0) -> None:
|
| 174 |
+
"""
|
| 175 |
+
Wait for the environment to become ready.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
timeout_s: The timeout to wait for the environment to become ready
|
| 179 |
+
|
| 180 |
+
Raises:
|
| 181 |
+
RuntimeError: If the environment is not running
|
| 182 |
+
TimeoutError: If the environment does not become ready within the timeout
|
| 183 |
+
"""
|
| 184 |
+
if self._process and self._process.poll() is not None:
|
| 185 |
+
code = self._process.returncode
|
| 186 |
+
raise RuntimeError(f"uv process exited prematurely with code {code}")
|
| 187 |
+
|
| 188 |
+
_poll_health(f"{self._base_url}/health", timeout_s=timeout_s)
|
| 189 |
+
|
| 190 |
+
def stop(self) -> None:
|
| 191 |
+
"""
|
| 192 |
+
Stop the environment.
|
| 193 |
+
|
| 194 |
+
Raises:
|
| 195 |
+
RuntimeError: If the environment is not running
|
| 196 |
+
"""
|
| 197 |
+
if self._process is None:
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
if self._process.poll() is None:
|
| 201 |
+
self._process.terminate()
|
| 202 |
+
try:
|
| 203 |
+
self._process.wait(timeout=10.0)
|
| 204 |
+
except subprocess.TimeoutExpired:
|
| 205 |
+
self._process.kill()
|
| 206 |
+
self._process.wait(timeout=5.0)
|
| 207 |
+
|
| 208 |
+
self._process = None
|
| 209 |
+
self._base_url = None
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def base_url(self) -> str:
|
| 213 |
+
"""
|
| 214 |
+
The base URL of the environment.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
The base URL of the environment
|
| 218 |
+
|
| 219 |
+
Raises:
|
| 220 |
+
RuntimeError: If the environment is not running
|
| 221 |
+
"""
|
| 222 |
+
if self._base_url is None:
|
| 223 |
+
raise RuntimeError("UVProvider has not been started")
|
| 224 |
+
return self._base_url
|
src/core/containers/test_local_docker_provider.py
CHANGED
|
@@ -16,8 +16,8 @@ from pathlib import Path
|
|
| 16 |
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 17 |
|
| 18 |
import requests
|
|
|
|
| 19 |
|
| 20 |
-
from core.containers.runtime import LocalDockerProvider
|
| 21 |
|
| 22 |
# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
|
| 23 |
def test_local_docker_provider():
|
|
@@ -87,7 +87,9 @@ def test_local_docker_provider():
|
|
| 87 |
print(f" Length: {data['observation']['message_length']}")
|
| 88 |
print(f" Reward: {data['reward']}")
|
| 89 |
assert response.status_code == 200
|
| 90 |
-
assert
|
|
|
|
|
|
|
| 91 |
assert data["observation"]["message_length"] == 31
|
| 92 |
print("✓ Step test passed\n")
|
| 93 |
|
|
@@ -107,11 +109,11 @@ def test_local_docker_provider():
|
|
| 107 |
for i in range(3):
|
| 108 |
response = requests.post(
|
| 109 |
f"{base_url}/step",
|
| 110 |
-
json={"action": {"message": f"Message {i+1}"}},
|
| 111 |
headers={"Content-Type": "application/json"},
|
| 112 |
)
|
| 113 |
assert response.status_code == 200
|
| 114 |
-
print(f" Step {i+1}: ✓")
|
| 115 |
|
| 116 |
# Check state updated
|
| 117 |
response = requests.get(f"{base_url}/state")
|
|
@@ -130,6 +132,7 @@ def test_local_docker_provider():
|
|
| 130 |
except Exception as e:
|
| 131 |
print(f"\n❌ Test failed: {e}")
|
| 132 |
import traceback
|
|
|
|
| 133 |
traceback.print_exc()
|
| 134 |
return False
|
| 135 |
|
|
@@ -197,8 +200,7 @@ def test_provider_with_env_vars():
|
|
| 197 |
|
| 198 |
print("Starting container with environment variables...")
|
| 199 |
base_url = provider.start_container(
|
| 200 |
-
"echo-env:latest",
|
| 201 |
-
env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
|
| 202 |
)
|
| 203 |
print(f"✓ Started at: {base_url}")
|
| 204 |
|
|
|
|
| 16 |
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 17 |
|
| 18 |
import requests
|
| 19 |
+
from openenv.core.containers.runtime import LocalDockerProvider
|
| 20 |
|
|
|
|
| 21 |
|
| 22 |
# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
|
| 23 |
def test_local_docker_provider():
|
|
|
|
| 87 |
print(f" Length: {data['observation']['message_length']}")
|
| 88 |
print(f" Reward: {data['reward']}")
|
| 89 |
assert response.status_code == 200
|
| 90 |
+
assert (
|
| 91 |
+
data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!"
|
| 92 |
+
)
|
| 93 |
assert data["observation"]["message_length"] == 31
|
| 94 |
print("✓ Step test passed\n")
|
| 95 |
|
|
|
|
| 109 |
for i in range(3):
|
| 110 |
response = requests.post(
|
| 111 |
f"{base_url}/step",
|
| 112 |
+
json={"action": {"message": f"Message {i + 1}"}},
|
| 113 |
headers={"Content-Type": "application/json"},
|
| 114 |
)
|
| 115 |
assert response.status_code == 200
|
| 116 |
+
print(f" Step {i + 1}: ✓")
|
| 117 |
|
| 118 |
# Check state updated
|
| 119 |
response = requests.get(f"{base_url}/state")
|
|
|
|
| 132 |
except Exception as e:
|
| 133 |
print(f"\n❌ Test failed: {e}")
|
| 134 |
import traceback
|
| 135 |
+
|
| 136 |
traceback.print_exc()
|
| 137 |
return False
|
| 138 |
|
|
|
|
| 200 |
|
| 201 |
print("Starting container with environment variables...")
|
| 202 |
base_url = provider.start_container(
|
| 203 |
+
"echo-env:latest", env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
|
|
|
|
| 204 |
)
|
| 205 |
print(f"✓ Started at: {base_url}")
|
| 206 |
|
src/core/env_client.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Environment client for persistent sessions.
|
| 9 |
+
|
| 10 |
+
This module provides a WebSocket-based client that maintains a persistent connection
|
| 11 |
+
to an environment server, enabling efficient multi-step interactions without
|
| 12 |
+
the overhead of HTTP request/response cycles.
|
| 13 |
+
|
| 14 |
+
The client is async by default. For synchronous usage, use the `.sync()` method
|
| 15 |
+
to get a `SyncEnvClient` wrapper.
|
| 16 |
+
|
| 17 |
+
Example (async):
|
| 18 |
+
>>> async with GenericEnvClient(base_url="ws://localhost:8000") as env:
|
| 19 |
+
... result = await env.reset()
|
| 20 |
+
... result = await env.step({"code": "print('hello')"})
|
| 21 |
+
|
| 22 |
+
Example (sync wrapper):
|
| 23 |
+
>>> env = GenericEnvClient(base_url="ws://localhost:8000").sync()
|
| 24 |
+
>>> with env:
|
| 25 |
+
... result = env.reset()
|
| 26 |
+
... result = env.step({"code": "print('hello')"})
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from __future__ import annotations
|
| 30 |
+
|
| 31 |
+
import asyncio
|
| 32 |
+
import json
|
| 33 |
+
import os
|
| 34 |
+
from abc import ABC, abstractmethod
|
| 35 |
+
from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar
|
| 36 |
+
|
| 37 |
+
from .client_types import StateT, StepResult
|
| 38 |
+
from .containers.runtime import LocalDockerProvider, UVProvider
|
| 39 |
+
from .utils import convert_to_ws_url
|
| 40 |
+
|
| 41 |
+
if TYPE_CHECKING:
|
| 42 |
+
from websockets.asyncio.client import ClientConnection
|
| 43 |
+
|
| 44 |
+
from .containers.runtime import ContainerProvider, RuntimeProvider
|
| 45 |
+
from .sync_client import SyncEnvClient
|
| 46 |
+
|
| 47 |
+
from websockets.asyncio.client import connect as ws_connect
|
| 48 |
+
|
| 49 |
+
ActT = TypeVar("ActT")
|
| 50 |
+
ObsT = TypeVar("ObsT")
|
| 51 |
+
EnvClientT = TypeVar("EnvClientT", bound="EnvClient")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
|
| 55 |
+
"""
|
| 56 |
+
Async environment client for persistent sessions.
|
| 57 |
+
|
| 58 |
+
This client maintains a persistent WebSocket connection to an environment
|
| 59 |
+
server, enabling efficient multi-step interactions. Each client instance
|
| 60 |
+
corresponds to a dedicated environment session on the server.
|
| 61 |
+
|
| 62 |
+
The client is async by default. For synchronous usage, use the `.sync()`
|
| 63 |
+
method to get a `SyncEnvClient` wrapper.
|
| 64 |
+
|
| 65 |
+
Features:
|
| 66 |
+
- Lower latency for sequential interactions
|
| 67 |
+
- Session state is maintained server-side
|
| 68 |
+
- Better suited for long-running episodes
|
| 69 |
+
- Async by default for modern Python async/await patterns
|
| 70 |
+
|
| 71 |
+
Example (async):
|
| 72 |
+
>>> from envs.coding_env.client import CodingEnv
|
| 73 |
+
>>>
|
| 74 |
+
>>> # Connect to a server using async context manager
|
| 75 |
+
>>> async with CodingEnv(base_url="ws://localhost:8000") as env:
|
| 76 |
+
... result = await env.reset(seed=42)
|
| 77 |
+
... while not result.done:
|
| 78 |
+
... action = agent.predict(result.observation)
|
| 79 |
+
... result = await env.step(action)
|
| 80 |
+
|
| 81 |
+
Example (sync wrapper):
|
| 82 |
+
>>> env = CodingEnv(base_url="ws://localhost:8000").sync()
|
| 83 |
+
>>> with env:
|
| 84 |
+
... result = env.reset(seed=42)
|
| 85 |
+
... result = env.step(action)
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
base_url: str,
|
| 91 |
+
connect_timeout_s: float = 10.0,
|
| 92 |
+
message_timeout_s: float = 60.0,
|
| 93 |
+
max_message_size_mb: float = 100.0,
|
| 94 |
+
provider: Optional["ContainerProvider | RuntimeProvider"] = None,
|
| 95 |
+
mode: Optional[str] = None,
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
Initialize environment client.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
base_url: Base URL of the environment server (http:// or ws://).
|
| 102 |
+
Will be converted to ws:// if http:// is provided.
|
| 103 |
+
connect_timeout_s: Timeout for establishing WebSocket connection
|
| 104 |
+
message_timeout_s: Timeout for receiving responses to messages
|
| 105 |
+
max_message_size_mb: Maximum WebSocket message size in megabytes.
|
| 106 |
+
Default 100MB to handle large observations (screenshots, DOM, etc.)
|
| 107 |
+
provider: Optional container/runtime provider for lifecycle management.
|
| 108 |
+
Can be a ContainerProvider (Docker) or RuntimeProvider (UV).
|
| 109 |
+
mode: Communication mode: 'simulation' for Gym-style API (default) or
|
| 110 |
+
'production' for MCP JSON-RPC protocol. Can also be set via the
|
| 111 |
+
OPENENV_CLIENT_MODE environment variable. Constructor parameter
|
| 112 |
+
takes precedence over environment variable. Case-insensitive.
|
| 113 |
+
"""
|
| 114 |
+
# Determine mode (constructor > env var > default)
|
| 115 |
+
if mode is None:
|
| 116 |
+
mode = os.environ.get("OPENENV_CLIENT_MODE", "simulation")
|
| 117 |
+
|
| 118 |
+
# Normalize and validate mode
|
| 119 |
+
mode = mode.lower()
|
| 120 |
+
if mode not in ("simulation", "production"):
|
| 121 |
+
raise ValueError(
|
| 122 |
+
f"Invalid mode: '{mode}'. Must be 'simulation' or 'production'. "
|
| 123 |
+
f"Set via constructor parameter or OPENENV_CLIENT_MODE environment variable."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Store mode (use object.__setattr__ to bypass immutability)
|
| 127 |
+
object.__setattr__(self, "_mode", mode)
|
| 128 |
+
|
| 129 |
+
# Convert HTTP URL to WebSocket URL
|
| 130 |
+
ws_url = convert_to_ws_url(base_url)
|
| 131 |
+
|
| 132 |
+
self._ws_url = f"{ws_url}/ws"
|
| 133 |
+
self._connect_timeout = connect_timeout_s
|
| 134 |
+
self._message_timeout = message_timeout_s
|
| 135 |
+
self._max_message_size = int(
|
| 136 |
+
max_message_size_mb * 1024 * 1024
|
| 137 |
+
) # Convert MB to bytes
|
| 138 |
+
self._provider = provider
|
| 139 |
+
self._ws: Optional[ClientConnection] = None
|
| 140 |
+
|
| 141 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 142 |
+
"""Prevent modification of _mode after initialization."""
|
| 143 |
+
if name == "_mode" and hasattr(self, "_mode"):
|
| 144 |
+
raise AttributeError("Cannot modify mode after initialization")
|
| 145 |
+
super().__setattr__(name, value)
|
| 146 |
+
|
| 147 |
+
async def connect(self) -> "EnvClient":
|
| 148 |
+
"""
|
| 149 |
+
Establish WebSocket connection to the server.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
self for method chaining
|
| 153 |
+
|
| 154 |
+
Raises:
|
| 155 |
+
ConnectionError: If connection cannot be established
|
| 156 |
+
"""
|
| 157 |
+
if self._ws is not None:
|
| 158 |
+
return self
|
| 159 |
+
|
| 160 |
+
# Bypass proxy for localhost connections
|
| 161 |
+
ws_url_lower = self._ws_url.lower()
|
| 162 |
+
is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
|
| 163 |
+
|
| 164 |
+
old_no_proxy = os.environ.get("NO_PROXY")
|
| 165 |
+
if is_localhost:
|
| 166 |
+
# Set NO_PROXY to bypass proxy for localhost
|
| 167 |
+
current_no_proxy = old_no_proxy or ""
|
| 168 |
+
if "localhost" not in current_no_proxy.lower():
|
| 169 |
+
os.environ["NO_PROXY"] = (
|
| 170 |
+
f"{current_no_proxy},localhost,127.0.0.1"
|
| 171 |
+
if current_no_proxy
|
| 172 |
+
else "localhost,127.0.0.1"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
self._ws = await ws_connect(
|
| 177 |
+
self._ws_url,
|
| 178 |
+
open_timeout=self._connect_timeout,
|
| 179 |
+
max_size=self._max_message_size,
|
| 180 |
+
)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
|
| 183 |
+
finally:
|
| 184 |
+
# Restore original NO_PROXY value
|
| 185 |
+
if is_localhost:
|
| 186 |
+
if old_no_proxy is None:
|
| 187 |
+
os.environ.pop("NO_PROXY", None)
|
| 188 |
+
else:
|
| 189 |
+
os.environ["NO_PROXY"] = old_no_proxy
|
| 190 |
+
|
| 191 |
+
return self
|
| 192 |
+
|
| 193 |
+
async def disconnect(self) -> None:
|
| 194 |
+
"""Close the WebSocket connection."""
|
| 195 |
+
if self._ws is not None:
|
| 196 |
+
try:
|
| 197 |
+
# Send close message
|
| 198 |
+
await self._send({"type": "close"})
|
| 199 |
+
except Exception:
|
| 200 |
+
pass # Best effort
|
| 201 |
+
try:
|
| 202 |
+
await self._ws.close()
|
| 203 |
+
except Exception:
|
| 204 |
+
pass
|
| 205 |
+
self._ws = None
|
| 206 |
+
|
| 207 |
+
async def _ensure_connected(self) -> None:
|
| 208 |
+
"""Ensure WebSocket connection is established."""
|
| 209 |
+
if self._ws is None:
|
| 210 |
+
await self.connect()
|
| 211 |
+
|
| 212 |
+
async def _send(self, message: Dict[str, Any]) -> None:
|
| 213 |
+
"""Send a message over the WebSocket."""
|
| 214 |
+
await self._ensure_connected()
|
| 215 |
+
assert self._ws is not None
|
| 216 |
+
await self._ws.send(json.dumps(message))
|
| 217 |
+
|
| 218 |
+
async def _receive(self) -> Dict[str, Any]:
|
| 219 |
+
"""Receive and parse a message from the WebSocket."""
|
| 220 |
+
assert self._ws is not None
|
| 221 |
+
raw = await asyncio.wait_for(self._ws.recv(), timeout=self._message_timeout)
|
| 222 |
+
return json.loads(raw)
|
| 223 |
+
|
| 224 |
+
async def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
| 225 |
+
"""Send a message and wait for response."""
|
| 226 |
+
await self._send(message)
|
| 227 |
+
response = await self._receive()
|
| 228 |
+
|
| 229 |
+
# Check for error response
|
| 230 |
+
if response.get("type") == "error":
|
| 231 |
+
error_data = response.get("data", {})
|
| 232 |
+
raise RuntimeError(
|
| 233 |
+
f"Server error: {error_data.get('message', 'Unknown error')} "
|
| 234 |
+
f"(code: {error_data.get('code', 'UNKNOWN')})"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
return response
|
| 238 |
+
|
| 239 |
+
@classmethod
|
| 240 |
+
async def from_docker_image(
|
| 241 |
+
cls: Type[EnvClientT],
|
| 242 |
+
image: str,
|
| 243 |
+
provider: Optional["ContainerProvider"] = None,
|
| 244 |
+
**kwargs: Any,
|
| 245 |
+
) -> EnvClientT:
|
| 246 |
+
"""
|
| 247 |
+
Create an environment client by spinning up a Docker container.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
image: Docker image name to run (e.g., "coding-env:latest")
|
| 251 |
+
provider: Container provider to use (defaults to LocalDockerProvider)
|
| 252 |
+
**kwargs: Additional arguments to pass to provider.start_container()
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Connected client instance
|
| 256 |
+
"""
|
| 257 |
+
if provider is None:
|
| 258 |
+
provider = LocalDockerProvider()
|
| 259 |
+
|
| 260 |
+
# Start container
|
| 261 |
+
base_url = provider.start_container(image, **kwargs)
|
| 262 |
+
|
| 263 |
+
# Wait for server to be ready
|
| 264 |
+
provider.wait_for_ready(base_url)
|
| 265 |
+
|
| 266 |
+
# Create and connect client
|
| 267 |
+
client = cls(base_url=base_url, provider=provider)
|
| 268 |
+
await client.connect()
|
| 269 |
+
|
| 270 |
+
return client
|
| 271 |
+
|
| 272 |
+
@classmethod
|
| 273 |
+
async def from_env(
|
| 274 |
+
cls: Type[EnvClientT],
|
| 275 |
+
repo_id: str,
|
| 276 |
+
*,
|
| 277 |
+
use_docker: bool = True,
|
| 278 |
+
provider: Optional["ContainerProvider | RuntimeProvider"] = None,
|
| 279 |
+
**provider_kwargs: Any,
|
| 280 |
+
) -> EnvClientT:
|
| 281 |
+
"""
|
| 282 |
+
Create a client from a Hugging Face Space.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
repo_id: Hugging Face space identifier ``{org}/{space}``.
|
| 286 |
+
use_docker: When ``True`` (default) pull from the HF registry and
|
| 287 |
+
launch via :class:`LocalDockerProvider`. When ``False`` run the
|
| 288 |
+
space locally with :class:`UVProvider`.
|
| 289 |
+
provider: Optional provider instance to reuse. Must be a
|
| 290 |
+
:class:`ContainerProvider` when ``use_docker=True`` and a
|
| 291 |
+
:class:`RuntimeProvider` otherwise.
|
| 292 |
+
provider_kwargs: Additional keyword arguments forwarded to
|
| 293 |
+
either the container provider's ``start_container`` (docker)
|
| 294 |
+
or to the ``UVProvider`` constructor/start (uv). When
|
| 295 |
+
``use_docker=False``, the ``project_path`` argument can be
|
| 296 |
+
used to override the default git URL
|
| 297 |
+
(``git+https://huggingface.co/spaces/{repo_id}``).
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Connected client instance
|
| 301 |
+
|
| 302 |
+
Examples:
|
| 303 |
+
>>> # Pull and run from HF Docker registry
|
| 304 |
+
>>> env = await MyEnv.from_env("openenv/echo-env")
|
| 305 |
+
>>>
|
| 306 |
+
>>> # Run locally with UV (clones the space)
|
| 307 |
+
>>> env = await MyEnv.from_env("openenv/echo-env", use_docker=False)
|
| 308 |
+
>>>
|
| 309 |
+
>>> # Run from a local checkout
|
| 310 |
+
>>> env = await MyEnv.from_env(
|
| 311 |
+
... "openenv/echo-env",
|
| 312 |
+
... use_docker=False,
|
| 313 |
+
... project_path="/path/to/local/checkout"
|
| 314 |
+
... )
|
| 315 |
+
"""
|
| 316 |
+
# Extract start args that apply to both providers
|
| 317 |
+
start_args = {}
|
| 318 |
+
for key in ("port", "env_vars", "workers"):
|
| 319 |
+
if key in provider_kwargs:
|
| 320 |
+
start_args[key] = provider_kwargs.pop(key)
|
| 321 |
+
|
| 322 |
+
if use_docker:
|
| 323 |
+
# Docker mode: pull from HF registry
|
| 324 |
+
docker_provider = provider or LocalDockerProvider()
|
| 325 |
+
tag = provider_kwargs.pop("tag", "latest")
|
| 326 |
+
image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
|
| 327 |
+
base_url = docker_provider.start_container(
|
| 328 |
+
image, **start_args, **provider_kwargs
|
| 329 |
+
)
|
| 330 |
+
docker_provider.wait_for_ready(base_url)
|
| 331 |
+
|
| 332 |
+
client = cls(base_url=base_url, provider=docker_provider)
|
| 333 |
+
await client.connect()
|
| 334 |
+
return client
|
| 335 |
+
else:
|
| 336 |
+
# UV mode: clone and run with uv
|
| 337 |
+
if provider is None:
|
| 338 |
+
uv_kwargs = dict(provider_kwargs)
|
| 339 |
+
project_path = uv_kwargs.pop("project_path", None)
|
| 340 |
+
if project_path is None:
|
| 341 |
+
project_path = f"git+https://huggingface.co/spaces/{repo_id}"
|
| 342 |
+
|
| 343 |
+
provider = UVProvider(project_path=project_path, **uv_kwargs)
|
| 344 |
+
else:
|
| 345 |
+
if provider_kwargs:
|
| 346 |
+
raise ValueError(
|
| 347 |
+
"provider_kwargs cannot be used when supplying a provider instance"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
base_url = provider.start(**start_args)
|
| 351 |
+
provider.wait_for_ready()
|
| 352 |
+
|
| 353 |
+
client = cls(base_url=base_url, provider=provider)
|
| 354 |
+
await client.connect()
|
| 355 |
+
return client
|
| 356 |
+
|
| 357 |
+
@abstractmethod
|
| 358 |
+
def _step_payload(self, action: ActT) -> Dict[str, Any]:
|
| 359 |
+
"""Convert an Action object to the JSON data expected by the env server."""
|
| 360 |
+
raise NotImplementedError
|
| 361 |
+
|
| 362 |
+
@abstractmethod
|
| 363 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]:
|
| 364 |
+
"""Convert a JSON response from the env server to StepResult[ObsT]."""
|
| 365 |
+
raise NotImplementedError
|
| 366 |
+
|
| 367 |
+
@abstractmethod
|
| 368 |
+
def _parse_state(self, payload: Dict[str, Any]) -> StateT:
|
| 369 |
+
"""Convert a JSON response from the state endpoint to a State object."""
|
| 370 |
+
raise NotImplementedError
|
| 371 |
+
|
| 372 |
+
async def reset(self, **kwargs: Any) -> StepResult[ObsT]:
|
| 373 |
+
"""
|
| 374 |
+
Reset the environment with optional parameters.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
**kwargs: Optional parameters passed to the environment's reset method.
|
| 378 |
+
Common parameters include:
|
| 379 |
+
- seed: Random seed for reproducibility
|
| 380 |
+
- episode_id: Custom episode identifier
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
StepResult containing initial observation
|
| 384 |
+
"""
|
| 385 |
+
message = {
|
| 386 |
+
"type": "reset",
|
| 387 |
+
"data": kwargs,
|
| 388 |
+
}
|
| 389 |
+
response = await self._send_and_receive(message)
|
| 390 |
+
return self._parse_result(response.get("data", {}))
|
| 391 |
+
|
| 392 |
+
async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
|
| 393 |
+
"""
|
| 394 |
+
Execute an action in the environment.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
action: The action to execute
|
| 398 |
+
**kwargs: Optional parameters (currently ignored)
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
StepResult containing observation, reward, and done status
|
| 402 |
+
"""
|
| 403 |
+
message = {
|
| 404 |
+
"type": "step",
|
| 405 |
+
"data": self._step_payload(action),
|
| 406 |
+
}
|
| 407 |
+
response = await self._send_and_receive(message)
|
| 408 |
+
return self._parse_result(response.get("data", {}))
|
| 409 |
+
|
| 410 |
+
async def state(self) -> StateT:
|
| 411 |
+
"""
|
| 412 |
+
Get the current environment state from the server.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
State object with environment state information
|
| 416 |
+
"""
|
| 417 |
+
message = {"type": "state"}
|
| 418 |
+
response = await self._send_and_receive(message)
|
| 419 |
+
return self._parse_state(response.get("data", {}))
|
| 420 |
+
|
| 421 |
+
async def close(self) -> None:
|
| 422 |
+
"""
|
| 423 |
+
Close the WebSocket connection and clean up resources.
|
| 424 |
+
|
| 425 |
+
If this client was created via from_docker_image() or from_env(),
|
| 426 |
+
this will also stop and remove the associated container/process.
|
| 427 |
+
"""
|
| 428 |
+
await self.disconnect()
|
| 429 |
+
|
| 430 |
+
if self._provider is not None:
|
| 431 |
+
# Handle both ContainerProvider and RuntimeProvider
|
| 432 |
+
if hasattr(self._provider, "stop_container"):
|
| 433 |
+
self._provider.stop_container()
|
| 434 |
+
elif hasattr(self._provider, "stop"):
|
| 435 |
+
self._provider.stop()
|
| 436 |
+
|
| 437 |
+
async def __aenter__(self) -> "EnvClient":
|
| 438 |
+
"""Enter async context manager, ensuring connection is established."""
|
| 439 |
+
await self.connect()
|
| 440 |
+
return self
|
| 441 |
+
|
| 442 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 443 |
+
"""Exit async context manager, closing connection."""
|
| 444 |
+
await self.close()
|
| 445 |
+
|
| 446 |
+
def __enter__(self) -> "EnvClient":
|
| 447 |
+
"""Sync context manager entry - raises error suggesting async usage."""
|
| 448 |
+
raise TypeError(
|
| 449 |
+
"EnvClient is async by default. Use 'async with' instead of 'with', "
|
| 450 |
+
"or call .sync() to get a synchronous wrapper:\n"
|
| 451 |
+
" async with client: # async usage\n"
|
| 452 |
+
" with client.sync(): # sync wrapper"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
| 456 |
+
"""Sync context manager exit - should not be reached."""
|
| 457 |
+
pass # pragma: no cover
|
| 458 |
+
|
| 459 |
+
def sync(self) -> "SyncEnvClient":
|
| 460 |
+
"""
|
| 461 |
+
Return a synchronous wrapper around this async client.
|
| 462 |
+
|
| 463 |
+
Use this method when you need synchronous access to the environment
|
| 464 |
+
without async/await syntax. This is useful for:
|
| 465 |
+
- Integration with synchronous codebases
|
| 466 |
+
- Interactive/REPL usage
|
| 467 |
+
- Stopping async from "infecting" the call stack
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
SyncEnvClient wrapper that provides synchronous methods
|
| 471 |
+
|
| 472 |
+
Example:
|
| 473 |
+
>>> # Create async client and get sync wrapper
|
| 474 |
+
>>> async_client = GenericEnvClient(base_url="http://localhost:8000")
|
| 475 |
+
>>> sync_client = async_client.sync()
|
| 476 |
+
>>>
|
| 477 |
+
>>> # Use synchronous API
|
| 478 |
+
>>> with sync_client:
|
| 479 |
+
... result = sync_client.reset()
|
| 480 |
+
... result = sync_client.step({"code": "print('hello')"})
|
| 481 |
+
"""
|
| 482 |
+
from .sync_client import SyncEnvClient
|
| 483 |
+
|
| 484 |
+
return SyncEnvClient(self)
|
src/core/env_server/__init__.py
CHANGED
|
@@ -7,10 +7,74 @@
|
|
| 7 |
"""Core environment interfaces and types."""
|
| 8 |
|
| 9 |
from .base_transforms import CompositeTransform, NullTransform
|
| 10 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from .interfaces import Environment, Message, ModelTokenizer, Transform
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
__all__ = [
|
| 16 |
# Core interfaces
|
|
@@ -22,6 +86,33 @@ __all__ = [
|
|
| 22 |
"Action",
|
| 23 |
"Observation",
|
| 24 |
"State",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Base transforms
|
| 26 |
"CompositeTransform",
|
| 27 |
"NullTransform",
|
|
@@ -32,4 +123,28 @@ __all__ = [
|
|
| 32 |
# Web Interface
|
| 33 |
"create_web_interface_app",
|
| 34 |
"WebInterfaceManager",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
]
|
|
|
|
| 7 |
"""Core environment interfaces and types."""
|
| 8 |
|
| 9 |
from .base_transforms import CompositeTransform, NullTransform
|
| 10 |
+
from .exceptions import (
|
| 11 |
+
ConcurrencyConfigurationError,
|
| 12 |
+
EnvironmentFactoryError,
|
| 13 |
+
OpenEnvError,
|
| 14 |
+
SessionCapacityError,
|
| 15 |
+
SessionCreationError,
|
| 16 |
+
SessionNotFoundError,
|
| 17 |
+
)
|
| 18 |
+
from .http_server import create_app, create_fastapi_app, HTTPEnvServer
|
| 19 |
from .interfaces import Environment, Message, ModelTokenizer, Transform
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from .mcp_environment import MCPEnvironment
|
| 23 |
+
except ModuleNotFoundError:
|
| 24 |
+
MCPEnvironment = None # type: ignore[assignment]
|
| 25 |
+
|
| 26 |
+
from .mcp_types import (
|
| 27 |
+
CallToolAction,
|
| 28 |
+
CallToolObservation,
|
| 29 |
+
JsonRpcError,
|
| 30 |
+
# JSON-RPC types
|
| 31 |
+
JsonRpcErrorCode,
|
| 32 |
+
JsonRpcRequest,
|
| 33 |
+
JsonRpcResponse,
|
| 34 |
+
ListToolsAction,
|
| 35 |
+
ListToolsObservation,
|
| 36 |
+
McpMethod,
|
| 37 |
+
RESERVED_TOOL_NAMES,
|
| 38 |
+
Tool,
|
| 39 |
+
ToolError,
|
| 40 |
+
ToolErrorType,
|
| 41 |
+
WSMCPMessage,
|
| 42 |
+
WSMCPResponse,
|
| 43 |
+
)
|
| 44 |
+
from .route_config import GetEndpointConfig
|
| 45 |
+
from .serialization import (
|
| 46 |
+
deserialize_action,
|
| 47 |
+
deserialize_action_with_preprocessing,
|
| 48 |
+
serialize_observation,
|
| 49 |
+
)
|
| 50 |
+
from .types import (
|
| 51 |
+
Action,
|
| 52 |
+
BaseMessage,
|
| 53 |
+
ConcurrencyConfig,
|
| 54 |
+
HealthResponse,
|
| 55 |
+
HealthStatus,
|
| 56 |
+
Observation,
|
| 57 |
+
SchemaResponse,
|
| 58 |
+
ServerCapacityStatus,
|
| 59 |
+
ServerMode,
|
| 60 |
+
SessionInfo,
|
| 61 |
+
State,
|
| 62 |
+
WSCloseMessage,
|
| 63 |
+
WSErrorCode,
|
| 64 |
+
WSErrorResponse,
|
| 65 |
+
WSIncomingMessage,
|
| 66 |
+
WSObservationResponse,
|
| 67 |
+
WSResetMessage,
|
| 68 |
+
WSStateMessage,
|
| 69 |
+
WSStateResponse,
|
| 70 |
+
WSStepMessage,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
from .web_interface import create_web_interface_app, WebInterfaceManager
|
| 75 |
+
except ModuleNotFoundError:
|
| 76 |
+
create_web_interface_app = None # type: ignore[assignment]
|
| 77 |
+
WebInterfaceManager = None # type: ignore[assignment]
|
| 78 |
|
| 79 |
__all__ = [
|
| 80 |
# Core interfaces
|
|
|
|
| 86 |
"Action",
|
| 87 |
"Observation",
|
| 88 |
"State",
|
| 89 |
+
"SchemaResponse",
|
| 90 |
+
"HealthResponse",
|
| 91 |
+
# Enums
|
| 92 |
+
"HealthStatus",
|
| 93 |
+
"ServerMode",
|
| 94 |
+
"WSErrorCode",
|
| 95 |
+
# WebSocket message types
|
| 96 |
+
"BaseMessage",
|
| 97 |
+
"WSIncomingMessage",
|
| 98 |
+
"WSResetMessage",
|
| 99 |
+
"WSStepMessage",
|
| 100 |
+
"WSStateMessage",
|
| 101 |
+
"WSCloseMessage",
|
| 102 |
+
"WSObservationResponse",
|
| 103 |
+
"WSStateResponse",
|
| 104 |
+
"WSErrorResponse",
|
| 105 |
+
# Concurrency types
|
| 106 |
+
"ConcurrencyConfig",
|
| 107 |
+
"ServerCapacityStatus",
|
| 108 |
+
"SessionInfo",
|
| 109 |
+
# Exceptions
|
| 110 |
+
"OpenEnvError",
|
| 111 |
+
"ConcurrencyConfigurationError",
|
| 112 |
+
"SessionCapacityError",
|
| 113 |
+
"SessionNotFoundError",
|
| 114 |
+
"SessionCreationError",
|
| 115 |
+
"EnvironmentFactoryError",
|
| 116 |
# Base transforms
|
| 117 |
"CompositeTransform",
|
| 118 |
"NullTransform",
|
|
|
|
| 123 |
# Web Interface
|
| 124 |
"create_web_interface_app",
|
| 125 |
"WebInterfaceManager",
|
| 126 |
+
# Serialization utilities
|
| 127 |
+
"deserialize_action",
|
| 128 |
+
"deserialize_action_with_preprocessing",
|
| 129 |
+
"serialize_observation",
|
| 130 |
+
# Route configuration
|
| 131 |
+
"GetEndpointConfig",
|
| 132 |
+
# MCP types
|
| 133 |
+
"Tool",
|
| 134 |
+
"ToolError",
|
| 135 |
+
"ToolErrorType",
|
| 136 |
+
"ListToolsAction",
|
| 137 |
+
"CallToolAction",
|
| 138 |
+
"ListToolsObservation",
|
| 139 |
+
"CallToolObservation",
|
| 140 |
+
"WSMCPMessage",
|
| 141 |
+
"WSMCPResponse",
|
| 142 |
+
"RESERVED_TOOL_NAMES",
|
| 143 |
+
"MCPEnvironment",
|
| 144 |
+
# JSON-RPC types
|
| 145 |
+
"JsonRpcErrorCode",
|
| 146 |
+
"JsonRpcError",
|
| 147 |
+
"JsonRpcRequest",
|
| 148 |
+
"JsonRpcResponse",
|
| 149 |
+
"McpMethod",
|
| 150 |
]
|
src/core/env_server/base_transforms.py
CHANGED
|
@@ -26,4 +26,4 @@ class NullTransform(Transform):
|
|
| 26 |
"""Default transform that passes through unchanged."""
|
| 27 |
|
| 28 |
def __call__(self, observation: Observation) -> Observation:
|
| 29 |
-
return observation
|
|
|
|
| 26 |
"""Default transform that passes through unchanged."""
|
| 27 |
|
| 28 |
def __call__(self, observation: Observation) -> Observation:
|
| 29 |
+
return observation
|
src/core/env_server/exceptions.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Custom exceptions for environment server operations."""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OpenEnvError(Exception):
|
| 13 |
+
"""Base exception for all OpenEnv errors."""
|
| 14 |
+
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ConcurrencyConfigurationError(OpenEnvError):
|
| 19 |
+
"""
|
| 20 |
+
Raised when an environment is misconfigured for concurrent sessions.
|
| 21 |
+
|
| 22 |
+
This error is raised during server startup when max_concurrent_envs > 1
|
| 23 |
+
is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
environment_name: str,
|
| 29 |
+
max_concurrent_envs: int,
|
| 30 |
+
message: Optional[str] = None,
|
| 31 |
+
):
|
| 32 |
+
self.environment_name = environment_name
|
| 33 |
+
self.max_concurrent_envs = max_concurrent_envs
|
| 34 |
+
|
| 35 |
+
if message is None:
|
| 36 |
+
message = (
|
| 37 |
+
f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. "
|
| 38 |
+
f"Cannot run with max_concurrent_envs={max_concurrent_envs}. "
|
| 39 |
+
f"Either set max_concurrent_envs=1 or ensure the environment "
|
| 40 |
+
f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
super().__init__(message)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SessionCapacityError(OpenEnvError):
|
| 47 |
+
"""
|
| 48 |
+
Raised when the server cannot accept new sessions due to capacity limits.
|
| 49 |
+
|
| 50 |
+
This error is raised when a new WebSocket connection is attempted but
|
| 51 |
+
the server has already reached max_concurrent_envs active sessions.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
active_sessions: int,
|
| 57 |
+
max_sessions: int,
|
| 58 |
+
message: Optional[str] = None,
|
| 59 |
+
):
|
| 60 |
+
self.active_sessions = active_sessions
|
| 61 |
+
self.max_sessions = max_sessions
|
| 62 |
+
|
| 63 |
+
if message is None:
|
| 64 |
+
message = (
|
| 65 |
+
f"Server at capacity: {active_sessions}/{max_sessions} sessions active. "
|
| 66 |
+
f"Cannot accept new connections."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
super().__init__(message)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class SessionNotFoundError(OpenEnvError):
|
| 73 |
+
"""Raised when attempting to access a session that does not exist."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, session_id: str, message: Optional[str] = None):
|
| 76 |
+
self.session_id = session_id
|
| 77 |
+
|
| 78 |
+
if message is None:
|
| 79 |
+
message = f"Session '{session_id}' not found."
|
| 80 |
+
|
| 81 |
+
super().__init__(message)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class SessionCreationError(OpenEnvError):
|
| 85 |
+
"""Raised when a session cannot be created."""
|
| 86 |
+
|
| 87 |
+
def __init__(self, reason: str, message: Optional[str] = None):
|
| 88 |
+
self.reason = reason
|
| 89 |
+
|
| 90 |
+
if message is None:
|
| 91 |
+
message = f"Failed to create session: {reason}"
|
| 92 |
+
|
| 93 |
+
super().__init__(message)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class EnvironmentFactoryError(OpenEnvError):
|
| 97 |
+
"""Raised when the environment factory fails to create an instance."""
|
| 98 |
+
|
| 99 |
+
def __init__(self, factory_name: str, message: Optional[str] = None):
|
| 100 |
+
self.factory_name = factory_name
|
| 101 |
+
|
| 102 |
+
if message is None:
|
| 103 |
+
message = f"Environment factory '{factory_name}' failed to create instance."
|
| 104 |
+
|
| 105 |
+
super().__init__(message)
|
src/core/env_server/gradio_theme.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Unified terminal-style theme for OpenEnv Gradio UI (light/dark)."""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
_MONO_FONTS = (
|
| 14 |
+
"JetBrains Mono",
|
| 15 |
+
"Fira Code",
|
| 16 |
+
"Cascadia Code",
|
| 17 |
+
"Consolas",
|
| 18 |
+
"ui-monospace",
|
| 19 |
+
"monospace",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
_CORE_FONT = (
|
| 23 |
+
"Lato",
|
| 24 |
+
"Inter",
|
| 25 |
+
"Arial",
|
| 26 |
+
"Helvetica",
|
| 27 |
+
"sans-serif",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
_ZERO_RADIUS = gr.themes.Size(
|
| 31 |
+
xxs="0px",
|
| 32 |
+
xs="0px",
|
| 33 |
+
sm="0px",
|
| 34 |
+
md="0px",
|
| 35 |
+
lg="0px",
|
| 36 |
+
xl="0px",
|
| 37 |
+
xxl="0px",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
_GREEN_HUE = gr.themes.Color(
|
| 41 |
+
c50="#e6f4ea",
|
| 42 |
+
c100="#ceead6",
|
| 43 |
+
c200="#a8dab5",
|
| 44 |
+
c300="#6fcc8b",
|
| 45 |
+
c400="#3fb950",
|
| 46 |
+
c500="#238636",
|
| 47 |
+
c600="#1a7f37",
|
| 48 |
+
c700="#116329",
|
| 49 |
+
c800="#0a4620",
|
| 50 |
+
c900="#033a16",
|
| 51 |
+
c950="#04200d",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
_NEUTRAL_HUE = gr.themes.Color(
|
| 55 |
+
c50="#f6f8fa",
|
| 56 |
+
c100="#eaeef2",
|
| 57 |
+
c200="#d0d7de",
|
| 58 |
+
c300="#afb8c1",
|
| 59 |
+
c400="#8c959f",
|
| 60 |
+
c500="#6e7781",
|
| 61 |
+
c600="#57606a",
|
| 62 |
+
c700="#424a53",
|
| 63 |
+
c800="#32383f",
|
| 64 |
+
c900="#24292f",
|
| 65 |
+
c950="#1b1f24",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
OPENENV_GRADIO_THEME = gr.themes.Base(
|
| 69 |
+
primary_hue=_GREEN_HUE,
|
| 70 |
+
secondary_hue=_NEUTRAL_HUE,
|
| 71 |
+
neutral_hue=_NEUTRAL_HUE,
|
| 72 |
+
font=_CORE_FONT,
|
| 73 |
+
font_mono=_MONO_FONTS,
|
| 74 |
+
radius_size=_ZERO_RADIUS,
|
| 75 |
+
).set(
|
| 76 |
+
body_background_fill="#ffffff",
|
| 77 |
+
background_fill_primary="#ffffff",
|
| 78 |
+
background_fill_secondary="#f6f8fa",
|
| 79 |
+
block_background_fill="#ffffff",
|
| 80 |
+
block_border_color="#ffffff",
|
| 81 |
+
block_label_text_color="#57606a",
|
| 82 |
+
block_title_text_color="#24292f",
|
| 83 |
+
border_color_primary="#d0d7de",
|
| 84 |
+
input_background_fill="#ffffff",
|
| 85 |
+
input_border_color="#d0d7de",
|
| 86 |
+
button_primary_background_fill="#1a7f37",
|
| 87 |
+
button_primary_background_fill_hover="#116329",
|
| 88 |
+
button_primary_text_color="#ffffff",
|
| 89 |
+
button_secondary_background_fill="#f6f8fa",
|
| 90 |
+
button_secondary_background_fill_hover="#eaeef2",
|
| 91 |
+
button_secondary_text_color="#24292f",
|
| 92 |
+
button_secondary_border_color="#d0d7de",
|
| 93 |
+
body_background_fill_dark="#0d1117",
|
| 94 |
+
background_fill_primary_dark="#0d1117",
|
| 95 |
+
background_fill_secondary_dark="#0d1117",
|
| 96 |
+
block_background_fill_dark="#0d1117",
|
| 97 |
+
block_border_color_dark="#0d1117",
|
| 98 |
+
block_label_text_color_dark="#8b949e",
|
| 99 |
+
block_title_text_color_dark="#c9d1d9",
|
| 100 |
+
border_color_primary_dark="#30363d",
|
| 101 |
+
input_background_fill_dark="#0d1117",
|
| 102 |
+
input_border_color_dark="#30363d",
|
| 103 |
+
button_primary_background_fill_dark="#30363d",
|
| 104 |
+
button_primary_background_fill_hover_dark="#484f58",
|
| 105 |
+
button_primary_text_color_dark="#c9d1d9",
|
| 106 |
+
button_secondary_background_fill_dark="#21262d",
|
| 107 |
+
button_secondary_background_fill_hover_dark="#30363d",
|
| 108 |
+
button_secondary_text_color_dark="#c9d1d9",
|
| 109 |
+
button_secondary_border_color_dark="#30363d",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
OPENENV_GRADIO_CSS = """
|
| 113 |
+
* { border-radius: 0 !important; }
|
| 114 |
+
.col-left { padding: 16px !important; }
|
| 115 |
+
.col-right { padding: 16px !important; }
|
| 116 |
+
.prose, .markdown-text, .md,
|
| 117 |
+
.prose > *, .markdown-text > * {
|
| 118 |
+
background: transparent !important;
|
| 119 |
+
border: none !important;
|
| 120 |
+
box-shadow: none !important;
|
| 121 |
+
}
|
| 122 |
+
.dark .col-left {
|
| 123 |
+
border-left-color: rgba(139, 148, 158, 0.4) !important;
|
| 124 |
+
}
|
| 125 |
+
.dark .col-right {
|
| 126 |
+
border-left-color: rgba(201, 209, 217, 0.3) !important;
|
| 127 |
+
}
|
| 128 |
+
"""
|
src/core/env_server/gradio_ui.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Gradio-based web UI for OpenEnv environments.
|
| 9 |
+
|
| 10 |
+
Replaces the legacy HTML/JavaScript interface when ENABLE_WEB_INTERFACE is set.
|
| 11 |
+
Mount at /web via gr.mount_gradio_app() from create_web_interface_app().
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import re
|
| 18 |
+
from typing import Any, Dict, List, Optional
|
| 19 |
+
|
| 20 |
+
import gradio as gr
|
| 21 |
+
|
| 22 |
+
from .types import EnvironmentMetadata
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _escape_md(text: str) -> str:
|
| 26 |
+
"""Escape Markdown special characters in user-controlled content."""
|
| 27 |
+
return re.sub(r"([\\`*_\{\}\[\]()#+\-.!|~>])", r"\\\1", str(text))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _format_observation(data: Dict[str, Any]) -> str:
|
| 31 |
+
"""Format reset/step response for Markdown display."""
|
| 32 |
+
lines: List[str] = []
|
| 33 |
+
obs = data.get("observation", {})
|
| 34 |
+
if isinstance(obs, dict):
|
| 35 |
+
if obs.get("prompt"):
|
| 36 |
+
lines.append(f"**Prompt:**\n\n{_escape_md(obs['prompt'])}\n")
|
| 37 |
+
messages = obs.get("messages", [])
|
| 38 |
+
if messages:
|
| 39 |
+
lines.append("**Messages:**\n")
|
| 40 |
+
for msg in messages:
|
| 41 |
+
sender = _escape_md(str(msg.get("sender_id", "?")))
|
| 42 |
+
content = _escape_md(str(msg.get("content", "")))
|
| 43 |
+
cat = _escape_md(str(msg.get("category", "")))
|
| 44 |
+
lines.append(f"- `[{cat}]` Player {sender}: {content}")
|
| 45 |
+
lines.append("")
|
| 46 |
+
reward = data.get("reward")
|
| 47 |
+
done = data.get("done")
|
| 48 |
+
if reward is not None:
|
| 49 |
+
lines.append(f"**Reward:** `{reward}`")
|
| 50 |
+
if done is not None:
|
| 51 |
+
lines.append(f"**Done:** `{done}`")
|
| 52 |
+
return "\n".join(lines) if lines else "*No observation data*"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _readme_section(metadata: Optional[EnvironmentMetadata]) -> str:
|
| 56 |
+
"""README content for the left panel."""
|
| 57 |
+
if not metadata or not metadata.readme_content:
|
| 58 |
+
return "*No README available.*"
|
| 59 |
+
return metadata.readme_content
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_gradio_display_title(
|
| 63 |
+
metadata: Optional[EnvironmentMetadata],
|
| 64 |
+
fallback: str = "OpenEnv Environment",
|
| 65 |
+
) -> str:
|
| 66 |
+
"""Return the title used for the Gradio app (browser tab and Blocks)."""
|
| 67 |
+
name = metadata.name if metadata else fallback
|
| 68 |
+
return f"OpenEnv Agentic Environment: {name}"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_gradio_app(
|
| 72 |
+
web_manager: Any,
|
| 73 |
+
action_fields: List[Dict[str, Any]],
|
| 74 |
+
metadata: Optional[EnvironmentMetadata],
|
| 75 |
+
is_chat_env: bool,
|
| 76 |
+
title: str = "OpenEnv Environment",
|
| 77 |
+
quick_start_md: Optional[str] = None,
|
| 78 |
+
) -> gr.Blocks:
|
| 79 |
+
"""
|
| 80 |
+
Build a Gradio Blocks app for the OpenEnv web interface.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
web_manager: WebInterfaceManager (reset/step_environment, get_state).
|
| 84 |
+
action_fields: Field dicts from _extract_action_fields(action_cls).
|
| 85 |
+
metadata: Environment metadata for README/name.
|
| 86 |
+
is_chat_env: If True, single message textbox; else form from action_fields.
|
| 87 |
+
title: App title (overridden by metadata.name when present; see get_gradio_display_title).
|
| 88 |
+
quick_start_md: Optional Quick Start markdown (class names already replaced).
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
gr.Blocks to mount with gr.mount_gradio_app(app, blocks, path="/web").
|
| 92 |
+
"""
|
| 93 |
+
readme_content = _readme_section(metadata)
|
| 94 |
+
display_title = get_gradio_display_title(metadata, fallback=title)
|
| 95 |
+
|
| 96 |
+
async def reset_env():
|
| 97 |
+
try:
|
| 98 |
+
data = await web_manager.reset_environment()
|
| 99 |
+
obs_md = _format_observation(data)
|
| 100 |
+
return (
|
| 101 |
+
obs_md,
|
| 102 |
+
json.dumps(data, indent=2),
|
| 103 |
+
"Environment reset successfully.",
|
| 104 |
+
)
|
| 105 |
+
except Exception as e:
|
| 106 |
+
return ("", "", f"Error: {e}")
|
| 107 |
+
|
| 108 |
+
def _step_with_action(action_data: Dict[str, Any]):
|
| 109 |
+
async def _run():
|
| 110 |
+
try:
|
| 111 |
+
data = await web_manager.step_environment(action_data)
|
| 112 |
+
obs_md = _format_observation(data)
|
| 113 |
+
return (
|
| 114 |
+
obs_md,
|
| 115 |
+
json.dumps(data, indent=2),
|
| 116 |
+
"Step complete.",
|
| 117 |
+
)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
return ("", "", f"Error: {e}")
|
| 120 |
+
|
| 121 |
+
return _run
|
| 122 |
+
|
| 123 |
+
async def step_chat(message: str):
|
| 124 |
+
if not (message or str(message).strip()):
|
| 125 |
+
return ("", "", "Please enter an action message.")
|
| 126 |
+
action = {"message": str(message).strip()}
|
| 127 |
+
return await _step_with_action(action)()
|
| 128 |
+
|
| 129 |
+
def get_state_sync():
|
| 130 |
+
try:
|
| 131 |
+
data = web_manager.get_state()
|
| 132 |
+
return json.dumps(data, indent=2)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
return f"Error: {e}"
|
| 135 |
+
|
| 136 |
+
with gr.Blocks(title=display_title) as demo:
|
| 137 |
+
with gr.Row():
|
| 138 |
+
with gr.Column(scale=1, elem_classes="col-left"):
|
| 139 |
+
if quick_start_md:
|
| 140 |
+
with gr.Accordion("Quick Start", open=True):
|
| 141 |
+
gr.Markdown(quick_start_md)
|
| 142 |
+
with gr.Accordion("README", open=False):
|
| 143 |
+
gr.Markdown(readme_content)
|
| 144 |
+
|
| 145 |
+
with gr.Column(scale=2, elem_classes="col-right"):
|
| 146 |
+
obs_display = gr.Markdown(
|
| 147 |
+
value=("# Playground\n\nClick **Reset** to start a new episode."),
|
| 148 |
+
)
|
| 149 |
+
with gr.Group():
|
| 150 |
+
if is_chat_env:
|
| 151 |
+
action_input = gr.Textbox(
|
| 152 |
+
label="Action message",
|
| 153 |
+
placeholder="e.g. Enter your message...",
|
| 154 |
+
)
|
| 155 |
+
step_inputs = [action_input]
|
| 156 |
+
step_fn = step_chat
|
| 157 |
+
else:
|
| 158 |
+
step_inputs = []
|
| 159 |
+
for field in action_fields:
|
| 160 |
+
name = field["name"]
|
| 161 |
+
field_type = field.get("type", "text")
|
| 162 |
+
label = name.replace("_", " ").title()
|
| 163 |
+
placeholder = field.get("placeholder", "")
|
| 164 |
+
if field_type == "checkbox":
|
| 165 |
+
inp = gr.Checkbox(label=label)
|
| 166 |
+
elif field_type == "number":
|
| 167 |
+
inp = gr.Number(label=label)
|
| 168 |
+
elif field_type == "select":
|
| 169 |
+
choices = field.get("choices") or []
|
| 170 |
+
inp = gr.Dropdown(
|
| 171 |
+
choices=choices,
|
| 172 |
+
label=label,
|
| 173 |
+
allow_custom_value=False,
|
| 174 |
+
)
|
| 175 |
+
elif field_type in ("textarea", "tensor"):
|
| 176 |
+
inp = gr.Textbox(
|
| 177 |
+
label=label,
|
| 178 |
+
placeholder=placeholder,
|
| 179 |
+
lines=3,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
inp = gr.Textbox(
|
| 183 |
+
label=label,
|
| 184 |
+
placeholder=placeholder,
|
| 185 |
+
)
|
| 186 |
+
step_inputs.append(inp)
|
| 187 |
+
|
| 188 |
+
async def step_form(*values):
|
| 189 |
+
if not action_fields:
|
| 190 |
+
return await _step_with_action({})()
|
| 191 |
+
action_data = {}
|
| 192 |
+
for i, field in enumerate(action_fields):
|
| 193 |
+
if i >= len(values):
|
| 194 |
+
break
|
| 195 |
+
name = field["name"]
|
| 196 |
+
val = values[i]
|
| 197 |
+
if field.get("type") == "checkbox":
|
| 198 |
+
action_data[name] = bool(val)
|
| 199 |
+
elif val is not None and val != "":
|
| 200 |
+
action_data[name] = val
|
| 201 |
+
return await _step_with_action(action_data)()
|
| 202 |
+
|
| 203 |
+
step_fn = step_form
|
| 204 |
+
|
| 205 |
+
with gr.Row():
|
| 206 |
+
step_btn = gr.Button("Step", variant="primary")
|
| 207 |
+
reset_btn = gr.Button("Reset", variant="secondary")
|
| 208 |
+
state_btn = gr.Button("Get state", variant="secondary")
|
| 209 |
+
with gr.Row():
|
| 210 |
+
status = gr.Textbox(
|
| 211 |
+
label="Status",
|
| 212 |
+
interactive=False,
|
| 213 |
+
)
|
| 214 |
+
raw_json = gr.Code(
|
| 215 |
+
label="Raw JSON response",
|
| 216 |
+
language="json",
|
| 217 |
+
interactive=False,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
reset_btn.click(
|
| 221 |
+
fn=reset_env,
|
| 222 |
+
outputs=[obs_display, raw_json, status],
|
| 223 |
+
)
|
| 224 |
+
step_btn.click(
|
| 225 |
+
fn=step_fn,
|
| 226 |
+
inputs=step_inputs,
|
| 227 |
+
outputs=[obs_display, raw_json, status],
|
| 228 |
+
)
|
| 229 |
+
if is_chat_env:
|
| 230 |
+
action_input.submit(
|
| 231 |
+
fn=step_fn,
|
| 232 |
+
inputs=step_inputs,
|
| 233 |
+
outputs=[obs_display, raw_json, status],
|
| 234 |
+
)
|
| 235 |
+
state_btn.click(
|
| 236 |
+
fn=get_state_sync,
|
| 237 |
+
outputs=[raw_json],
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return demo
|
src/core/env_server/http_server.py
CHANGED
|
@@ -8,25 +8,113 @@
|
|
| 8 |
HTTP server wrapper for Environment instances.
|
| 9 |
|
| 10 |
This module provides utilities to wrap any Environment subclass and expose it
|
| 11 |
-
over HTTP endpoints that
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
import os
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
from .interfaces import Environment
|
| 21 |
-
from .
|
| 22 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class HTTPEnvServer:
|
| 25 |
"""
|
| 26 |
HTTP server wrapper for Environment instances.
|
| 27 |
|
| 28 |
This class wraps an Environment and exposes its reset(), step(), and state
|
| 29 |
-
methods as HTTP endpoints compatible with
|
| 30 |
|
| 31 |
The server expects:
|
| 32 |
- Action deserialization: Converts JSON dict to Action subclass
|
|
@@ -35,9 +123,15 @@ class HTTPEnvServer:
|
|
| 35 |
Example:
|
| 36 |
>>> from core.env_server import HTTPEnvServer
|
| 37 |
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
|
|
|
| 38 |
>>>
|
| 39 |
-
>>>
|
| 40 |
-
>>> server = HTTPEnvServer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
>>>
|
| 42 |
>>> # Register routes with FastAPI
|
| 43 |
>>> from fastapi import FastAPI
|
|
@@ -47,178 +141,1177 @@ class HTTPEnvServer:
|
|
| 47 |
|
| 48 |
def __init__(
|
| 49 |
self,
|
| 50 |
-
env: Environment,
|
| 51 |
action_cls: Type[Action],
|
| 52 |
observation_cls: Type[Observation],
|
|
|
|
|
|
|
| 53 |
):
|
| 54 |
"""
|
| 55 |
Initialize HTTP server wrapper.
|
| 56 |
|
| 57 |
Args:
|
| 58 |
-
env:
|
|
|
|
| 59 |
action_cls: The Action subclass this environment expects
|
| 60 |
observation_cls: The Observation subclass this environment returns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
"""
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
self.action_cls = action_cls
|
| 64 |
self.observation_cls = observation_cls
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
"""
|
| 68 |
-
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
"""
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
if
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
return self._serialize_observation(observation)
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
|
|
|
| 110 |
|
| 111 |
-
def
|
| 112 |
"""
|
| 113 |
-
|
| 114 |
|
| 115 |
Args:
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
"""
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
return action
|
| 130 |
|
| 131 |
-
def
|
| 132 |
"""
|
| 133 |
-
|
| 134 |
|
| 135 |
Args:
|
| 136 |
-
|
| 137 |
|
| 138 |
Returns:
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
"
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
"""
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
# Return in HTTPEnvClient expected format
|
| 156 |
-
return {
|
| 157 |
-
"observation": obs_dict,
|
| 158 |
-
"reward": reward,
|
| 159 |
-
"done": done,
|
| 160 |
-
}
|
| 161 |
|
| 162 |
def create_app(
|
| 163 |
-
env: Environment,
|
| 164 |
action_cls: Type[Action],
|
| 165 |
observation_cls: Type[Observation],
|
| 166 |
env_name: Optional[str] = None,
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
| 168 |
"""
|
| 169 |
Create a FastAPI application with or without web interface.
|
| 170 |
-
|
| 171 |
This function creates a FastAPI app with the web interface enabled by default,
|
| 172 |
including README integration for better user experience.
|
| 173 |
-
|
| 174 |
Args:
|
| 175 |
-
env:
|
| 176 |
action_cls: The Action subclass this environment expects
|
| 177 |
observation_cls: The Observation subclass this environment returns
|
| 178 |
env_name: Optional environment name for README loading
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
Returns:
|
| 181 |
FastAPI application instance with or without web interface and README integration
|
| 182 |
"""
|
| 183 |
# Check if web interface should be enabled
|
| 184 |
# This can be controlled via environment variable or build argument
|
| 185 |
-
enable_web = (
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
if enable_web:
|
| 190 |
-
#
|
| 191 |
from .web_interface import create_web_interface_app
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
else:
|
| 194 |
# Use standard FastAPI app without web interface
|
| 195 |
-
return create_fastapi_app(
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
|
| 198 |
def create_fastapi_app(
|
| 199 |
-
env: Environment,
|
| 200 |
action_cls: Type[Action],
|
| 201 |
observation_cls: Type[Observation],
|
| 202 |
-
|
|
|
|
|
|
|
| 203 |
"""
|
| 204 |
-
Create a FastAPI application with
|
| 205 |
|
| 206 |
Args:
|
| 207 |
-
env:
|
| 208 |
action_cls: The Action subclass this environment expects
|
| 209 |
observation_cls: The Observation subclass this environment returns
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
Returns:
|
| 212 |
-
FastAPI application instance
|
| 213 |
-
|
| 214 |
-
Example:
|
| 215 |
-
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
| 216 |
-
>>> from envs.coding_env.models import CodeAction, CodeObservation
|
| 217 |
-
>>>
|
| 218 |
-
>>> env = CodeExecutionEnvironment()
|
| 219 |
-
>>> app = create_fastapi_app(env, CodeAction, CodeObservation)
|
| 220 |
-
>>>
|
| 221 |
-
>>> # Run with: uvicorn module:app --host 0.0.0.0 --port 8000
|
| 222 |
"""
|
| 223 |
try:
|
| 224 |
from fastapi import FastAPI
|
|
@@ -227,7 +1320,72 @@ def create_fastapi_app(
|
|
| 227 |
"FastAPI is required. Install with: pip install fastapi uvicorn"
|
| 228 |
)
|
| 229 |
|
| 230 |
-
app = FastAPI(
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
server.register_routes(app)
|
| 233 |
return app
|
|
|
|
| 8 |
HTTP server wrapper for Environment instances.
|
| 9 |
|
| 10 |
This module provides utilities to wrap any Environment subclass and expose it
|
| 11 |
+
over HTTP and WebSocket endpoints that EnvClient can consume.
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
| 16 |
+
import asyncio
|
| 17 |
+
import inspect
|
| 18 |
+
import json
|
| 19 |
import os
|
| 20 |
+
import time
|
| 21 |
+
import uuid
|
| 22 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 23 |
+
from typing import Any, Callable, Dict, Optional, Type
|
| 24 |
+
|
| 25 |
+
from fastapi import (
|
| 26 |
+
Body,
|
| 27 |
+
FastAPI,
|
| 28 |
+
HTTPException,
|
| 29 |
+
Request,
|
| 30 |
+
status,
|
| 31 |
+
WebSocket,
|
| 32 |
+
WebSocketDisconnect,
|
| 33 |
+
)
|
| 34 |
+
from pydantic import ValidationError
|
| 35 |
|
| 36 |
from .interfaces import Environment
|
| 37 |
+
from .mcp_environment import get_server_tools
|
| 38 |
+
from .mcp_types import (
|
| 39 |
+
JsonRpcErrorCode,
|
| 40 |
+
JsonRpcRequest,
|
| 41 |
+
JsonRpcResponse,
|
| 42 |
+
McpMethod,
|
| 43 |
+
WSMCPMessage,
|
| 44 |
+
WSMCPResponse,
|
| 45 |
+
)
|
| 46 |
+
from .route_config import GetEndpointConfig, register_get_endpoints
|
| 47 |
+
from .serialization import deserialize_action, serialize_observation
|
| 48 |
+
from .types import (
|
| 49 |
+
Action,
|
| 50 |
+
ConcurrencyConfig,
|
| 51 |
+
EnvironmentMetadata,
|
| 52 |
+
HealthResponse,
|
| 53 |
+
HealthStatus,
|
| 54 |
+
Observation,
|
| 55 |
+
ResetRequest,
|
| 56 |
+
ResetResponse,
|
| 57 |
+
SchemaResponse,
|
| 58 |
+
ServerCapacityStatus,
|
| 59 |
+
ServerMode,
|
| 60 |
+
SessionInfo,
|
| 61 |
+
State,
|
| 62 |
+
StepRequest,
|
| 63 |
+
StepResponse,
|
| 64 |
+
WSCloseMessage,
|
| 65 |
+
WSErrorCode,
|
| 66 |
+
WSErrorResponse,
|
| 67 |
+
WSObservationResponse,
|
| 68 |
+
WSResetMessage,
|
| 69 |
+
WSStateMessage,
|
| 70 |
+
WSStateResponse,
|
| 71 |
+
WSStepMessage,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _make_json_serializable(obj: Any) -> Any:
|
| 76 |
+
"""
|
| 77 |
+
Convert an object to a JSON-serializable form.
|
| 78 |
+
|
| 79 |
+
Handles Pydantic models, dataclasses, and other common types.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
obj: The object to convert
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
A JSON-serializable representation of the object
|
| 86 |
+
"""
|
| 87 |
+
if obj is None:
|
| 88 |
+
return None
|
| 89 |
+
if isinstance(obj, (str, int, float, bool)):
|
| 90 |
+
return obj
|
| 91 |
+
if isinstance(obj, (list, tuple)):
|
| 92 |
+
return [_make_json_serializable(item) for item in obj]
|
| 93 |
+
if isinstance(obj, dict):
|
| 94 |
+
return {k: _make_json_serializable(v) for k, v in obj.items()}
|
| 95 |
+
if hasattr(obj, "model_dump"):
|
| 96 |
+
# Pydantic model
|
| 97 |
+
return obj.model_dump()
|
| 98 |
+
if hasattr(obj, "__dict__"):
|
| 99 |
+
# Object with __dict__
|
| 100 |
+
return {k: _make_json_serializable(v) for k, v in obj.__dict__.items()}
|
| 101 |
+
# Fallback to string representation
|
| 102 |
+
return str(obj)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
from .exceptions import (
|
| 106 |
+
ConcurrencyConfigurationError,
|
| 107 |
+
EnvironmentFactoryError,
|
| 108 |
+
SessionCapacityError,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
|
| 112 |
class HTTPEnvServer:
|
| 113 |
"""
|
| 114 |
HTTP server wrapper for Environment instances.
|
| 115 |
|
| 116 |
This class wraps an Environment and exposes its reset(), step(), and state
|
| 117 |
+
methods as HTTP and WebSocket endpoints compatible with EnvClient.
|
| 118 |
|
| 119 |
The server expects:
|
| 120 |
- Action deserialization: Converts JSON dict to Action subclass
|
|
|
|
| 123 |
Example:
|
| 124 |
>>> from core.env_server import HTTPEnvServer
|
| 125 |
>>> from envs.coding_env.server import CodeExecutionEnvironment
|
| 126 |
+
>>> from envs.coding_env.models import CodeAction, CodeObservation
|
| 127 |
>>>
|
| 128 |
+
>>> # Pass environment class (factory pattern)
|
| 129 |
+
>>> server = HTTPEnvServer(
|
| 130 |
+
... env=CodeExecutionEnvironment,
|
| 131 |
+
... action_cls=CodeAction,
|
| 132 |
+
... observation_cls=CodeObservation,
|
| 133 |
+
... max_concurrent_envs=4,
|
| 134 |
+
... )
|
| 135 |
>>>
|
| 136 |
>>> # Register routes with FastAPI
|
| 137 |
>>> from fastapi import FastAPI
|
|
|
|
| 141 |
|
| 142 |
def __init__(
|
| 143 |
self,
|
| 144 |
+
env: Callable[[], Environment],
|
| 145 |
action_cls: Type[Action],
|
| 146 |
observation_cls: Type[Observation],
|
| 147 |
+
max_concurrent_envs: Optional[int] = None,
|
| 148 |
+
concurrency_config: Optional[ConcurrencyConfig] = None,
|
| 149 |
):
|
| 150 |
"""
|
| 151 |
Initialize HTTP server wrapper.
|
| 152 |
|
| 153 |
Args:
|
| 154 |
+
env: Environment factory (callable) that creates new instances.
|
| 155 |
+
Will be called to create a new environment for each WebSocket session.
|
| 156 |
action_cls: The Action subclass this environment expects
|
| 157 |
observation_cls: The Observation subclass this environment returns
|
| 158 |
+
max_concurrent_envs: Maximum number of concurrent WebSocket sessions.
|
| 159 |
+
Mutually exclusive with concurrency_config.
|
| 160 |
+
concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
|
| 161 |
+
Mutually exclusive with max_concurrent_envs.
|
| 162 |
+
|
| 163 |
+
Raises:
|
| 164 |
+
ValueError: If both max_concurrent_envs and concurrency_config are provided.
|
| 165 |
+
ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
|
| 166 |
+
environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
|
| 167 |
"""
|
| 168 |
+
# Validate that env is callable
|
| 169 |
+
if not callable(env):
|
| 170 |
+
raise TypeError(
|
| 171 |
+
f"env must be a callable (class or factory function), got {type(env)}. "
|
| 172 |
+
f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())."
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self._env_factory: Callable[[], Environment] = env
|
| 176 |
+
|
| 177 |
+
# Handle concurrency configuration
|
| 178 |
+
if max_concurrent_envs is not None and concurrency_config is not None:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. "
|
| 181 |
+
"Please use only one method to configure concurrency."
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if concurrency_config is not None:
|
| 185 |
+
self._concurrency_config = concurrency_config
|
| 186 |
+
elif max_concurrent_envs is not None:
|
| 187 |
+
self._concurrency_config = ConcurrencyConfig(
|
| 188 |
+
max_concurrent_envs=max_concurrent_envs,
|
| 189 |
+
session_timeout=None,
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
# Default configuration
|
| 193 |
+
self._concurrency_config = ConcurrencyConfig(
|
| 194 |
+
max_concurrent_envs=1,
|
| 195 |
+
session_timeout=None,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs
|
| 199 |
+
|
| 200 |
+
# Validate concurrency configuration
|
| 201 |
+
self._validate_concurrency_safety()
|
| 202 |
+
|
| 203 |
self.action_cls = action_cls
|
| 204 |
self.observation_cls = observation_cls
|
| 205 |
|
| 206 |
+
# Session management for WebSocket connections
|
| 207 |
+
self._sessions: Dict[str, Environment] = {}
|
| 208 |
+
self._session_executors: Dict[str, ThreadPoolExecutor] = {}
|
| 209 |
+
self._session_info: Dict[str, SessionInfo] = {}
|
| 210 |
+
self._session_lock = asyncio.Lock()
|
| 211 |
+
|
| 212 |
+
# Create thread pool for running sync code in async context
|
| 213 |
+
# This is needed for environments using sync libraries (e.g., Playwright)
|
| 214 |
+
self._executor = ThreadPoolExecutor(max_workers=32)
|
| 215 |
+
|
| 216 |
+
def _validate_concurrency_safety(self) -> None:
|
| 217 |
"""
|
| 218 |
+
Validate that the environment supports the configured concurrency level.
|
| 219 |
|
| 220 |
+
Raises:
|
| 221 |
+
ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
|
| 222 |
+
environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
|
| 223 |
"""
|
| 224 |
+
if self._max_concurrent_envs <= 1:
|
| 225 |
+
return
|
| 226 |
|
| 227 |
+
if inspect.isclass(self._env_factory):
|
| 228 |
+
env_cls = self._env_factory
|
| 229 |
+
else:
|
| 230 |
+
_temp_env = self._env_factory()
|
| 231 |
+
env_cls = type(_temp_env)
|
| 232 |
+
_temp_env.close()
|
| 233 |
+
del _temp_env
|
| 234 |
|
| 235 |
+
if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False):
|
| 236 |
+
raise ConcurrencyConfigurationError(
|
| 237 |
+
environment_name=env_cls.__name__,
|
| 238 |
+
max_concurrent_envs=self._max_concurrent_envs,
|
| 239 |
+
)
|
|
|
|
| 240 |
|
| 241 |
+
def get_capacity_status(self) -> ServerCapacityStatus:
|
| 242 |
+
"""
|
| 243 |
+
Get the current capacity status of the server.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
ServerCapacityStatus with current session counts and availability.
|
| 247 |
+
"""
|
| 248 |
+
return ServerCapacityStatus.from_counts(
|
| 249 |
+
active=len(self._sessions),
|
| 250 |
+
max_sessions=self._max_concurrent_envs,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
async def _run_sync_in_thread_pool(
|
| 254 |
+
self, func: Callable[..., Observation], *args, **kwargs
|
| 255 |
+
) -> Observation:
|
| 256 |
+
"""Run a synchronous function in the thread pool executor."""
|
| 257 |
+
loop = asyncio.get_event_loop()
|
| 258 |
+
return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs))
|
| 259 |
+
|
| 260 |
+
def _get_valid_kwargs(
|
| 261 |
+
self,
|
| 262 |
+
sig: inspect.Signature,
|
| 263 |
+
kwargs: Dict[str, Any],
|
| 264 |
+
skip_params: Optional[set[str]] = None,
|
| 265 |
+
) -> Dict[str, Any]:
|
| 266 |
+
"""Filter kwargs to only include parameters accepted by the function signature."""
|
| 267 |
+
if skip_params is None:
|
| 268 |
+
skip_params = set()
|
| 269 |
+
|
| 270 |
+
valid_kwargs = {}
|
| 271 |
+
|
| 272 |
+
has_kwargs = any(
|
| 273 |
+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
for k, v in kwargs.items():
|
| 277 |
+
if k in sig.parameters or has_kwargs:
|
| 278 |
+
if k not in skip_params:
|
| 279 |
+
valid_kwargs[k] = v
|
| 280 |
+
|
| 281 |
+
return valid_kwargs
|
| 282 |
+
|
| 283 |
+
async def _create_session(self) -> tuple[str, Environment]:
|
| 284 |
+
"""
|
| 285 |
+
Create a new WebSocket session with its own environment instance.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
Tuple of (session_id, environment)
|
| 289 |
|
| 290 |
+
Raises:
|
| 291 |
+
SessionCapacityError: If max concurrent sessions reached
|
| 292 |
+
EnvironmentFactoryError: If the factory fails to create an environment
|
| 293 |
+
"""
|
| 294 |
+
async with self._session_lock:
|
| 295 |
+
if len(self._sessions) >= self._max_concurrent_envs:
|
| 296 |
+
raise SessionCapacityError(
|
| 297 |
+
active_sessions=len(self._sessions),
|
| 298 |
+
max_sessions=self._max_concurrent_envs,
|
| 299 |
+
)
|
| 300 |
|
| 301 |
+
session_id = str(uuid.uuid4())
|
| 302 |
+
current_time = time.time()
|
| 303 |
|
| 304 |
+
# Create executor and reserve slot so capacity is not exceeded while
|
| 305 |
+
# we create the env outside the lock (avoids blocking other sessions)
|
| 306 |
+
executor = ThreadPoolExecutor(max_workers=1)
|
| 307 |
+
self._session_executors[session_id] = executor
|
| 308 |
+
self._sessions[session_id] = None # placeholder until env is ready
|
| 309 |
|
| 310 |
+
try:
|
| 311 |
+
# Create environment in the executor thread (outside lock)
|
| 312 |
+
loop = asyncio.get_event_loop()
|
| 313 |
+
env = await loop.run_in_executor(executor, self._env_factory)
|
| 314 |
+
except Exception as e:
|
| 315 |
+
async with self._session_lock:
|
| 316 |
+
executor.shutdown(wait=False)
|
| 317 |
+
self._session_executors.pop(session_id, None)
|
| 318 |
+
self._sessions.pop(session_id, None)
|
| 319 |
+
factory_name = getattr(
|
| 320 |
+
self._env_factory, "__name__", str(self._env_factory)
|
| 321 |
+
)
|
| 322 |
+
raise EnvironmentFactoryError(factory_name) from e
|
| 323 |
|
| 324 |
+
async with self._session_lock:
|
| 325 |
+
self._sessions[session_id] = env
|
| 326 |
+
self._session_info[session_id] = SessionInfo(
|
| 327 |
+
session_id=session_id,
|
| 328 |
+
created_at=current_time,
|
| 329 |
+
last_activity_at=current_time,
|
| 330 |
+
step_count=0,
|
| 331 |
+
environment_type=type(env).__name__,
|
| 332 |
+
)
|
| 333 |
|
| 334 |
+
return session_id, env
|
| 335 |
|
| 336 |
+
async def _destroy_session(self, session_id: str) -> None:
|
| 337 |
"""
|
| 338 |
+
Destroy a WebSocket session and cleanup resources.
|
| 339 |
|
| 340 |
Args:
|
| 341 |
+
session_id: The session ID to destroy
|
| 342 |
+
"""
|
| 343 |
+
async with self._session_lock:
|
| 344 |
+
env = self._sessions.pop(session_id, None)
|
| 345 |
+
executor = self._session_executors.pop(session_id, None)
|
| 346 |
+
self._session_info.pop(session_id, None)
|
| 347 |
|
| 348 |
+
# Run close() in the same executor where the env was created
|
| 349 |
+
# This is required for thread-sensitive libraries like Playwright/greenlet
|
| 350 |
+
if env is not None:
|
| 351 |
+
if executor is not None:
|
| 352 |
+
try:
|
| 353 |
+
loop = asyncio.get_event_loop()
|
| 354 |
+
await loop.run_in_executor(executor, env.close)
|
| 355 |
+
except Exception:
|
| 356 |
+
# If executor close fails, try direct close as fallback
|
| 357 |
+
try:
|
| 358 |
+
env.close()
|
| 359 |
+
except Exception:
|
| 360 |
+
pass # Best effort cleanup
|
| 361 |
+
else:
|
| 362 |
+
try:
|
| 363 |
+
env.close()
|
| 364 |
+
except Exception:
|
| 365 |
+
pass # Best effort cleanup
|
| 366 |
+
|
| 367 |
+
# Shutdown executor after close is done
|
| 368 |
+
if executor is not None:
|
| 369 |
+
executor.shutdown(wait=False)
|
| 370 |
+
|
| 371 |
+
def _update_session_activity(
|
| 372 |
+
self, session_id: str, increment_step: bool = False
|
| 373 |
+
) -> None:
|
| 374 |
+
"""
|
| 375 |
+
Update session activity timestamp and optionally increment step count.
|
| 376 |
|
| 377 |
+
Args:
|
| 378 |
+
session_id: The session ID to update
|
| 379 |
+
increment_step: If True, increment the step count
|
| 380 |
"""
|
| 381 |
+
if session_id in self._session_info:
|
| 382 |
+
self._session_info[session_id].last_activity_at = time.time()
|
| 383 |
+
if increment_step:
|
| 384 |
+
self._session_info[session_id].step_count += 1
|
|
|
|
| 385 |
|
| 386 |
+
def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
|
| 387 |
"""
|
| 388 |
+
Get information about a specific session.
|
| 389 |
|
| 390 |
Args:
|
| 391 |
+
session_id: The session ID to query
|
| 392 |
|
| 393 |
Returns:
|
| 394 |
+
SessionInfo if the session exists, None otherwise
|
| 395 |
+
"""
|
| 396 |
+
return self._session_info.get(session_id)
|
| 397 |
+
|
| 398 |
+
async def _run_in_session_executor(
|
| 399 |
+
self, session_id: str, func: Callable[..., Observation], *args, **kwargs
|
| 400 |
+
) -> Observation:
|
| 401 |
+
"""Run a synchronous function in the session's thread pool executor."""
|
| 402 |
+
executor = self._session_executors.get(session_id, self._executor)
|
| 403 |
+
loop = asyncio.get_event_loop()
|
| 404 |
+
return await loop.run_in_executor(executor, lambda: func(*args, **kwargs))
|
| 405 |
+
|
| 406 |
+
@property
|
| 407 |
+
def active_sessions(self) -> int:
|
| 408 |
+
"""Return the number of active WebSocket sessions."""
|
| 409 |
+
return len(self._sessions)
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def max_concurrent_envs(self) -> int:
|
| 413 |
+
"""Return the maximum number of concurrent environments."""
|
| 414 |
+
return self._max_concurrent_envs
|
| 415 |
+
|
| 416 |
+
@property
|
| 417 |
+
def is_concurrency_safe(self) -> bool:
|
| 418 |
+
"""Return whether the environment is marked as concurrency safe."""
|
| 419 |
+
import inspect
|
| 420 |
|
| 421 |
+
if inspect.isclass(self._env_factory):
|
| 422 |
+
return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False)
|
| 423 |
+
else:
|
| 424 |
+
_temp_env = self._env_factory()
|
| 425 |
+
result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False)
|
| 426 |
+
_temp_env.close()
|
| 427 |
+
del _temp_env
|
| 428 |
+
return result
|
| 429 |
+
|
| 430 |
+
@property
|
| 431 |
+
def concurrency_config(self) -> ConcurrencyConfig:
|
| 432 |
+
"""Return the concurrency configuration."""
|
| 433 |
+
return self._concurrency_config
|
| 434 |
+
|
| 435 |
+
def register_routes(
|
| 436 |
+
self, app: FastAPI, mode: ServerMode | str = ServerMode.SIMULATION
|
| 437 |
+
) -> None:
|
| 438 |
"""
|
| 439 |
+
Register HTTP routes on a FastAPI application.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
app: FastAPI application instance
|
| 443 |
+
mode: Server mode - either SIMULATION or PRODUCTION (or string equivalents).
|
| 444 |
+
In production mode, simulation control endpoints (/reset, /step, /state)
|
| 445 |
+
are NOT registered. Only safe endpoints (/health, /schema, /metadata, /ws)
|
| 446 |
+
are available. Defaults to SIMULATION for backwards compatibility.
|
| 447 |
+
|
| 448 |
+
Raises:
|
| 449 |
+
ValueError: If mode is not a valid ServerMode or string equivalent.
|
| 450 |
+
"""
|
| 451 |
+
# Convert string to ServerMode enum for backwards compatibility
|
| 452 |
+
if isinstance(mode, str):
|
| 453 |
+
try:
|
| 454 |
+
mode = ServerMode(mode.lower())
|
| 455 |
+
except ValueError:
|
| 456 |
+
valid_modes = [m.value for m in ServerMode]
|
| 457 |
+
raise ValueError(
|
| 458 |
+
f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Helper function to handle reset endpoint
|
| 462 |
+
async def reset_handler(
|
| 463 |
+
request: ResetRequest = Body(default_factory=ResetRequest),
|
| 464 |
+
) -> ResetResponse:
|
| 465 |
+
"""Reset endpoint - returns initial observation."""
|
| 466 |
+
_env = self._env_factory()
|
| 467 |
+
|
| 468 |
+
try:
|
| 469 |
+
kwargs = request.model_dump(exclude_unset=True)
|
| 470 |
+
|
| 471 |
+
is_async = _env.reset_async.__func__ is not Environment.reset_async
|
| 472 |
+
|
| 473 |
+
if is_async:
|
| 474 |
+
sig = inspect.signature(_env.reset_async)
|
| 475 |
+
else:
|
| 476 |
+
sig = inspect.signature(_env.reset)
|
| 477 |
+
valid_kwargs = self._get_valid_kwargs(sig, kwargs)
|
| 478 |
+
|
| 479 |
+
if is_async:
|
| 480 |
+
observation = await _env.reset_async(**valid_kwargs)
|
| 481 |
+
else:
|
| 482 |
+
observation = await self._run_sync_in_thread_pool(
|
| 483 |
+
_env.reset, **valid_kwargs
|
| 484 |
+
)
|
| 485 |
+
return ResetResponse(**serialize_observation(observation))
|
| 486 |
+
finally:
|
| 487 |
+
_env.close()
|
| 488 |
+
|
| 489 |
+
# Helper function to handle step endpoint
|
| 490 |
+
async def step_handler(request: StepRequest) -> StepResponse:
|
| 491 |
+
"""Step endpoint - executes action and returns observation."""
|
| 492 |
+
action_data = request.action
|
| 493 |
+
|
| 494 |
+
try:
|
| 495 |
+
action = deserialize_action(action_data, self.action_cls)
|
| 496 |
+
except ValidationError as e:
|
| 497 |
+
raise HTTPException(
|
| 498 |
+
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
_env = self._env_factory()
|
| 502 |
+
|
| 503 |
+
try:
|
| 504 |
+
kwargs = request.model_dump(exclude_unset=True, exclude={"action"})
|
| 505 |
+
|
| 506 |
+
is_async = _env.step_async.__func__ is not Environment.step_async
|
| 507 |
+
|
| 508 |
+
if is_async:
|
| 509 |
+
sig = inspect.signature(_env.step_async)
|
| 510 |
+
else:
|
| 511 |
+
sig = inspect.signature(_env.step)
|
| 512 |
+
valid_kwargs = self._get_valid_kwargs(
|
| 513 |
+
sig, kwargs, skip_params={"action"}
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if is_async:
|
| 517 |
+
observation = await _env.step_async(action, **valid_kwargs)
|
| 518 |
+
else:
|
| 519 |
+
observation = await self._run_sync_in_thread_pool(
|
| 520 |
+
_env.step, action, **valid_kwargs
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
return StepResponse(**serialize_observation(observation))
|
| 524 |
+
finally:
|
| 525 |
+
_env.close()
|
| 526 |
+
|
| 527 |
+
# Helper function to handle MCP endpoint
|
| 528 |
+
async def mcp_handler(
|
| 529 |
+
request: JsonRpcRequest, session_env: Optional[Environment] = None
|
| 530 |
+
) -> JsonRpcResponse:
|
| 531 |
+
"""
|
| 532 |
+
Handle MCP JSON-RPC requests.
|
| 533 |
+
|
| 534 |
+
Supports tools/list and tools/call methods in JSON-RPC 2.0 format.
|
| 535 |
+
"""
|
| 536 |
+
method = request.method
|
| 537 |
+
request_id = request.id
|
| 538 |
+
|
| 539 |
+
# Use provided session environment or create temporary one
|
| 540 |
+
if session_env is not None:
|
| 541 |
+
_env = session_env
|
| 542 |
+
should_close = False
|
| 543 |
+
else:
|
| 544 |
+
_env = self._env_factory()
|
| 545 |
+
should_close = True
|
| 546 |
+
try:
|
| 547 |
+
if method == McpMethod.TOOLS_LIST:
|
| 548 |
+
# Check if environment is MCP-enabled
|
| 549 |
+
if not hasattr(_env, "mcp_client"):
|
| 550 |
+
return JsonRpcResponse.error_response(
|
| 551 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 552 |
+
"Environment does not support MCP",
|
| 553 |
+
request_id=request_id,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# Use async context manager for MCP client
|
| 557 |
+
async with _env.mcp_client:
|
| 558 |
+
tools = await _env.mcp_client.list_tools()
|
| 559 |
+
|
| 560 |
+
return JsonRpcResponse.success(
|
| 561 |
+
result={
|
| 562 |
+
"tools": [
|
| 563 |
+
t.model_dump() if hasattr(t, "model_dump") else dict(t)
|
| 564 |
+
for t in tools
|
| 565 |
+
]
|
| 566 |
+
},
|
| 567 |
+
request_id=request_id,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
elif method == McpMethod.TOOLS_CALL:
|
| 571 |
+
params = request.params
|
| 572 |
+
tool_name = params.get("name")
|
| 573 |
+
arguments = params.get("arguments", {})
|
| 574 |
+
|
| 575 |
+
if not hasattr(_env, "mcp_client"):
|
| 576 |
+
return JsonRpcResponse.error_response(
|
| 577 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 578 |
+
"Environment does not support MCP",
|
| 579 |
+
request_id=request_id,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
if not tool_name:
|
| 583 |
+
return JsonRpcResponse.error_response(
|
| 584 |
+
JsonRpcErrorCode.INVALID_REQUEST,
|
| 585 |
+
"Missing 'name' in params",
|
| 586 |
+
request_id=request_id,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Use async context manager for MCP client
|
| 590 |
+
async with _env.mcp_client:
|
| 591 |
+
result = await _env.mcp_client.call_tool(
|
| 592 |
+
name=tool_name, arguments=arguments
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Ensure result is JSON serializable
|
| 596 |
+
serializable_result = _make_json_serializable(result)
|
| 597 |
|
| 598 |
+
return JsonRpcResponse.success(
|
| 599 |
+
result=serializable_result,
|
| 600 |
+
request_id=request_id,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
else:
|
| 604 |
+
return JsonRpcResponse.error_response(
|
| 605 |
+
JsonRpcErrorCode.METHOD_NOT_FOUND,
|
| 606 |
+
f"Method not found: {method}",
|
| 607 |
+
request_id=request_id,
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
except Exception as e:
|
| 611 |
+
return JsonRpcResponse.error_response(
|
| 612 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 613 |
+
str(e),
|
| 614 |
+
request_id=request_id,
|
| 615 |
+
)
|
| 616 |
+
finally:
|
| 617 |
+
if should_close:
|
| 618 |
+
_env.close()
|
| 619 |
+
|
| 620 |
+
# Register MCP WebSocket endpoint (available in both production and simulation modes)
|
| 621 |
+
@app.websocket("/mcp")
|
| 622 |
+
async def mcp_websocket_endpoint(websocket: WebSocket):
|
| 623 |
+
"""
|
| 624 |
+
WebSocket endpoint for MCP JSON-RPC requests.
|
| 625 |
+
|
| 626 |
+
Each WebSocket connection gets its own environment instance for MCP operations.
|
| 627 |
+
|
| 628 |
+
Message Protocol:
|
| 629 |
+
- Client sends: JSON-RPC 2.0 request (tools/list, tools/call)
|
| 630 |
+
- Server responds: JSON-RPC 2.0 response (result or error)
|
| 631 |
+
"""
|
| 632 |
+
await websocket.accept()
|
| 633 |
+
|
| 634 |
+
session_id = None
|
| 635 |
+
session_env = None
|
| 636 |
+
|
| 637 |
+
try:
|
| 638 |
+
# Create session with dedicated environment
|
| 639 |
+
session_id, session_env = await self._create_session()
|
| 640 |
+
|
| 641 |
+
while True:
|
| 642 |
+
# Receive message from client
|
| 643 |
+
raw_message = await websocket.receive_text()
|
| 644 |
+
|
| 645 |
+
try:
|
| 646 |
+
jsonrpc_dict = json.loads(raw_message)
|
| 647 |
+
jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
|
| 648 |
+
except json.JSONDecodeError as e:
|
| 649 |
+
error_resp = JsonRpcResponse.error_response(
|
| 650 |
+
JsonRpcErrorCode.PARSE_ERROR,
|
| 651 |
+
f"Parse error: {e}",
|
| 652 |
+
)
|
| 653 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 654 |
+
continue
|
| 655 |
+
except ValidationError as e:
|
| 656 |
+
error_resp = JsonRpcResponse.error_response(
|
| 657 |
+
JsonRpcErrorCode.INVALID_REQUEST,
|
| 658 |
+
f"Invalid request: {e}",
|
| 659 |
+
)
|
| 660 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 661 |
+
continue
|
| 662 |
+
|
| 663 |
+
try:
|
| 664 |
+
# Call mcp_handler with session environment
|
| 665 |
+
response = await mcp_handler(
|
| 666 |
+
jsonrpc_request, session_env=session_env
|
| 667 |
+
)
|
| 668 |
+
await websocket.send_text(response.model_dump_json())
|
| 669 |
+
except Exception as e:
|
| 670 |
+
error_resp = JsonRpcResponse.error_response(
|
| 671 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 672 |
+
str(e),
|
| 673 |
+
request_id=jsonrpc_request.id,
|
| 674 |
+
)
|
| 675 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 676 |
+
|
| 677 |
+
except WebSocketDisconnect:
|
| 678 |
+
pass
|
| 679 |
+
except SessionCapacityError as e:
|
| 680 |
+
error_resp = JsonRpcResponse.error_response(
|
| 681 |
+
JsonRpcErrorCode.SERVER_ERROR,
|
| 682 |
+
str(e),
|
| 683 |
+
data={
|
| 684 |
+
"active_sessions": e.active_sessions,
|
| 685 |
+
"max_sessions": e.max_sessions,
|
| 686 |
+
},
|
| 687 |
+
)
|
| 688 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 689 |
+
except EnvironmentFactoryError as e:
|
| 690 |
+
error_resp = JsonRpcResponse.error_response(
|
| 691 |
+
JsonRpcErrorCode.SERVER_ERROR,
|
| 692 |
+
str(e),
|
| 693 |
+
data={"factory_name": e.factory_name},
|
| 694 |
+
)
|
| 695 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 696 |
+
except Exception as e:
|
| 697 |
+
error_resp = JsonRpcResponse.error_response(
|
| 698 |
+
JsonRpcErrorCode.SERVER_ERROR,
|
| 699 |
+
str(e),
|
| 700 |
+
)
|
| 701 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 702 |
+
finally:
|
| 703 |
+
if session_id:
|
| 704 |
+
await self._destroy_session(session_id)
|
| 705 |
+
try:
|
| 706 |
+
await websocket.close()
|
| 707 |
+
except RuntimeError:
|
| 708 |
+
pass
|
| 709 |
+
|
| 710 |
+
# Register simulation control routes only in simulation mode
|
| 711 |
+
if mode == ServerMode.SIMULATION:
|
| 712 |
+
|
| 713 |
+
@app.post(
|
| 714 |
+
"/reset",
|
| 715 |
+
response_model=ResetResponse,
|
| 716 |
+
tags=["Environment Control"],
|
| 717 |
+
summary="Reset the environment",
|
| 718 |
+
description="""
|
| 719 |
+
Reset the environment to its initial state and return the first observation.
|
| 720 |
+
|
| 721 |
+
You can optionally provide a seed for reproducibility and an episode_id for tracking.
|
| 722 |
+
""",
|
| 723 |
+
responses={
|
| 724 |
+
200: {
|
| 725 |
+
"description": "Environment reset successfully",
|
| 726 |
+
"content": {
|
| 727 |
+
"application/json": {
|
| 728 |
+
"example": {
|
| 729 |
+
"observation": {"status": "ready", "data": {}},
|
| 730 |
+
"reward": None,
|
| 731 |
+
"done": False,
|
| 732 |
+
}
|
| 733 |
+
}
|
| 734 |
+
},
|
| 735 |
+
}
|
| 736 |
+
},
|
| 737 |
+
)
|
| 738 |
+
async def reset(
|
| 739 |
+
request: ResetRequest = Body(default_factory=ResetRequest),
|
| 740 |
+
) -> ResetResponse:
|
| 741 |
+
return await reset_handler(request)
|
| 742 |
+
|
| 743 |
+
@app.post(
|
| 744 |
+
"/step",
|
| 745 |
+
response_model=StepResponse,
|
| 746 |
+
tags=["Environment Control"],
|
| 747 |
+
summary="Execute an action in the environment",
|
| 748 |
+
description="""
|
| 749 |
+
Execute an action in the environment and receive the resulting observation.
|
| 750 |
+
|
| 751 |
+
The action must conform to the environment's action schema, which can be
|
| 752 |
+
retrieved from the `/schema` endpoint. If the action is invalid,
|
| 753 |
+
the endpoint will return HTTP 422 with detailed validation errors.
|
| 754 |
+
|
| 755 |
+
The response includes:
|
| 756 |
+
- **observation**: The environment's response to the action
|
| 757 |
+
- **reward**: Optional reward signal (float or None)
|
| 758 |
+
- **done**: Boolean indicating if the episode has terminated
|
| 759 |
+
""",
|
| 760 |
+
responses={
|
| 761 |
+
200: {
|
| 762 |
+
"description": "Action executed successfully",
|
| 763 |
+
"content": {
|
| 764 |
+
"application/json": {
|
| 765 |
+
"example": {
|
| 766 |
+
"observation": {"status": "success", "data": {}},
|
| 767 |
+
"reward": 1.0,
|
| 768 |
+
"done": False,
|
| 769 |
+
}
|
| 770 |
+
}
|
| 771 |
+
},
|
| 772 |
+
},
|
| 773 |
+
422: {
|
| 774 |
+
"description": "Validation error - invalid action format or values",
|
| 775 |
+
"content": {
|
| 776 |
+
"application/json": {
|
| 777 |
+
"example": {
|
| 778 |
+
"detail": [
|
| 779 |
+
{
|
| 780 |
+
"type": "string_too_short",
|
| 781 |
+
"loc": ["body", "action", "message"],
|
| 782 |
+
"msg": "String should have at least 1 character",
|
| 783 |
+
"input": "",
|
| 784 |
+
}
|
| 785 |
+
]
|
| 786 |
+
}
|
| 787 |
+
}
|
| 788 |
+
},
|
| 789 |
+
},
|
| 790 |
+
500: {
|
| 791 |
+
"description": "Internal server error during action execution"
|
| 792 |
+
},
|
| 793 |
+
},
|
| 794 |
+
)
|
| 795 |
+
async def step(request: StepRequest) -> StepResponse:
|
| 796 |
+
return await step_handler(request)
|
| 797 |
+
|
| 798 |
+
def get_state_handler() -> State:
|
| 799 |
+
_env = self._env_factory()
|
| 800 |
+
try:
|
| 801 |
+
return _env.state
|
| 802 |
+
finally:
|
| 803 |
+
_env.close()
|
| 804 |
+
|
| 805 |
+
def get_metadata_handler() -> EnvironmentMetadata:
|
| 806 |
+
_env = self._env_factory()
|
| 807 |
+
try:
|
| 808 |
+
return _env.get_metadata()
|
| 809 |
+
finally:
|
| 810 |
+
_env.close()
|
| 811 |
+
|
| 812 |
+
# Build list of GET endpoints based on mode
|
| 813 |
+
get_endpoints = [
|
| 814 |
+
GetEndpointConfig(
|
| 815 |
+
path="/metadata",
|
| 816 |
+
handler=get_metadata_handler,
|
| 817 |
+
response_model=EnvironmentMetadata,
|
| 818 |
+
tag="Environment Info",
|
| 819 |
+
summary="Get environment metadata",
|
| 820 |
+
description="""
|
| 821 |
+
Get metadata about this environment.
|
| 822 |
+
|
| 823 |
+
Returns information about the environment including name, description,
|
| 824 |
+
version, author, and documentation links.
|
| 825 |
+
""",
|
| 826 |
+
),
|
| 827 |
+
GetEndpointConfig(
|
| 828 |
+
path="/health",
|
| 829 |
+
handler=lambda: HealthResponse(status=HealthStatus.HEALTHY),
|
| 830 |
+
response_model=HealthResponse,
|
| 831 |
+
tag="Health",
|
| 832 |
+
summary="Health check",
|
| 833 |
+
description="Check if the environment server is running and healthy.",
|
| 834 |
+
),
|
| 835 |
+
]
|
| 836 |
+
|
| 837 |
+
# Only register /state endpoint in simulation mode
|
| 838 |
+
if mode == ServerMode.SIMULATION:
|
| 839 |
+
get_endpoints.insert(
|
| 840 |
+
0,
|
| 841 |
+
GetEndpointConfig(
|
| 842 |
+
path="/state",
|
| 843 |
+
handler=get_state_handler,
|
| 844 |
+
response_model=State,
|
| 845 |
+
tag="State Management",
|
| 846 |
+
summary="Get current environment state",
|
| 847 |
+
description="""
|
| 848 |
+
Retrieve the current internal state of the environment.
|
| 849 |
+
|
| 850 |
+
The structure of the state object is defined by the environment's State model.
|
| 851 |
+
""",
|
| 852 |
+
),
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
register_get_endpoints(app, get_endpoints)
|
| 856 |
+
|
| 857 |
+
# Register combined schema endpoint
|
| 858 |
+
@app.get(
|
| 859 |
+
"/schema",
|
| 860 |
+
response_model=SchemaResponse,
|
| 861 |
+
tags=["Schema"],
|
| 862 |
+
summary="Get all JSON schemas",
|
| 863 |
+
description="""
|
| 864 |
+
Get JSON schemas for actions, observations, and state in a single response.
|
| 865 |
+
|
| 866 |
+
Returns a combined schema object containing:
|
| 867 |
+
- **action**: JSON schema for actions accepted by this environment
|
| 868 |
+
- **observation**: JSON schema for observations returned by this environment
|
| 869 |
+
- **state**: JSON schema for environment state objects
|
| 870 |
+
|
| 871 |
+
This is more efficient than calling individual schema endpoints and provides
|
| 872 |
+
all schema information needed to interact with the environment.
|
| 873 |
+
""",
|
| 874 |
+
responses={
|
| 875 |
+
200: {
|
| 876 |
+
"description": "Combined schemas retrieved successfully",
|
| 877 |
+
"content": {
|
| 878 |
+
"application/json": {
|
| 879 |
+
"example": {
|
| 880 |
+
"action": {
|
| 881 |
+
"type": "object",
|
| 882 |
+
"properties": {"message": {"type": "string"}},
|
| 883 |
+
},
|
| 884 |
+
"observation": {
|
| 885 |
+
"type": "object",
|
| 886 |
+
"properties": {"response": {"type": "string"}},
|
| 887 |
+
},
|
| 888 |
+
"state": {
|
| 889 |
+
"type": "object",
|
| 890 |
+
"properties": {"step_count": {"type": "integer"}},
|
| 891 |
+
},
|
| 892 |
+
}
|
| 893 |
+
}
|
| 894 |
+
},
|
| 895 |
+
}
|
| 896 |
+
},
|
| 897 |
+
)
|
| 898 |
+
async def get_schemas() -> SchemaResponse:
|
| 899 |
+
"""Return all schemas in one response."""
|
| 900 |
+
return SchemaResponse(
|
| 901 |
+
action=self.action_cls.model_json_schema(),
|
| 902 |
+
observation=self.observation_cls.model_json_schema(),
|
| 903 |
+
state=State.model_json_schema(),
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
# Register MCP endpoint for production mode (direct MCP access)
|
| 907 |
+
@app.post("/mcp")
|
| 908 |
+
async def mcp_endpoint(request_raw: Request) -> Dict[str, Any]:
|
| 909 |
+
"""
|
| 910 |
+
MCP JSON-RPC endpoint for production mode.
|
| 911 |
+
|
| 912 |
+
Bypasses step() overhead and provides direct access to MCP tools.
|
| 913 |
+
Supports tools/list and tools/call methods.
|
| 914 |
+
"""
|
| 915 |
+
# Parse JSON manually to handle parse errors gracefully
|
| 916 |
+
try:
|
| 917 |
+
body = await request_raw.body()
|
| 918 |
+
request_dict = json.loads(body)
|
| 919 |
+
request = JsonRpcRequest(**request_dict)
|
| 920 |
+
except json.JSONDecodeError:
|
| 921 |
+
return JsonRpcResponse.error_response(
|
| 922 |
+
JsonRpcErrorCode.PARSE_ERROR
|
| 923 |
+
).model_dump()
|
| 924 |
+
except ValidationError as e:
|
| 925 |
+
return JsonRpcResponse.error_response(
|
| 926 |
+
JsonRpcErrorCode.INVALID_REQUEST,
|
| 927 |
+
f"Invalid request: {e}",
|
| 928 |
+
).model_dump()
|
| 929 |
+
except Exception:
|
| 930 |
+
return JsonRpcResponse.error_response(
|
| 931 |
+
JsonRpcErrorCode.PARSE_ERROR
|
| 932 |
+
).model_dump()
|
| 933 |
+
|
| 934 |
+
method = request.method
|
| 935 |
+
params = request.params
|
| 936 |
+
request_id = request.id
|
| 937 |
+
|
| 938 |
+
# Create a temporary environment for MCP access
|
| 939 |
+
_env = self._env_factory()
|
| 940 |
+
|
| 941 |
+
try:
|
| 942 |
+
# Check if environment supports MCP
|
| 943 |
+
if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"):
|
| 944 |
+
return JsonRpcResponse.error_response(
|
| 945 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 946 |
+
"Environment does not support MCP",
|
| 947 |
+
request_id=request_id,
|
| 948 |
+
).model_dump()
|
| 949 |
+
|
| 950 |
+
if method == McpMethod.TOOLS_LIST:
|
| 951 |
+
# List tools from MCP server
|
| 952 |
+
if hasattr(_env, "mcp_client") and _env.mcp_client:
|
| 953 |
+
async with _env.mcp_client:
|
| 954 |
+
tools = await _env.mcp_client.list_tools()
|
| 955 |
+
return JsonRpcResponse.success(
|
| 956 |
+
result={
|
| 957 |
+
"tools": [
|
| 958 |
+
t.model_dump()
|
| 959 |
+
if hasattr(t, "model_dump")
|
| 960 |
+
else dict(t)
|
| 961 |
+
for t in tools
|
| 962 |
+
]
|
| 963 |
+
},
|
| 964 |
+
request_id=request_id,
|
| 965 |
+
).model_dump()
|
| 966 |
+
elif hasattr(_env, "mcp_server") and _env.mcp_server:
|
| 967 |
+
# Use server directly
|
| 968 |
+
tools = []
|
| 969 |
+
for tool_name, tool in get_server_tools(
|
| 970 |
+
_env.mcp_server
|
| 971 |
+
).items():
|
| 972 |
+
tool_dict = {
|
| 973 |
+
"name": tool.name,
|
| 974 |
+
"description": tool.description or "",
|
| 975 |
+
"inputSchema": tool.parameters or {},
|
| 976 |
+
}
|
| 977 |
+
tools.append(tool_dict)
|
| 978 |
+
return JsonRpcResponse.success(
|
| 979 |
+
result={"tools": tools},
|
| 980 |
+
request_id=request_id,
|
| 981 |
+
).model_dump()
|
| 982 |
+
else:
|
| 983 |
+
return JsonRpcResponse.error_response(
|
| 984 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 985 |
+
"MCP server not available",
|
| 986 |
+
request_id=request_id,
|
| 987 |
+
).model_dump()
|
| 988 |
+
|
| 989 |
+
elif method == McpMethod.TOOLS_CALL:
|
| 990 |
+
tool_name = params.get("name")
|
| 991 |
+
arguments = params.get("arguments", {})
|
| 992 |
+
|
| 993 |
+
if not tool_name:
|
| 994 |
+
return JsonRpcResponse.error_response(
|
| 995 |
+
JsonRpcErrorCode.INVALID_PARAMS,
|
| 996 |
+
"Invalid params - 'name' is required",
|
| 997 |
+
request_id=request_id,
|
| 998 |
+
).model_dump()
|
| 999 |
+
|
| 1000 |
+
# Call tool via MCP
|
| 1001 |
+
if hasattr(_env, "mcp_client") and _env.mcp_client:
|
| 1002 |
+
async with _env.mcp_client:
|
| 1003 |
+
result = await _env.mcp_client.call_tool(
|
| 1004 |
+
name=tool_name, arguments=arguments
|
| 1005 |
+
)
|
| 1006 |
+
elif hasattr(_env, "mcp_server") and _env.mcp_server:
|
| 1007 |
+
# Call tool directly on FastMCP server
|
| 1008 |
+
server_tools = get_server_tools(_env.mcp_server)
|
| 1009 |
+
if tool_name in server_tools:
|
| 1010 |
+
tool = server_tools[tool_name]
|
| 1011 |
+
result = tool.fn(**arguments)
|
| 1012 |
+
else:
|
| 1013 |
+
return JsonRpcResponse.error_response(
|
| 1014 |
+
JsonRpcErrorCode.INVALID_PARAMS,
|
| 1015 |
+
f"Tool not found: {tool_name}",
|
| 1016 |
+
request_id=request_id,
|
| 1017 |
+
).model_dump()
|
| 1018 |
+
else:
|
| 1019 |
+
return JsonRpcResponse.error_response(
|
| 1020 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 1021 |
+
"MCP server not available",
|
| 1022 |
+
request_id=request_id,
|
| 1023 |
+
).model_dump()
|
| 1024 |
+
|
| 1025 |
+
# Make result JSON serializable
|
| 1026 |
+
serializable_result = _make_json_serializable(result)
|
| 1027 |
+
|
| 1028 |
+
return JsonRpcResponse.success(
|
| 1029 |
+
result=serializable_result,
|
| 1030 |
+
request_id=request_id,
|
| 1031 |
+
).model_dump()
|
| 1032 |
+
|
| 1033 |
+
else:
|
| 1034 |
+
return JsonRpcResponse.error_response(
|
| 1035 |
+
JsonRpcErrorCode.METHOD_NOT_FOUND,
|
| 1036 |
+
f"Method not found: {method}",
|
| 1037 |
+
request_id=request_id,
|
| 1038 |
+
).model_dump()
|
| 1039 |
+
|
| 1040 |
+
except Exception as e:
|
| 1041 |
+
return JsonRpcResponse.error_response(
|
| 1042 |
+
JsonRpcErrorCode.INTERNAL_ERROR,
|
| 1043 |
+
str(e),
|
| 1044 |
+
request_id=request_id,
|
| 1045 |
+
).model_dump()
|
| 1046 |
+
finally:
|
| 1047 |
+
_env.close()
|
| 1048 |
+
|
| 1049 |
+
# Register WebSocket endpoint for persistent sessions
|
| 1050 |
+
@app.websocket("/ws")
|
| 1051 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 1052 |
+
"""
|
| 1053 |
+
WebSocket endpoint for persistent environment sessions.
|
| 1054 |
+
|
| 1055 |
+
Each WebSocket connection gets its own environment instance.
|
| 1056 |
+
|
| 1057 |
+
Message Protocol:
|
| 1058 |
+
- Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage
|
| 1059 |
+
- Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse
|
| 1060 |
+
"""
|
| 1061 |
+
await websocket.accept()
|
| 1062 |
+
|
| 1063 |
+
session_id = None
|
| 1064 |
+
session_env = None
|
| 1065 |
+
|
| 1066 |
+
try:
|
| 1067 |
+
# Create session with dedicated environment
|
| 1068 |
+
session_id, session_env = await self._create_session()
|
| 1069 |
+
|
| 1070 |
+
while True:
|
| 1071 |
+
# Receive message from client
|
| 1072 |
+
raw_message = await websocket.receive_text()
|
| 1073 |
+
|
| 1074 |
+
try:
|
| 1075 |
+
message_dict = json.loads(raw_message)
|
| 1076 |
+
except json.JSONDecodeError as e:
|
| 1077 |
+
error_resp = WSErrorResponse(
|
| 1078 |
+
data={
|
| 1079 |
+
"message": f"Invalid JSON: {e}",
|
| 1080 |
+
"code": WSErrorCode.INVALID_JSON,
|
| 1081 |
+
}
|
| 1082 |
+
)
|
| 1083 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 1084 |
+
continue
|
| 1085 |
+
|
| 1086 |
+
msg_type = message_dict.get("type", "")
|
| 1087 |
+
|
| 1088 |
+
try:
|
| 1089 |
+
match msg_type:
|
| 1090 |
+
case "reset":
|
| 1091 |
+
msg = WSResetMessage(**message_dict)
|
| 1092 |
+
|
| 1093 |
+
is_async = (
|
| 1094 |
+
session_env.reset_async.__func__
|
| 1095 |
+
is not Environment.reset_async
|
| 1096 |
+
)
|
| 1097 |
+
|
| 1098 |
+
if is_async:
|
| 1099 |
+
sig = inspect.signature(session_env.reset_async)
|
| 1100 |
+
valid_kwargs = self._get_valid_kwargs(sig, msg.data)
|
| 1101 |
+
observation = await session_env.reset_async(
|
| 1102 |
+
**valid_kwargs
|
| 1103 |
+
)
|
| 1104 |
+
else:
|
| 1105 |
+
sig = inspect.signature(session_env.reset)
|
| 1106 |
+
valid_kwargs = self._get_valid_kwargs(sig, msg.data)
|
| 1107 |
+
observation = await self._run_in_session_executor(
|
| 1108 |
+
session_id, session_env.reset, **valid_kwargs
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
self._update_session_activity(session_id)
|
| 1112 |
+
|
| 1113 |
+
response = WSObservationResponse(
|
| 1114 |
+
data=serialize_observation(observation),
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
case "step":
|
| 1118 |
+
msg = WSStepMessage(**message_dict)
|
| 1119 |
+
action = deserialize_action(msg.data, self.action_cls)
|
| 1120 |
+
|
| 1121 |
+
is_async = (
|
| 1122 |
+
session_env.step_async.__func__
|
| 1123 |
+
is not Environment.step_async
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
if is_async:
|
| 1127 |
+
observation = await session_env.step_async(action)
|
| 1128 |
+
else:
|
| 1129 |
+
observation = await self._run_in_session_executor(
|
| 1130 |
+
session_id, session_env.step, action
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
self._update_session_activity(
|
| 1134 |
+
session_id, increment_step=True
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
response = WSObservationResponse(
|
| 1138 |
+
data=serialize_observation(observation)
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
case "state":
|
| 1142 |
+
msg = WSStateMessage(**message_dict)
|
| 1143 |
+
state = session_env.state
|
| 1144 |
+
if hasattr(state, "model_dump"):
|
| 1145 |
+
state_data = state.model_dump()
|
| 1146 |
+
else:
|
| 1147 |
+
state_data = dict(state) if state else {}
|
| 1148 |
+
|
| 1149 |
+
response = WSStateResponse(data=state_data)
|
| 1150 |
+
|
| 1151 |
+
case "close":
|
| 1152 |
+
msg = WSCloseMessage(**message_dict)
|
| 1153 |
+
break
|
| 1154 |
+
|
| 1155 |
+
case "mcp":
|
| 1156 |
+
msg = WSMCPMessage(**message_dict)
|
| 1157 |
+
try:
|
| 1158 |
+
rpc_request = JsonRpcRequest(**msg.data)
|
| 1159 |
+
except (ValidationError, Exception) as e:
|
| 1160 |
+
rpc_response = JsonRpcResponse.error_response(
|
| 1161 |
+
JsonRpcErrorCode.INVALID_REQUEST,
|
| 1162 |
+
f"Invalid request: {e}",
|
| 1163 |
+
)
|
| 1164 |
+
else:
|
| 1165 |
+
rpc_response = await mcp_handler(
|
| 1166 |
+
rpc_request,
|
| 1167 |
+
session_env=session_env,
|
| 1168 |
+
)
|
| 1169 |
+
response = WSMCPResponse(data=rpc_response.model_dump())
|
| 1170 |
+
|
| 1171 |
+
case _:
|
| 1172 |
+
response = WSErrorResponse(
|
| 1173 |
+
data={
|
| 1174 |
+
"message": f"Unknown message type: {msg_type}",
|
| 1175 |
+
"code": WSErrorCode.UNKNOWN_TYPE,
|
| 1176 |
+
}
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
await websocket.send_text(response.model_dump_json())
|
| 1180 |
+
|
| 1181 |
+
except ValidationError as e:
|
| 1182 |
+
error_resp = WSErrorResponse(
|
| 1183 |
+
data={
|
| 1184 |
+
"message": "Invalid message",
|
| 1185 |
+
"code": WSErrorCode.VALIDATION_ERROR,
|
| 1186 |
+
"errors": e.errors(),
|
| 1187 |
+
}
|
| 1188 |
+
)
|
| 1189 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 1190 |
+
except Exception as e:
|
| 1191 |
+
error_resp = WSErrorResponse(
|
| 1192 |
+
data={
|
| 1193 |
+
"message": str(e),
|
| 1194 |
+
"code": WSErrorCode.EXECUTION_ERROR,
|
| 1195 |
+
}
|
| 1196 |
+
)
|
| 1197 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 1198 |
+
|
| 1199 |
+
except WebSocketDisconnect:
|
| 1200 |
+
pass
|
| 1201 |
+
except SessionCapacityError as e:
|
| 1202 |
+
error_resp = WSErrorResponse(
|
| 1203 |
+
data={
|
| 1204 |
+
"message": str(e),
|
| 1205 |
+
"code": WSErrorCode.CAPACITY_REACHED,
|
| 1206 |
+
"active_sessions": e.active_sessions,
|
| 1207 |
+
"max_sessions": e.max_sessions,
|
| 1208 |
+
}
|
| 1209 |
+
)
|
| 1210 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 1211 |
+
except EnvironmentFactoryError as e:
|
| 1212 |
+
error_resp = WSErrorResponse(
|
| 1213 |
+
data={
|
| 1214 |
+
"message": str(e),
|
| 1215 |
+
"code": WSErrorCode.FACTORY_ERROR,
|
| 1216 |
+
"factory_name": e.factory_name,
|
| 1217 |
+
}
|
| 1218 |
+
)
|
| 1219 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 1220 |
+
except Exception as e:
|
| 1221 |
+
error_resp = WSErrorResponse(
|
| 1222 |
+
data={"message": str(e), "code": WSErrorCode.SESSION_ERROR}
|
| 1223 |
+
)
|
| 1224 |
+
await websocket.send_text(error_resp.model_dump_json())
|
| 1225 |
+
finally:
|
| 1226 |
+
if session_id:
|
| 1227 |
+
await self._destroy_session(session_id)
|
| 1228 |
+
try:
|
| 1229 |
+
await websocket.close()
|
| 1230 |
+
except RuntimeError:
|
| 1231 |
+
pass
|
| 1232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1233 |
|
| 1234 |
def create_app(
|
| 1235 |
+
env: Callable[[], Environment],
|
| 1236 |
action_cls: Type[Action],
|
| 1237 |
observation_cls: Type[Observation],
|
| 1238 |
env_name: Optional[str] = None,
|
| 1239 |
+
max_concurrent_envs: Optional[int] = None,
|
| 1240 |
+
concurrency_config: Optional[ConcurrencyConfig] = None,
|
| 1241 |
+
gradio_builder: Optional[Callable[..., Any]] = None,
|
| 1242 |
+
) -> FastAPI:
|
| 1243 |
"""
|
| 1244 |
Create a FastAPI application with or without web interface.
|
| 1245 |
+
|
| 1246 |
This function creates a FastAPI app with the web interface enabled by default,
|
| 1247 |
including README integration for better user experience.
|
| 1248 |
+
|
| 1249 |
Args:
|
| 1250 |
+
env: Environment factory (callable) that creates new instances
|
| 1251 |
action_cls: The Action subclass this environment expects
|
| 1252 |
observation_cls: The Observation subclass this environment returns
|
| 1253 |
env_name: Optional environment name for README loading
|
| 1254 |
+
max_concurrent_envs: Maximum concurrent WebSocket sessions.
|
| 1255 |
+
Mutually exclusive with concurrency_config.
|
| 1256 |
+
concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
|
| 1257 |
+
Mutually exclusive with max_concurrent_envs.
|
| 1258 |
+
gradio_builder: Optional callable to build a custom Gradio UI at /web.
|
| 1259 |
+
Signature: (web_manager, action_fields, metadata, is_chat_env, title,
|
| 1260 |
+
quick_start_md) -> gr.Blocks. When None, the default Gradio app is used.
|
| 1261 |
+
See docs/customizing-web-ui.md.
|
| 1262 |
+
|
| 1263 |
Returns:
|
| 1264 |
FastAPI application instance with or without web interface and README integration
|
| 1265 |
"""
|
| 1266 |
# Check if web interface should be enabled
|
| 1267 |
# This can be controlled via environment variable or build argument
|
| 1268 |
+
enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in (
|
| 1269 |
+
"true",
|
| 1270 |
+
"1",
|
| 1271 |
+
"yes",
|
| 1272 |
)
|
| 1273 |
|
| 1274 |
if enable_web:
|
| 1275 |
+
# Gradio-based web UI (gradio is a core dependency)
|
| 1276 |
from .web_interface import create_web_interface_app
|
| 1277 |
+
|
| 1278 |
+
return create_web_interface_app(
|
| 1279 |
+
env,
|
| 1280 |
+
action_cls,
|
| 1281 |
+
observation_cls,
|
| 1282 |
+
env_name,
|
| 1283 |
+
max_concurrent_envs,
|
| 1284 |
+
concurrency_config,
|
| 1285 |
+
gradio_builder=gradio_builder,
|
| 1286 |
+
)
|
| 1287 |
else:
|
| 1288 |
# Use standard FastAPI app without web interface
|
| 1289 |
+
return create_fastapi_app(
|
| 1290 |
+
env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
|
| 1291 |
+
)
|
| 1292 |
+
|
| 1293 |
|
| 1294 |
def create_fastapi_app(
|
| 1295 |
+
env: Callable[[], Environment],
|
| 1296 |
action_cls: Type[Action],
|
| 1297 |
observation_cls: Type[Observation],
|
| 1298 |
+
max_concurrent_envs: Optional[int] = None,
|
| 1299 |
+
concurrency_config: Optional[ConcurrencyConfig] = None,
|
| 1300 |
+
) -> FastAPI:
|
| 1301 |
"""
|
| 1302 |
+
Create a FastAPI application with comprehensive documentation.
|
| 1303 |
|
| 1304 |
Args:
|
| 1305 |
+
env: Environment factory (callable) that creates new instances
|
| 1306 |
action_cls: The Action subclass this environment expects
|
| 1307 |
observation_cls: The Observation subclass this environment returns
|
| 1308 |
+
max_concurrent_envs: Maximum concurrent WebSocket sessions.
|
| 1309 |
+
Mutually exclusive with concurrency_config.
|
| 1310 |
+
concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
|
| 1311 |
+
Mutually exclusive with max_concurrent_envs.
|
| 1312 |
|
| 1313 |
Returns:
|
| 1314 |
+
FastAPI application instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1315 |
"""
|
| 1316 |
try:
|
| 1317 |
from fastapi import FastAPI
|
|
|
|
| 1320 |
"FastAPI is required. Install with: pip install fastapi uvicorn"
|
| 1321 |
)
|
| 1322 |
|
| 1323 |
+
app = FastAPI(
|
| 1324 |
+
title="OpenEnv Environment HTTP API",
|
| 1325 |
+
version="1.0.0",
|
| 1326 |
+
description="""
|
| 1327 |
+
# OpenEnv Environment HTTP API
|
| 1328 |
+
|
| 1329 |
+
HTTP API for interacting with OpenEnv environments through a standardized interface.
|
| 1330 |
+
|
| 1331 |
+
## Features
|
| 1332 |
+
|
| 1333 |
+
* **Environment Reset**: Initialize or restart episodes
|
| 1334 |
+
* **Action Execution**: Send actions and receive observations
|
| 1335 |
+
* **State Inspection**: Query current environment state
|
| 1336 |
+
* **Schema Access**: Retrieve JSON schemas for actions and observations
|
| 1337 |
+
|
| 1338 |
+
## Workflow
|
| 1339 |
+
|
| 1340 |
+
1. Call `/reset` to start a new episode and get initial observation
|
| 1341 |
+
2. Call `/step` repeatedly with actions to interact with environment
|
| 1342 |
+
3. Episode ends when observation returns `done: true`
|
| 1343 |
+
4. Call `/state` anytime to inspect current environment state
|
| 1344 |
+
|
| 1345 |
+
## Documentation
|
| 1346 |
+
|
| 1347 |
+
* **Swagger UI**: Available at `/docs`
|
| 1348 |
+
* **ReDoc**: Available at `/redoc`
|
| 1349 |
+
* **OpenAPI Schema**: Available at `/openapi.json`
|
| 1350 |
+
""",
|
| 1351 |
+
openapi_tags=[
|
| 1352 |
+
{
|
| 1353 |
+
"name": "Environment Control",
|
| 1354 |
+
"description": "Core operations for environment interaction (reset, step)",
|
| 1355 |
+
},
|
| 1356 |
+
{
|
| 1357 |
+
"name": "State Management",
|
| 1358 |
+
"description": "Operations for inspecting environment state",
|
| 1359 |
+
},
|
| 1360 |
+
{
|
| 1361 |
+
"name": "Environment Info",
|
| 1362 |
+
"description": "Information about the environment",
|
| 1363 |
+
},
|
| 1364 |
+
{
|
| 1365 |
+
"name": "Schema",
|
| 1366 |
+
"description": "JSON Schema endpoints for actions, observations, and state",
|
| 1367 |
+
},
|
| 1368 |
+
{"name": "Health", "description": "Service health and status checks"},
|
| 1369 |
+
],
|
| 1370 |
+
docs_url="/docs",
|
| 1371 |
+
redoc_url="/redoc",
|
| 1372 |
+
openapi_url="/openapi.json",
|
| 1373 |
+
contact={
|
| 1374 |
+
"name": "OpenEnv Team",
|
| 1375 |
+
"url": "https://github.com/meta-pytorch/OpenEnv",
|
| 1376 |
+
},
|
| 1377 |
+
license_info={
|
| 1378 |
+
"name": "BSD-3-Clause",
|
| 1379 |
+
"url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE",
|
| 1380 |
+
},
|
| 1381 |
+
)
|
| 1382 |
+
|
| 1383 |
+
server = HTTPEnvServer(
|
| 1384 |
+
env,
|
| 1385 |
+
action_cls,
|
| 1386 |
+
observation_cls,
|
| 1387 |
+
max_concurrent_envs,
|
| 1388 |
+
concurrency_config=concurrency_config,
|
| 1389 |
+
)
|
| 1390 |
server.register_routes(app)
|
| 1391 |
return app
|
src/core/env_server/interfaces.py
CHANGED
|
@@ -4,10 +4,20 @@
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
|
|
|
| 7 |
from abc import ABC, abstractmethod
|
| 8 |
-
from typing import Any, Protocol,
|
| 9 |
|
| 10 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class Message(TypedDict):
|
|
@@ -64,7 +74,7 @@ class ModelTokenizer(Protocol):
|
|
| 64 |
...
|
| 65 |
|
| 66 |
|
| 67 |
-
class Transform(ABC):
|
| 68 |
"""Transform observations to add rewards, metrics, or other modifications.
|
| 69 |
|
| 70 |
Transforms follow the TorchRL pattern where they take an observation
|
|
@@ -73,7 +83,7 @@ class Transform(ABC):
|
|
| 73 |
"""
|
| 74 |
|
| 75 |
@abstractmethod
|
| 76 |
-
def __call__(self, observation:
|
| 77 |
"""Transform an observation.
|
| 78 |
|
| 79 |
Args:
|
|
@@ -85,34 +95,203 @@ class Transform(ABC):
|
|
| 85 |
pass
|
| 86 |
|
| 87 |
|
| 88 |
-
class Environment(ABC):
|
| 89 |
"""Base class for all environment servers following Gym/Gymnasium API.
|
| 90 |
|
| 91 |
Args:
|
| 92 |
transform: Optional transform to apply to observations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
"""
|
| 94 |
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
self.transform = transform
|
|
|
|
| 97 |
|
| 98 |
@abstractmethod
|
| 99 |
-
def reset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
"""Reset the environment and return initial observation."""
|
| 101 |
pass
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
@abstractmethod
|
| 104 |
-
def step(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
"""Take a step in the environment."""
|
| 106 |
pass
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
@property
|
| 109 |
@abstractmethod
|
| 110 |
-
def state(self) ->
|
| 111 |
"""Get the current environment state."""
|
| 112 |
pass
|
| 113 |
|
| 114 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
"""Apply transform if one is provided."""
|
| 116 |
if self.transform is not None:
|
| 117 |
return self.transform(observation)
|
| 118 |
return observation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
import inspect
|
| 8 |
from abc import ABC, abstractmethod
|
| 9 |
+
from typing import Any, Generic, Optional, Protocol, TYPE_CHECKING, TypeVar
|
| 10 |
|
| 11 |
+
from typing_extensions import TypedDict
|
| 12 |
+
|
| 13 |
+
from .types import Action, EnvironmentMetadata, Observation, State
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from openenv.core.rubrics import Rubric
|
| 17 |
+
|
| 18 |
+
ActT = TypeVar("ActT", bound=Action)
|
| 19 |
+
ObsT = TypeVar("ObsT", bound=Observation)
|
| 20 |
+
StateT = TypeVar("StateT", bound=State)
|
| 21 |
|
| 22 |
|
| 23 |
class Message(TypedDict):
|
|
|
|
| 74 |
...
|
| 75 |
|
| 76 |
|
| 77 |
+
class Transform(ABC, Generic[ObsT]):
|
| 78 |
"""Transform observations to add rewards, metrics, or other modifications.
|
| 79 |
|
| 80 |
Transforms follow the TorchRL pattern where they take an observation
|
|
|
|
| 83 |
"""
|
| 84 |
|
| 85 |
@abstractmethod
|
| 86 |
+
def __call__(self, observation: ObsT) -> ObsT:
|
| 87 |
"""Transform an observation.
|
| 88 |
|
| 89 |
Args:
|
|
|
|
| 95 |
pass
|
| 96 |
|
| 97 |
|
| 98 |
+
class Environment(ABC, Generic[ActT, ObsT, StateT]):
|
| 99 |
"""Base class for all environment servers following Gym/Gymnasium API.
|
| 100 |
|
| 101 |
Args:
|
| 102 |
transform: Optional transform to apply to observations
|
| 103 |
+
rubric: Optional rubric for reward computation. When provided, the
|
| 104 |
+
rubric's output can be used to set the observation's reward in step().
|
| 105 |
+
|
| 106 |
+
Class Attributes:
|
| 107 |
+
SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions.
|
| 108 |
+
When True, multiple WebSocket connections can each have their own
|
| 109 |
+
environment instance (up to max_concurrent_envs). When False (default),
|
| 110 |
+
the environment should only be used with a single session at a time.
|
| 111 |
+
|
| 112 |
+
Set this to True in your Environment subclass if:
|
| 113 |
+
- The environment uses proper session isolation (e.g., unique working dirs)
|
| 114 |
+
- No shared mutable state exists between instances
|
| 115 |
+
- External resources (databases, APIs) can handle concurrent access
|
| 116 |
+
|
| 117 |
+
Attributes:
|
| 118 |
+
rubric: Optional rubric for computing rewards. Environments can set this
|
| 119 |
+
in __init__ and use it in step() to compute observation rewards.
|
| 120 |
+
Training infrastructure can access it for introspection:
|
| 121 |
+
for name, r in env.rubric.named_rubrics():
|
| 122 |
+
print(f"{name}: {r.last_score}")
|
| 123 |
+
|
| 124 |
+
See RFC 004 for rubric design: rfcs/004-rubrics.md
|
| 125 |
"""
|
| 126 |
|
| 127 |
+
# Class-level flag indicating whether this environment supports concurrent sessions
|
| 128 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = False
|
| 129 |
+
|
| 130 |
+
# Optional rubric for reward computation
|
| 131 |
+
rubric: Optional["Rubric"]
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
transform: Optional[Transform[ObsT]] = None,
|
| 136 |
+
rubric: Optional["Rubric"] = None,
|
| 137 |
+
):
|
| 138 |
self.transform = transform
|
| 139 |
+
self.rubric = rubric
|
| 140 |
|
| 141 |
@abstractmethod
|
| 142 |
+
def reset(
|
| 143 |
+
self,
|
| 144 |
+
seed: Optional[int] = None,
|
| 145 |
+
episode_id: Optional[str] = None,
|
| 146 |
+
**kwargs: Any,
|
| 147 |
+
) -> ObsT:
|
| 148 |
"""Reset the environment and return initial observation."""
|
| 149 |
pass
|
| 150 |
|
| 151 |
+
async def reset_async(
|
| 152 |
+
self,
|
| 153 |
+
seed: Optional[int] = None,
|
| 154 |
+
episode_id: Optional[str] = None,
|
| 155 |
+
**kwargs: Any,
|
| 156 |
+
) -> ObsT:
|
| 157 |
+
"""Async version of reset. Default implementation calls sync reset.
|
| 158 |
+
|
| 159 |
+
Override to provide true async implementation.
|
| 160 |
+
"""
|
| 161 |
+
return self.reset(seed=seed, episode_id=episode_id, **kwargs)
|
| 162 |
+
|
| 163 |
@abstractmethod
|
| 164 |
+
def step(
|
| 165 |
+
self,
|
| 166 |
+
action: ActT,
|
| 167 |
+
timeout_s: Optional[float] = None,
|
| 168 |
+
**kwargs: Any,
|
| 169 |
+
) -> ObsT:
|
| 170 |
"""Take a step in the environment."""
|
| 171 |
pass
|
| 172 |
|
| 173 |
+
async def step_async(
|
| 174 |
+
self,
|
| 175 |
+
action: ActT,
|
| 176 |
+
timeout_s: Optional[float] = None,
|
| 177 |
+
**kwargs: Any,
|
| 178 |
+
) -> ObsT:
|
| 179 |
+
"""Async version of step. Default implementation calls sync step.
|
| 180 |
+
|
| 181 |
+
Override to provide true async implementation.
|
| 182 |
+
"""
|
| 183 |
+
return self.step(action, timeout_s=timeout_s, **kwargs)
|
| 184 |
+
|
| 185 |
@property
|
| 186 |
@abstractmethod
|
| 187 |
+
def state(self) -> StateT:
|
| 188 |
"""Get the current environment state."""
|
| 189 |
pass
|
| 190 |
|
| 191 |
+
def get_metadata(self) -> EnvironmentMetadata:
|
| 192 |
+
"""
|
| 193 |
+
Get metadata about this environment.
|
| 194 |
+
|
| 195 |
+
Override this method to provide custom metadata for the environment.
|
| 196 |
+
Default implementation returns basic metadata derived from class name.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
EnvironmentMetadata with environment information
|
| 200 |
+
"""
|
| 201 |
+
return EnvironmentMetadata(
|
| 202 |
+
name=self.__class__.__name__,
|
| 203 |
+
description=f"{self.__class__.__name__} environment",
|
| 204 |
+
version="1.0.0",
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def _apply_transform(self, observation: ObsT) -> ObsT:
|
| 208 |
"""Apply transform if one is provided."""
|
| 209 |
if self.transform is not None:
|
| 210 |
return self.transform(observation)
|
| 211 |
return observation
|
| 212 |
+
|
| 213 |
+
def _apply_rubric(self, action: ActT, observation: ObsT) -> float:
|
| 214 |
+
"""Apply rubric if one is provided.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
action: The action taken by the agent.
|
| 218 |
+
observation: The resulting observation.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Reward value from the rubric, or 0.0 if no rubric is set.
|
| 222 |
+
|
| 223 |
+
Usage in step():
|
| 224 |
+
def step(self, action: MyAction, ...) -> MyObservation:
|
| 225 |
+
# ... execute action and create observation ...
|
| 226 |
+
observation.reward = self._apply_rubric(action, observation)
|
| 227 |
+
return observation
|
| 228 |
+
"""
|
| 229 |
+
if self.rubric is not None:
|
| 230 |
+
return self.rubric(action, observation)
|
| 231 |
+
return 0.0
|
| 232 |
+
|
| 233 |
+
async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float:
|
| 234 |
+
"""Apply rubric asynchronously if one is provided.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
action: The action taken by the agent.
|
| 238 |
+
observation: The resulting observation.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Reward value from the rubric, or 0.0 if no rubric is set.
|
| 242 |
+
|
| 243 |
+
Usage in step_async():
|
| 244 |
+
async def step_async(self, action: MyAction, ...) -> MyObservation:
|
| 245 |
+
# ... execute action and create observation ...
|
| 246 |
+
observation.reward = await self._apply_rubric_async(action, observation)
|
| 247 |
+
return observation
|
| 248 |
+
"""
|
| 249 |
+
if self.rubric is not None:
|
| 250 |
+
result = self.rubric(action, observation)
|
| 251 |
+
# If rubric returns a coroutine, await it
|
| 252 |
+
if inspect.iscoroutine(result):
|
| 253 |
+
return await result
|
| 254 |
+
return result
|
| 255 |
+
return 0.0
|
| 256 |
+
|
| 257 |
+
def _reset_rubric(self) -> None:
|
| 258 |
+
"""Reset the rubric state if one is provided.
|
| 259 |
+
|
| 260 |
+
Call this in reset() to clear any trajectory state in the rubric.
|
| 261 |
+
|
| 262 |
+
Usage in reset():
|
| 263 |
+
def reset(self, ...) -> MyObservation:
|
| 264 |
+
self._reset_rubric()
|
| 265 |
+
# ... create initial observation ...
|
| 266 |
+
return observation
|
| 267 |
+
"""
|
| 268 |
+
if self.rubric is not None:
|
| 269 |
+
self.rubric.reset()
|
| 270 |
+
|
| 271 |
+
async def _reset_rubric_async(self) -> None:
|
| 272 |
+
"""Reset the rubric state asynchronously if one is provided.
|
| 273 |
+
|
| 274 |
+
Call this in reset_async() to clear any trajectory state in the rubric.
|
| 275 |
+
|
| 276 |
+
Usage in reset_async():
|
| 277 |
+
async def reset_async(self, ...) -> MyObservation:
|
| 278 |
+
await self._reset_rubric_async()
|
| 279 |
+
# ... create initial observation ...
|
| 280 |
+
return observation
|
| 281 |
+
"""
|
| 282 |
+
if self.rubric is not None:
|
| 283 |
+
# Check if rubric has async reset method
|
| 284 |
+
if hasattr(self.rubric, "reset_async"):
|
| 285 |
+
result = self.rubric.reset_async()
|
| 286 |
+
if inspect.iscoroutine(result):
|
| 287 |
+
await result
|
| 288 |
+
else:
|
| 289 |
+
self.rubric.reset()
|
| 290 |
+
|
| 291 |
+
def close(self) -> None:
|
| 292 |
+
"""Clean up resources used by the environment.
|
| 293 |
+
|
| 294 |
+
Override this method to implement custom cleanup logic.
|
| 295 |
+
Called when the environment is being destroyed or reset.
|
| 296 |
+
"""
|
| 297 |
+
pass
|
src/core/env_server/mcp_environment.py
ADDED
|
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
MCP Environment base class for OpenEnv.
|
| 9 |
+
|
| 10 |
+
This module provides the MCPEnvironment base class that integrates FastMCP servers
|
| 11 |
+
with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery
|
| 12 |
+
and invocation through the step() API, following RFC 003.
|
| 13 |
+
|
| 14 |
+
Key features:
|
| 15 |
+
- Automatic routing of ListToolsAction and CallToolAction to MCP server
|
| 16 |
+
- Reserved tool name validation (reset, step, state, close are protected)
|
| 17 |
+
- Timeout handling for tool calls
|
| 18 |
+
- Proper error categorization (tool not found, execution errors, timeouts)
|
| 19 |
+
- Mode-aware tool registration (production vs simulation)
|
| 20 |
+
- Code mode support via get_callables() and execute_code()
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
from fastmcp import FastMCP
|
| 24 |
+
from openenv.core.env_server.mcp_environment import MCPEnvironment
|
| 25 |
+
|
| 26 |
+
class MyMCPEnv(MCPEnvironment):
|
| 27 |
+
def __init__(self):
|
| 28 |
+
mcp = FastMCP("my-server")
|
| 29 |
+
|
| 30 |
+
# Register mode-specific tools
|
| 31 |
+
@self.tool(mode="production")
|
| 32 |
+
def my_tool(arg: str) -> str:
|
| 33 |
+
return f"Production: {arg}"
|
| 34 |
+
|
| 35 |
+
@self.tool(mode="simulation")
|
| 36 |
+
def my_tool(arg: str) -> str:
|
| 37 |
+
return f"Simulation: {arg}"
|
| 38 |
+
|
| 39 |
+
super().__init__(mcp)
|
| 40 |
+
|
| 41 |
+
def reset(self, seed=None, episode_id=None, **kwargs):
|
| 42 |
+
# Reset logic here
|
| 43 |
+
...
|
| 44 |
+
|
| 45 |
+
def _step_impl(self, action):
|
| 46 |
+
# Handle non-MCP actions
|
| 47 |
+
...
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def state(self):
|
| 51 |
+
# Return current state
|
| 52 |
+
...
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
import asyncio
|
| 56 |
+
import inspect
|
| 57 |
+
from abc import abstractmethod
|
| 58 |
+
from collections import defaultdict
|
| 59 |
+
from typing import Any, Callable, Dict, Optional
|
| 60 |
+
|
| 61 |
+
from fastmcp import Client
|
| 62 |
+
from fastmcp.client.client import CallToolResult
|
| 63 |
+
from mcp.types import TextContent
|
| 64 |
+
|
| 65 |
+
from ..utils import run_async_safely
|
| 66 |
+
from .interfaces import Environment
|
| 67 |
+
from .mcp_types import (
|
| 68 |
+
CallToolAction,
|
| 69 |
+
CallToolObservation,
|
| 70 |
+
ListToolsAction,
|
| 71 |
+
ListToolsObservation,
|
| 72 |
+
RESERVED_TOOL_NAMES,
|
| 73 |
+
Tool,
|
| 74 |
+
ToolError,
|
| 75 |
+
ToolErrorType,
|
| 76 |
+
)
|
| 77 |
+
from .types import Action, Observation
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Default timeout for MCP tool calls in seconds
|
| 81 |
+
MCP_TOOL_CALL_TIMEOUT = 30.0
|
| 82 |
+
|
| 83 |
+
# Valid modes for tool registration
|
| 84 |
+
VALID_MODES = {"production", "simulation"}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_server_tools(mcp_server: Any) -> Dict[str, Any]:
|
| 88 |
+
"""
|
| 89 |
+
Get tools from a FastMCP server, compatible with both 2.x and 3.x.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Dictionary mapping tool names to tool objects.
|
| 93 |
+
"""
|
| 94 |
+
# FastMCP 2.x: get_tools() returns dict {name: Tool}
|
| 95 |
+
if hasattr(mcp_server, "get_tools"):
|
| 96 |
+
result = run_async_safely(mcp_server.get_tools())
|
| 97 |
+
if isinstance(result, dict):
|
| 98 |
+
return result
|
| 99 |
+
# FastMCP 3.x: list_tools() returns list of Tool objects
|
| 100 |
+
if hasattr(mcp_server, "list_tools"):
|
| 101 |
+
tools_list = run_async_safely(mcp_server.list_tools())
|
| 102 |
+
return {t.name: t for t in tools_list}
|
| 103 |
+
return {}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class MCPEnvironment(Environment):
|
| 107 |
+
"""
|
| 108 |
+
Base class for environments that expose tools via MCP (Model Context Protocol).
|
| 109 |
+
|
| 110 |
+
MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing
|
| 111 |
+
agents to discover and invoke MCP tools through the standard step() interface.
|
| 112 |
+
|
| 113 |
+
The class automatically handles:
|
| 114 |
+
- ListToolsAction: Returns available tools from the MCP server
|
| 115 |
+
- CallToolAction: Invokes a specific tool with arguments
|
| 116 |
+
|
| 117 |
+
All other actions are delegated to the abstract _step_impl() method,
|
| 118 |
+
which subclasses must implement.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
mcp_server: A FastMCP server instance containing tool definitions.
|
| 122 |
+
The server's tools will be validated against reserved names.
|
| 123 |
+
transform: Optional transform to apply to observations (inherited from Environment).
|
| 124 |
+
|
| 125 |
+
Raises:
|
| 126 |
+
ValueError: If any tool in the MCP server uses a reserved name
|
| 127 |
+
(reset, step, state, close).
|
| 128 |
+
|
| 129 |
+
Example:
|
| 130 |
+
>>> from fastmcp import FastMCP
|
| 131 |
+
>>> mcp = FastMCP("calculator")
|
| 132 |
+
>>> @mcp.tool()
|
| 133 |
+
... def add(a: int, b: int) -> int:
|
| 134 |
+
... return a + b
|
| 135 |
+
>>> env = MyMCPEnvironment(mcp)
|
| 136 |
+
>>> obs = env.step(ListToolsAction())
|
| 137 |
+
>>> obs.tools[0].name
|
| 138 |
+
'add'
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None:
|
| 142 |
+
"""
|
| 143 |
+
Initialize the MCP environment.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
mcp_server: A FastMCP server instance with tool definitions.
|
| 147 |
+
transform: Optional transform to apply to observations.
|
| 148 |
+
|
| 149 |
+
Raises:
|
| 150 |
+
ValueError: If any tool uses a reserved name (reset, step, state, close).
|
| 151 |
+
"""
|
| 152 |
+
super().__init__(transform=transform)
|
| 153 |
+
|
| 154 |
+
# Validate tool names before storing
|
| 155 |
+
self._validate_tool_names(mcp_server)
|
| 156 |
+
|
| 157 |
+
self.mcp_server = mcp_server
|
| 158 |
+
self.mcp_client = Client(mcp_server)
|
| 159 |
+
|
| 160 |
+
# Track mode-specific tools: {tool_name: {mode: func}}
|
| 161 |
+
# mode can be "production", "simulation", or None (available in all modes)
|
| 162 |
+
self._mode_tools = defaultdict(dict)
|
| 163 |
+
|
| 164 |
+
# Track tool schemas for list_tools: {tool_name: {mode: schema}}
|
| 165 |
+
self._mode_tool_schemas = defaultdict(dict)
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def supports_code_mode(self) -> bool:
|
| 169 |
+
"""Check if this environment supports code mode (execute_code)."""
|
| 170 |
+
return True
|
| 171 |
+
|
| 172 |
+
def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]:
|
| 173 |
+
"""
|
| 174 |
+
Get tools from a FastMCP server, compatible with both 2.x and 3.x.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Dictionary mapping tool names to tool objects.
|
| 178 |
+
"""
|
| 179 |
+
return get_server_tools(mcp_server)
|
| 180 |
+
|
| 181 |
+
def get_callables(self) -> Dict[str, Callable]:
|
| 182 |
+
"""
|
| 183 |
+
Get callable functions for code mode.
|
| 184 |
+
|
| 185 |
+
Returns tool functions as direct Python callables, enabling code mode
|
| 186 |
+
where agents write Python code that calls tools directly (no JSON-RPC
|
| 187 |
+
overhead). Mode-specific tools are filtered by the current mode.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Dictionary mapping tool names to callables.
|
| 191 |
+
"""
|
| 192 |
+
callables: Dict[str, Callable] = {}
|
| 193 |
+
current_mode = getattr(self, "_mode", None)
|
| 194 |
+
|
| 195 |
+
# Extract callables from FastMCP server using public API
|
| 196 |
+
for tool_name, tool in self._get_server_tools(self.mcp_server).items():
|
| 197 |
+
if hasattr(tool, "fn") and callable(tool.fn):
|
| 198 |
+
callables[tool_name] = tool.fn
|
| 199 |
+
|
| 200 |
+
# Add mode-specific tools available in current mode
|
| 201 |
+
for tool_name, mode_funcs in self._mode_tools.items():
|
| 202 |
+
if None in mode_funcs:
|
| 203 |
+
# Tool available in all modes (already in FastMCP if registered there)
|
| 204 |
+
if tool_name not in callables:
|
| 205 |
+
callables[tool_name] = mode_funcs[None]
|
| 206 |
+
elif current_mode in mode_funcs:
|
| 207 |
+
# Tool available in current mode only
|
| 208 |
+
callables[tool_name] = mode_funcs[current_mode]
|
| 209 |
+
|
| 210 |
+
return callables
|
| 211 |
+
|
| 212 |
+
def execute_code(self, code: str) -> Observation:
|
| 213 |
+
"""
|
| 214 |
+
Execute Python code with tools available as callables.
|
| 215 |
+
|
| 216 |
+
This enables the CodeAct pattern where agents write Python code
|
| 217 |
+
that calls tools directly as functions, avoiding JSON-RPC overhead.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
code: Python code to execute. Tools are available as functions
|
| 221 |
+
in the execution namespace. Set a variable named 'result'
|
| 222 |
+
to capture the return value.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
Observation with result in metadata["result"] or error in
|
| 226 |
+
metadata["error"].
|
| 227 |
+
"""
|
| 228 |
+
namespace = self.get_callables()
|
| 229 |
+
|
| 230 |
+
result_dict: Dict[str, Any] = {}
|
| 231 |
+
try:
|
| 232 |
+
exec(code, namespace, result_dict)
|
| 233 |
+
result = result_dict.get("result")
|
| 234 |
+
return Observation(done=False, reward=0.0, metadata={"result": result})
|
| 235 |
+
except SyntaxError as e:
|
| 236 |
+
return Observation(
|
| 237 |
+
done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"}
|
| 238 |
+
)
|
| 239 |
+
except Exception as e:
|
| 240 |
+
return Observation(done=False, reward=0.0, metadata={"error": str(e)})
|
| 241 |
+
|
| 242 |
+
def _validate_tool_names(self, mcp_server: Any) -> None:
|
| 243 |
+
"""
|
| 244 |
+
Validate that no tools use reserved names.
|
| 245 |
+
|
| 246 |
+
Reserved names (reset, step, state, close) are protected to maintain
|
| 247 |
+
the dual API boundary between infrastructure and agent APIs.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
mcp_server: The FastMCP server to validate.
|
| 251 |
+
|
| 252 |
+
Raises:
|
| 253 |
+
ValueError: If any tool uses a reserved name.
|
| 254 |
+
"""
|
| 255 |
+
tools_dict = self._get_server_tools(mcp_server)
|
| 256 |
+
if tools_dict:
|
| 257 |
+
tool_names = set(tools_dict.keys())
|
| 258 |
+
conflicts = tool_names & RESERVED_TOOL_NAMES
|
| 259 |
+
if conflicts:
|
| 260 |
+
raise ValueError(
|
| 261 |
+
f"MCP tools cannot use reserved names: {sorted(conflicts)}. "
|
| 262 |
+
f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def tool(self, mode: Optional[str] = None) -> Callable:
|
| 266 |
+
"""
|
| 267 |
+
Decorator for registering mode-aware tools.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
mode: Optional mode for the tool ("production" or "simulation").
|
| 271 |
+
If None, tool is available in all modes.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
A decorator function for registering tools.
|
| 275 |
+
|
| 276 |
+
Raises:
|
| 277 |
+
ValueError: If mode is not None, "production", or "simulation".
|
| 278 |
+
"""
|
| 279 |
+
if mode is not None and mode not in VALID_MODES:
|
| 280 |
+
raise ValueError(
|
| 281 |
+
f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None."
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def decorator(func: Callable) -> Callable:
|
| 285 |
+
tool_name = func.__name__
|
| 286 |
+
# Validate tool name is not reserved
|
| 287 |
+
if tool_name in RESERVED_TOOL_NAMES:
|
| 288 |
+
raise ValueError(
|
| 289 |
+
f"Tool name '{tool_name}' is reserved and cannot be used. "
|
| 290 |
+
f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# If mode is None, register with FastMCP as usual
|
| 294 |
+
if mode is None:
|
| 295 |
+
decorated_func = self.mcp_server.tool()(func)
|
| 296 |
+
self._mode_tools[tool_name][None] = func
|
| 297 |
+
return decorated_func
|
| 298 |
+
|
| 299 |
+
# For mode-specific tools, don't register with FastMCP
|
| 300 |
+
# Instead, track them ourselves
|
| 301 |
+
self._mode_tools[tool_name][mode] = func
|
| 302 |
+
|
| 303 |
+
# Extract schema information from function signature
|
| 304 |
+
sig = inspect.signature(func)
|
| 305 |
+
schema = {
|
| 306 |
+
"type": "object",
|
| 307 |
+
"properties": {},
|
| 308 |
+
"required": [],
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
for param_name, param in sig.parameters.items():
|
| 312 |
+
# Get type annotation
|
| 313 |
+
param_type = param.annotation
|
| 314 |
+
json_type = "string" # default
|
| 315 |
+
if param_type in (int, "int"):
|
| 316 |
+
json_type = "integer"
|
| 317 |
+
elif param_type in (float, "float"):
|
| 318 |
+
json_type = "number"
|
| 319 |
+
elif param_type in (bool, "bool"):
|
| 320 |
+
json_type = "boolean"
|
| 321 |
+
|
| 322 |
+
schema["properties"][param_name] = {"type": json_type}
|
| 323 |
+
|
| 324 |
+
# If no default value, it's required
|
| 325 |
+
if param.default == inspect.Parameter.empty:
|
| 326 |
+
schema["required"].append(param_name)
|
| 327 |
+
|
| 328 |
+
# Store the schema for this mode-specific tool
|
| 329 |
+
self._mode_tool_schemas[tool_name][mode] = {
|
| 330 |
+
"name": tool_name,
|
| 331 |
+
"description": func.__doc__ or "",
|
| 332 |
+
"input_schema": schema,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
return func
|
| 336 |
+
|
| 337 |
+
return decorator
|
| 338 |
+
|
| 339 |
+
def step(
|
| 340 |
+
self,
|
| 341 |
+
action: Action,
|
| 342 |
+
timeout_s: Optional[float] = None,
|
| 343 |
+
**kwargs: Any,
|
| 344 |
+
) -> Observation:
|
| 345 |
+
"""
|
| 346 |
+
Execute an action in the environment.
|
| 347 |
+
|
| 348 |
+
This method routes MCP-specific actions (ListToolsAction, CallToolAction)
|
| 349 |
+
to the appropriate handlers, while delegating all other actions to
|
| 350 |
+
the subclass's _step_impl() method.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
action: The action to execute. Can be:
|
| 354 |
+
- ListToolsAction: Returns available MCP tools
|
| 355 |
+
- CallToolAction: Invokes a specific MCP tool
|
| 356 |
+
- Any other Action: Delegated to _step_impl()
|
| 357 |
+
timeout_s: Optional timeout in seconds for the action.
|
| 358 |
+
Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions.
|
| 359 |
+
**kwargs: Additional arguments passed to handlers.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Observation appropriate to the action type:
|
| 363 |
+
- ListToolsObservation for ListToolsAction
|
| 364 |
+
- CallToolObservation for CallToolAction
|
| 365 |
+
- Subclass-defined Observation for other actions
|
| 366 |
+
"""
|
| 367 |
+
if isinstance(action, ListToolsAction):
|
| 368 |
+
return self._handle_list_tools()
|
| 369 |
+
elif isinstance(action, CallToolAction):
|
| 370 |
+
return self._handle_call_tool(action, timeout_s=timeout_s)
|
| 371 |
+
else:
|
| 372 |
+
return self._step_impl(action, timeout_s=timeout_s, **kwargs)
|
| 373 |
+
|
| 374 |
+
def _handle_list_tools(self) -> ListToolsObservation:
|
| 375 |
+
"""
|
| 376 |
+
Handle a ListToolsAction by querying the MCP server.
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
ListToolsObservation containing all available tools with their
|
| 380 |
+
names, descriptions, and input schemas, filtered by current mode.
|
| 381 |
+
"""
|
| 382 |
+
try:
|
| 383 |
+
# Get current mode
|
| 384 |
+
current_mode = getattr(self, "_mode", None)
|
| 385 |
+
|
| 386 |
+
# Start with tools from FastMCP server (mode=None tools)
|
| 387 |
+
tools_result = run_async_safely(self._async_list_tools())
|
| 388 |
+
|
| 389 |
+
# Build list of Tool objects
|
| 390 |
+
tools = []
|
| 391 |
+
|
| 392 |
+
# Add FastMCP tools that are not mode-specific
|
| 393 |
+
for tool in tools_result:
|
| 394 |
+
if tool.name not in self._mode_tool_schemas:
|
| 395 |
+
tools.append(
|
| 396 |
+
Tool(
|
| 397 |
+
name=tool.name,
|
| 398 |
+
description=tool.description or "",
|
| 399 |
+
input_schema=tool.inputSchema
|
| 400 |
+
if hasattr(tool, "inputSchema")
|
| 401 |
+
else {},
|
| 402 |
+
)
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Add mode-specific tools available in current mode
|
| 406 |
+
for tool_name, mode_schemas in self._mode_tool_schemas.items():
|
| 407 |
+
if None in mode_schemas:
|
| 408 |
+
# Tool available in all modes
|
| 409 |
+
schema = mode_schemas[None]
|
| 410 |
+
tools.append(
|
| 411 |
+
Tool(
|
| 412 |
+
name=schema["name"],
|
| 413 |
+
description=schema["description"],
|
| 414 |
+
input_schema=schema["input_schema"],
|
| 415 |
+
)
|
| 416 |
+
)
|
| 417 |
+
elif current_mode in mode_schemas:
|
| 418 |
+
# Tool available in current mode
|
| 419 |
+
schema = mode_schemas[current_mode]
|
| 420 |
+
tools.append(
|
| 421 |
+
Tool(
|
| 422 |
+
name=schema["name"],
|
| 423 |
+
description=schema["description"],
|
| 424 |
+
input_schema=schema["input_schema"],
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return ListToolsObservation(tools=tools)
|
| 429 |
+
|
| 430 |
+
except Exception as e:
|
| 431 |
+
# Return an observation with error in metadata
|
| 432 |
+
return ListToolsObservation(
|
| 433 |
+
tools=[],
|
| 434 |
+
metadata={
|
| 435 |
+
"error": str(e),
|
| 436 |
+
"error_type": "list_tools_failed",
|
| 437 |
+
},
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
async def _async_list_tools(self) -> list:
|
| 441 |
+
"""
|
| 442 |
+
Async helper to list tools from the MCP client.
|
| 443 |
+
|
| 444 |
+
Returns:
|
| 445 |
+
List of tool objects from the MCP server.
|
| 446 |
+
"""
|
| 447 |
+
async with self.mcp_client:
|
| 448 |
+
return await self.mcp_client.list_tools()
|
| 449 |
+
|
| 450 |
+
def _handle_call_tool(
|
| 451 |
+
self,
|
| 452 |
+
action: CallToolAction,
|
| 453 |
+
timeout_s: Optional[float] = None,
|
| 454 |
+
) -> CallToolObservation:
|
| 455 |
+
"""
|
| 456 |
+
Handle a CallToolAction by invoking the specified tool.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
action: The CallToolAction containing tool_name and arguments.
|
| 460 |
+
timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s).
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
CallToolObservation with the tool's result or an error.
|
| 464 |
+
"""
|
| 465 |
+
timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
|
| 466 |
+
|
| 467 |
+
# Check if this is a mode-specific tool
|
| 468 |
+
tool_name = action.tool_name
|
| 469 |
+
current_mode = getattr(self, "_mode", None)
|
| 470 |
+
|
| 471 |
+
if tool_name in self._mode_tools:
|
| 472 |
+
mode_info = self._mode_tools[tool_name]
|
| 473 |
+
|
| 474 |
+
# Check if tool is available in current mode
|
| 475 |
+
# Tool is available if:
|
| 476 |
+
# 1. It has a None mode (available in all modes), OR
|
| 477 |
+
# 2. It has an implementation for the current mode
|
| 478 |
+
if None in mode_info:
|
| 479 |
+
# Use the mode-agnostic version
|
| 480 |
+
func = mode_info[None]
|
| 481 |
+
elif current_mode in mode_info:
|
| 482 |
+
# Use the mode-specific version
|
| 483 |
+
func = mode_info[current_mode]
|
| 484 |
+
else:
|
| 485 |
+
# Tool not available in current mode
|
| 486 |
+
return CallToolObservation(
|
| 487 |
+
tool_name=tool_name,
|
| 488 |
+
result=None,
|
| 489 |
+
error=ToolError(
|
| 490 |
+
error_type=ToolErrorType.TOOL_NOT_FOUND,
|
| 491 |
+
message=f"Tool '{tool_name}' not available in {current_mode} mode",
|
| 492 |
+
),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# Call the mode-specific function directly
|
| 496 |
+
try:
|
| 497 |
+
# Check if function is async and await if necessary
|
| 498 |
+
if inspect.iscoroutinefunction(func):
|
| 499 |
+
result = run_async_safely(func(**action.arguments))
|
| 500 |
+
else:
|
| 501 |
+
result = func(**action.arguments)
|
| 502 |
+
|
| 503 |
+
# Wrap result in CallToolResult format to match FastMCP behavior
|
| 504 |
+
return CallToolObservation(
|
| 505 |
+
tool_name=tool_name,
|
| 506 |
+
result=CallToolResult(
|
| 507 |
+
content=[TextContent(type="text", text=str(result))],
|
| 508 |
+
structured_content={"result": result},
|
| 509 |
+
meta=None,
|
| 510 |
+
data=result,
|
| 511 |
+
is_error=False,
|
| 512 |
+
),
|
| 513 |
+
)
|
| 514 |
+
except Exception as e:
|
| 515 |
+
return CallToolObservation(
|
| 516 |
+
tool_name=tool_name,
|
| 517 |
+
result=None,
|
| 518 |
+
error=ToolError(
|
| 519 |
+
error_type=ToolErrorType.EXECUTION_ERROR,
|
| 520 |
+
message=str(e),
|
| 521 |
+
),
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Not a mode-specific tool, use FastMCP
|
| 525 |
+
try:
|
| 526 |
+
# Run the async call_tool with timeout
|
| 527 |
+
# Use run_async_safely to handle both sync and async contexts
|
| 528 |
+
result = run_async_safely(
|
| 529 |
+
asyncio.wait_for(
|
| 530 |
+
self._async_call_tool(action.tool_name, action.arguments),
|
| 531 |
+
timeout=timeout,
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
return CallToolObservation(
|
| 536 |
+
tool_name=action.tool_name,
|
| 537 |
+
result=result,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
except asyncio.TimeoutError:
|
| 541 |
+
return CallToolObservation(
|
| 542 |
+
tool_name=action.tool_name,
|
| 543 |
+
result=None,
|
| 544 |
+
error=ToolError(
|
| 545 |
+
error_type=ToolErrorType.TIMEOUT,
|
| 546 |
+
message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
|
| 547 |
+
),
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
except Exception as e:
|
| 551 |
+
error_message = str(e)
|
| 552 |
+
|
| 553 |
+
# Determine error type based on the exception
|
| 554 |
+
if (
|
| 555 |
+
"not found" in error_message.lower()
|
| 556 |
+
or "unknown tool" in error_message.lower()
|
| 557 |
+
):
|
| 558 |
+
error_type = ToolErrorType.TOOL_NOT_FOUND
|
| 559 |
+
elif (
|
| 560 |
+
"invalid" in error_message.lower()
|
| 561 |
+
or "argument" in error_message.lower()
|
| 562 |
+
):
|
| 563 |
+
error_type = ToolErrorType.INVALID_ARGS
|
| 564 |
+
else:
|
| 565 |
+
error_type = ToolErrorType.EXECUTION_ERROR
|
| 566 |
+
|
| 567 |
+
return CallToolObservation(
|
| 568 |
+
tool_name=action.tool_name,
|
| 569 |
+
result=None,
|
| 570 |
+
error=ToolError(
|
| 571 |
+
error_type=error_type,
|
| 572 |
+
message=error_message,
|
| 573 |
+
),
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
|
| 577 |
+
"""
|
| 578 |
+
Async helper to call a tool on the MCP server.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
tool_name: Name of the tool to invoke.
|
| 582 |
+
arguments: Dictionary of arguments to pass to the tool.
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
The result from the tool execution.
|
| 586 |
+
"""
|
| 587 |
+
async with self.mcp_client:
|
| 588 |
+
return await self.mcp_client.call_tool(tool_name, arguments)
|
| 589 |
+
|
| 590 |
+
@abstractmethod
|
| 591 |
+
def _step_impl(
|
| 592 |
+
self,
|
| 593 |
+
action: Action,
|
| 594 |
+
timeout_s: Optional[float] = None,
|
| 595 |
+
**kwargs: Any,
|
| 596 |
+
) -> Observation:
|
| 597 |
+
"""
|
| 598 |
+
Handle non-MCP actions in the environment.
|
| 599 |
+
|
| 600 |
+
Subclasses must implement this method to handle any actions that are
|
| 601 |
+
not ListToolsAction or CallToolAction. This is where environment-specific
|
| 602 |
+
action processing should occur.
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
action: The action to execute (guaranteed not to be an MCP action).
|
| 606 |
+
timeout_s: Optional timeout in seconds.
|
| 607 |
+
**kwargs: Additional arguments.
|
| 608 |
+
|
| 609 |
+
Returns:
|
| 610 |
+
An Observation appropriate for the action.
|
| 611 |
+
"""
|
| 612 |
+
pass
|
| 613 |
+
|
| 614 |
+
def close(self) -> None:
|
| 615 |
+
"""
|
| 616 |
+
Clean up resources used by the environment.
|
| 617 |
+
|
| 618 |
+
This method cleans up the MCP client and any other resources.
|
| 619 |
+
Subclasses should call super().close() if they override this method.
|
| 620 |
+
"""
|
| 621 |
+
# The MCP client uses async context manager, so cleanup happens
|
| 622 |
+
# automatically when the context exits. We just clear references.
|
| 623 |
+
self.mcp_client = None
|
| 624 |
+
self.mcp_server = None
|
src/core/env_server/mcp_types.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
MCP (Model Context Protocol) type definitions for OpenEnv.
|
| 9 |
+
|
| 10 |
+
This module defines strongly typed models for MCP tool discovery and invocation,
|
| 11 |
+
following RFC 003. These types map MCP's REST-like API (tools/list, tools/call)
|
| 12 |
+
to Gym-style action types.
|
| 13 |
+
|
| 14 |
+
Key design decisions:
|
| 15 |
+
- Tool discovery (list_tools) does NOT require reset() first
|
| 16 |
+
- Reserved tool names (reset, step, state, close) are prohibited
|
| 17 |
+
- Both step() and WebSocket /mcp paths are supported
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from enum import Enum
|
| 21 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
| 22 |
+
|
| 23 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 24 |
+
|
| 25 |
+
from .types import Action, BaseMessage, Observation
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# =============================================================================
|
| 29 |
+
# JSON-RPC 2.0 Types
|
| 30 |
+
# =============================================================================
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class JsonRpcErrorCode(int, Enum):
|
| 34 |
+
"""
|
| 35 |
+
Standard JSON-RPC 2.0 error codes.
|
| 36 |
+
|
| 37 |
+
See: https://www.jsonrpc.org/specification#error_object
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# Standard JSON-RPC errors
|
| 41 |
+
PARSE_ERROR = -32700 # Invalid JSON was received
|
| 42 |
+
INVALID_REQUEST = -32600 # JSON is not a valid Request object
|
| 43 |
+
METHOD_NOT_FOUND = -32601 # Method does not exist / is not available
|
| 44 |
+
INVALID_PARAMS = -32602 # Invalid method parameter(s)
|
| 45 |
+
INTERNAL_ERROR = -32603 # Internal JSON-RPC error
|
| 46 |
+
|
| 47 |
+
# Server errors (reserved for implementation-defined errors)
|
| 48 |
+
SERVER_ERROR = -32000 # Generic server error
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class McpMethod(str, Enum):
|
| 52 |
+
"""Supported MCP method names."""
|
| 53 |
+
|
| 54 |
+
TOOLS_LIST = "tools/list"
|
| 55 |
+
TOOLS_CALL = "tools/call"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class JsonRpcError(BaseModel):
|
| 59 |
+
"""
|
| 60 |
+
JSON-RPC 2.0 error object.
|
| 61 |
+
|
| 62 |
+
See: https://www.jsonrpc.org/specification#error_object
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
model_config = ConfigDict(extra="forbid")
|
| 66 |
+
|
| 67 |
+
code: int = Field(description="Error code indicating the error type")
|
| 68 |
+
message: str = Field(description="Short description of the error")
|
| 69 |
+
data: Optional[Any] = Field(
|
| 70 |
+
default=None, description="Additional error information"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
@classmethod
|
| 74 |
+
def from_code(
|
| 75 |
+
cls, code: JsonRpcErrorCode, message: Optional[str] = None, data: Any = None
|
| 76 |
+
) -> "JsonRpcError":
|
| 77 |
+
"""Create an error from a standard error code."""
|
| 78 |
+
default_messages = {
|
| 79 |
+
JsonRpcErrorCode.PARSE_ERROR: "Parse error",
|
| 80 |
+
JsonRpcErrorCode.INVALID_REQUEST: "Invalid Request",
|
| 81 |
+
JsonRpcErrorCode.METHOD_NOT_FOUND: "Method not found",
|
| 82 |
+
JsonRpcErrorCode.INVALID_PARAMS: "Invalid params",
|
| 83 |
+
JsonRpcErrorCode.INTERNAL_ERROR: "Internal error",
|
| 84 |
+
JsonRpcErrorCode.SERVER_ERROR: "Server error",
|
| 85 |
+
}
|
| 86 |
+
return cls(
|
| 87 |
+
code=code.value,
|
| 88 |
+
message=message or default_messages.get(code, "Unknown error"),
|
| 89 |
+
data=data,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class JsonRpcRequest(BaseModel):
|
| 94 |
+
"""
|
| 95 |
+
JSON-RPC 2.0 request object.
|
| 96 |
+
|
| 97 |
+
See: https://www.jsonrpc.org/specification#request_object
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
model_config = ConfigDict(extra="forbid")
|
| 101 |
+
|
| 102 |
+
jsonrpc: Literal["2.0"] = Field(description="JSON-RPC version, must be '2.0'")
|
| 103 |
+
method: str = Field(description="Name of the method to be invoked")
|
| 104 |
+
params: Dict[str, Any] = Field(
|
| 105 |
+
default_factory=dict, description="Parameter values for the method"
|
| 106 |
+
)
|
| 107 |
+
id: Optional[Union[str, int]] = Field(
|
| 108 |
+
default=None, description="Request identifier established by the client"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class JsonRpcResponse(BaseModel):
|
| 113 |
+
"""
|
| 114 |
+
JSON-RPC 2.0 response object.
|
| 115 |
+
|
| 116 |
+
Per JSON-RPC 2.0 spec, a response has either 'result' or 'error', not both.
|
| 117 |
+
This model excludes None values during serialization to comply with the spec.
|
| 118 |
+
|
| 119 |
+
See: https://www.jsonrpc.org/specification#response_object
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
model_config = ConfigDict(extra="forbid")
|
| 123 |
+
|
| 124 |
+
jsonrpc: Literal["2.0"] = Field(default="2.0", description="JSON-RPC version")
|
| 125 |
+
result: Optional[Any] = Field(
|
| 126 |
+
default=None, description="Result of the method invocation"
|
| 127 |
+
)
|
| 128 |
+
error: Optional[JsonRpcError] = Field(
|
| 129 |
+
default=None, description="Error object if method invocation failed"
|
| 130 |
+
)
|
| 131 |
+
id: Optional[Union[str, int]] = Field(
|
| 132 |
+
default=None, description="Request identifier from the request"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
| 136 |
+
"""Serialize to dict, excluding result or error when None (JSON-RPC compliance)."""
|
| 137 |
+
# Always include jsonrpc and id, but only include result OR error
|
| 138 |
+
data: Dict[str, Any] = {"jsonrpc": self.jsonrpc, "id": self.id}
|
| 139 |
+
if self.error is not None:
|
| 140 |
+
data["error"] = (
|
| 141 |
+
self.error.model_dump()
|
| 142 |
+
if hasattr(self.error, "model_dump")
|
| 143 |
+
else self.error
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
# Only include result if there's no error
|
| 147 |
+
data["result"] = self.result
|
| 148 |
+
return data
|
| 149 |
+
|
| 150 |
+
def model_dump_json(self, **kwargs) -> str:
|
| 151 |
+
"""Serialize to JSON string, excluding result or error when None (JSON-RPC compliance)."""
|
| 152 |
+
import json
|
| 153 |
+
|
| 154 |
+
return json.dumps(self.model_dump())
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def success(
|
| 158 |
+
cls, result: Any, request_id: Optional[Union[str, int]] = None
|
| 159 |
+
) -> "JsonRpcResponse":
|
| 160 |
+
"""Create a success response."""
|
| 161 |
+
return cls(result=result, id=request_id)
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def error_response(
|
| 165 |
+
cls,
|
| 166 |
+
code: JsonRpcErrorCode,
|
| 167 |
+
message: Optional[str] = None,
|
| 168 |
+
data: Any = None,
|
| 169 |
+
request_id: Optional[Union[str, int]] = None,
|
| 170 |
+
) -> "JsonRpcResponse":
|
| 171 |
+
"""Create an error response from a standard error code."""
|
| 172 |
+
return cls(
|
| 173 |
+
error=JsonRpcError.from_code(code, message, data),
|
| 174 |
+
id=request_id,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# =============================================================================
|
| 179 |
+
# MCP Tool Types
|
| 180 |
+
# =============================================================================
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Tool(BaseModel):
|
| 184 |
+
"""
|
| 185 |
+
Strongly typed MCP tool specification.
|
| 186 |
+
|
| 187 |
+
Follows the MCP ToolSpec format for tool discovery.
|
| 188 |
+
See: https://modelcontextprotocol.io/specification/2025-06-18/server/tools
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
model_config = ConfigDict(extra="forbid")
|
| 192 |
+
|
| 193 |
+
name: str = Field(description="Unique identifier for the tool")
|
| 194 |
+
description: str = Field(
|
| 195 |
+
description="Human-readable description of what the tool does"
|
| 196 |
+
)
|
| 197 |
+
input_schema: Dict[str, Any] = Field(
|
| 198 |
+
description="JSON Schema for the tool's input parameters"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class ToolErrorType(str, Enum):
|
| 203 |
+
"""Types of errors that can occur during tool execution."""
|
| 204 |
+
|
| 205 |
+
EXECUTION_ERROR = "execution_error" # Tool ran but failed
|
| 206 |
+
INVALID_ARGS = "invalid_args" # Invalid arguments provided
|
| 207 |
+
TRANSPORT_ERROR = "transport_error" # Communication failure
|
| 208 |
+
TOOL_NOT_FOUND = "tool_not_found" # Tool doesn't exist
|
| 209 |
+
TIMEOUT = "timeout" # Operation timed out
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class ToolError(BaseModel):
|
| 213 |
+
"""
|
| 214 |
+
Structured error for tool execution failures.
|
| 215 |
+
|
| 216 |
+
This is used for transport/framework errors, NOT for errors returned
|
| 217 |
+
by the tool itself (those go in the result field).
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
model_config = ConfigDict(extra="forbid")
|
| 221 |
+
|
| 222 |
+
error_type: ToolErrorType = Field(description="Category of the error")
|
| 223 |
+
message: str = Field(description="Human-readable error message")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# --- MCP Actions ---
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ListToolsAction(Action):
|
| 230 |
+
"""
|
| 231 |
+
Request list of available tools from the environment.
|
| 232 |
+
|
| 233 |
+
This action triggers MCP's tools/list operation and returns
|
| 234 |
+
all available tools with their schemas.
|
| 235 |
+
|
| 236 |
+
Note: Does NOT require reset() to be called first.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
type: Literal["list_tools"] = Field(
|
| 240 |
+
default="list_tools", description="Action type discriminator"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class CallToolAction(Action):
|
| 245 |
+
"""
|
| 246 |
+
Call a specific tool via MCP.
|
| 247 |
+
|
| 248 |
+
This action triggers MCP's tools/call operation with the
|
| 249 |
+
specified tool name and arguments.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
type: Literal["call_tool"] = Field(
|
| 253 |
+
default="call_tool", description="Action type discriminator"
|
| 254 |
+
)
|
| 255 |
+
tool_name: str = Field(description="Name of the tool to call")
|
| 256 |
+
arguments: Dict[str, Any] = Field(
|
| 257 |
+
default_factory=dict, description="Arguments to pass to the tool"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# --- MCP Observations ---
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class ListToolsObservation(Observation):
|
| 265 |
+
"""
|
| 266 |
+
Response containing available tools.
|
| 267 |
+
|
| 268 |
+
Returned when processing a ListToolsAction.
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
tools: List[Tool] = Field(description="List of available tools with their schemas")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class CallToolObservation(Observation):
|
| 275 |
+
"""
|
| 276 |
+
Response from tool execution.
|
| 277 |
+
|
| 278 |
+
Contains the tool's result or an error if the call failed.
|
| 279 |
+
Tool-specific errors (from the tool itself) are included in the result.
|
| 280 |
+
Transport/framework errors use the error field.
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
tool_name: str = Field(description="Name of the tool that was called")
|
| 284 |
+
result: Any = Field(
|
| 285 |
+
default=None, description="Tool-specific result (may include tool errors)"
|
| 286 |
+
)
|
| 287 |
+
error: Optional[ToolError] = Field(
|
| 288 |
+
default=None, description="Transport/framework error if call failed"
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# --- WebSocket Message Types for MCP ---
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class WSMCPMessage(BaseMessage):
|
| 296 |
+
"""
|
| 297 |
+
WebSocket message for MCP JSON-RPC requests.
|
| 298 |
+
|
| 299 |
+
Allows direct MCP access via WebSocket for production inference,
|
| 300 |
+
bypassing the step() API.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
type: Literal["mcp"] = Field(default="mcp", description="Message type")
|
| 304 |
+
data: Dict[str, Any] = Field(description="JSON-RPC payload (method, params, id)")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class WSMCPResponse(BaseModel):
|
| 308 |
+
"""
|
| 309 |
+
WebSocket response for MCP JSON-RPC.
|
| 310 |
+
|
| 311 |
+
Contains the JSON-RPC response from the MCP server.
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
model_config = ConfigDict(extra="forbid")
|
| 315 |
+
|
| 316 |
+
type: str = Field(default="mcp", description="Response type")
|
| 317 |
+
data: Dict[str, Any] = Field(description="JSON-RPC response payload")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
# Reserved tool names that cannot be used (protects dual API boundary)
|
| 321 |
+
RESERVED_TOOL_NAMES = frozenset(["reset", "step", "state", "close"])
|
src/core/env_server/route_config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Route configuration utilities for declarative FastAPI route registration.
|
| 9 |
+
|
| 10 |
+
This module provides utilities to reduce boilerplate in route registration
|
| 11 |
+
by using configuration objects instead of repeated function calls.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Callable, List, Type
|
| 16 |
+
|
| 17 |
+
from fastapi import FastAPI
|
| 18 |
+
from pydantic import BaseModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class GetEndpointConfig:
|
| 23 |
+
"""Configuration for a simple GET endpoint."""
|
| 24 |
+
|
| 25 |
+
path: str
|
| 26 |
+
handler: Callable[[], BaseModel | dict]
|
| 27 |
+
response_model: Type[BaseModel] | type[dict]
|
| 28 |
+
tag: str
|
| 29 |
+
summary: str
|
| 30 |
+
description: str
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Register multiple GET endpoints from configuration.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
app: FastAPI application instance
|
| 39 |
+
configs: List of GET endpoint configurations
|
| 40 |
+
"""
|
| 41 |
+
for config in configs:
|
| 42 |
+
# Capture handler in a closure to avoid non-serializable default parameter
|
| 43 |
+
def make_endpoint(
|
| 44 |
+
handler: Callable[[], BaseModel | dict],
|
| 45 |
+
) -> Callable[[], BaseModel | dict]:
|
| 46 |
+
async def endpoint() -> BaseModel | dict:
|
| 47 |
+
return handler()
|
| 48 |
+
|
| 49 |
+
return endpoint
|
| 50 |
+
|
| 51 |
+
app.get(
|
| 52 |
+
config.path,
|
| 53 |
+
response_model=config.response_model,
|
| 54 |
+
tags=[config.tag],
|
| 55 |
+
summary=config.summary,
|
| 56 |
+
description=config.description,
|
| 57 |
+
)(make_endpoint(config.handler))
|
src/core/env_server/serialization.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Shared serialization and deserialization utilities for OpenEnv HTTP servers.
|
| 9 |
+
|
| 10 |
+
This module provides common utilities for converting between JSON dictionaries
|
| 11 |
+
and Pydantic models (Action/Observation) to eliminate code duplication across
|
| 12 |
+
HTTP server and web interface implementations.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from typing import Any, Dict, Type
|
| 16 |
+
|
| 17 |
+
from .types import Action, Observation
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
|
| 21 |
+
"""
|
| 22 |
+
Convert JSON dict to Action instance using Pydantic validation.
|
| 23 |
+
|
| 24 |
+
This is a basic deserialization that works for most environments.
|
| 25 |
+
For special cases (e.g., tensor fields, custom type conversions),
|
| 26 |
+
use deserialize_action_with_preprocessing().
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
action_data: Dictionary containing action data
|
| 30 |
+
action_cls: The Action subclass to instantiate
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Action instance
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
ValidationError: If action_data is invalid for the action class
|
| 37 |
+
|
| 38 |
+
Note:
|
| 39 |
+
This uses Pydantic's model_validate() for automatic validation.
|
| 40 |
+
"""
|
| 41 |
+
return action_cls.model_validate(action_data)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def deserialize_action_with_preprocessing(
|
| 45 |
+
action_data: Dict[str, Any], action_cls: Type[Action]
|
| 46 |
+
) -> Action:
|
| 47 |
+
"""
|
| 48 |
+
Convert JSON dict to Action instance with preprocessing for special types.
|
| 49 |
+
|
| 50 |
+
This version handles common type conversions needed for web interfaces:
|
| 51 |
+
- Converting lists/strings to tensors for 'tokens' field
|
| 52 |
+
- Converting string action_id to int
|
| 53 |
+
- Other custom preprocessing as needed
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
action_data: Dictionary containing action data
|
| 57 |
+
action_cls: The Action subclass to instantiate
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Action instance
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
ValidationError: If action_data is invalid for the action class
|
| 64 |
+
"""
|
| 65 |
+
processed_data = {}
|
| 66 |
+
|
| 67 |
+
for key, value in action_data.items():
|
| 68 |
+
if key == "tokens" and isinstance(value, (list, str)):
|
| 69 |
+
# Convert list or string to tensor
|
| 70 |
+
if isinstance(value, str):
|
| 71 |
+
# If it's a string, try to parse it as a list of numbers
|
| 72 |
+
try:
|
| 73 |
+
import json
|
| 74 |
+
|
| 75 |
+
value = json.loads(value)
|
| 76 |
+
except Exception:
|
| 77 |
+
# If parsing fails, treat as empty list
|
| 78 |
+
value = []
|
| 79 |
+
if isinstance(value, list):
|
| 80 |
+
try:
|
| 81 |
+
import torch # type: ignore
|
| 82 |
+
|
| 83 |
+
processed_data[key] = torch.tensor(value, dtype=torch.long)
|
| 84 |
+
except ImportError:
|
| 85 |
+
# If torch not available, keep as list
|
| 86 |
+
processed_data[key] = value
|
| 87 |
+
else:
|
| 88 |
+
processed_data[key] = value
|
| 89 |
+
elif key == "action_id" and isinstance(value, str):
|
| 90 |
+
# Convert action_id from string to int
|
| 91 |
+
try:
|
| 92 |
+
processed_data[key] = int(value)
|
| 93 |
+
except ValueError:
|
| 94 |
+
# If conversion fails, keep original value
|
| 95 |
+
processed_data[key] = value
|
| 96 |
+
else:
|
| 97 |
+
processed_data[key] = value
|
| 98 |
+
|
| 99 |
+
return action_cls.model_validate(processed_data)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def serialize_observation(observation: Observation) -> Dict[str, Any]:
|
| 103 |
+
"""
|
| 104 |
+
Convert Observation instance to JSON-compatible dict using Pydantic.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
observation: Observation instance
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Dictionary compatible with EnvClient._parse_result()
|
| 111 |
+
|
| 112 |
+
The format matches what EnvClient expects:
|
| 113 |
+
{
|
| 114 |
+
"observation": {...}, # Observation fields
|
| 115 |
+
"reward": float | None,
|
| 116 |
+
"done": bool,
|
| 117 |
+
}
|
| 118 |
+
"""
|
| 119 |
+
# Use Pydantic's model_dump() for serialization
|
| 120 |
+
obs_dict = observation.model_dump(
|
| 121 |
+
exclude={
|
| 122 |
+
"reward",
|
| 123 |
+
"done",
|
| 124 |
+
"metadata",
|
| 125 |
+
} # Exclude these from observation dict
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Extract reward and done directly from the observation
|
| 129 |
+
reward = observation.reward
|
| 130 |
+
done = observation.done
|
| 131 |
+
|
| 132 |
+
# Return in EnvClient expected format
|
| 133 |
+
return {
|
| 134 |
+
"observation": obs_dict,
|
| 135 |
+
"reward": reward,
|
| 136 |
+
"done": done,
|
| 137 |
+
}
|
src/core/env_server/types.py
CHANGED
|
@@ -4,54 +4,384 @@
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
-
from
|
| 8 |
-
from typing import Any, Dict,
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
# Type aliases
|
| 12 |
Scalar = Union[int, float, bool]
|
| 13 |
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
"""Base class for environment state."""
|
| 34 |
|
| 35 |
-
|
| 36 |
-
step_count: int = 0
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"""Result of code execution containing stdout, stderr, and exit code."""
|
| 42 |
|
| 43 |
-
stdout: str
|
| 44 |
-
stderr: str
|
| 45 |
-
exit_code: int
|
| 46 |
|
| 47 |
|
| 48 |
-
|
| 49 |
-
class EnvironmentMetadata:
|
| 50 |
"""Metadata about an environment for documentation and UI purposes."""
|
| 51 |
-
|
| 52 |
-
name: str
|
| 53 |
-
description: str
|
| 54 |
-
readme_content: Optional[str] =
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
# This source code is licensed under the BSD-style license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
from enum import Enum
|
| 8 |
+
from typing import Annotated, Any, Dict, Literal, Optional, Union
|
| 9 |
+
|
| 10 |
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
| 11 |
|
| 12 |
|
| 13 |
# Type aliases
|
| 14 |
Scalar = Union[int, float, bool]
|
| 15 |
|
| 16 |
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# Enums for Type Safety
|
| 19 |
+
# =============================================================================
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ServerMode(str, Enum):
|
| 23 |
+
"""Server operation mode."""
|
| 24 |
+
|
| 25 |
+
SIMULATION = "simulation"
|
| 26 |
+
PRODUCTION = "production"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class HealthStatus(str, Enum):
|
| 30 |
+
"""Server health status values."""
|
| 31 |
+
|
| 32 |
+
HEALTHY = "healthy"
|
| 33 |
+
UNHEALTHY = "unhealthy"
|
| 34 |
+
DEGRADED = "degraded"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class WSErrorCode(str, Enum):
|
| 38 |
+
"""WebSocket error codes for structured error handling."""
|
| 39 |
+
|
| 40 |
+
INVALID_JSON = "INVALID_JSON"
|
| 41 |
+
UNKNOWN_TYPE = "UNKNOWN_TYPE"
|
| 42 |
+
VALIDATION_ERROR = "VALIDATION_ERROR"
|
| 43 |
+
EXECUTION_ERROR = "EXECUTION_ERROR"
|
| 44 |
+
CAPACITY_REACHED = "CAPACITY_REACHED"
|
| 45 |
+
FACTORY_ERROR = "FACTORY_ERROR"
|
| 46 |
+
SESSION_ERROR = "SESSION_ERROR"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# =============================================================================
|
| 50 |
+
# Core Types
|
| 51 |
+
# =============================================================================
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Action(BaseModel):
|
| 55 |
+
"""Base class for all environment actions.
|
| 56 |
+
|
| 57 |
+
All action subclasses should inherit from this base class.
|
| 58 |
+
Uses Pydantic for automatic validation and serialization.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
model_config = ConfigDict(
|
| 62 |
+
extra="forbid", # Reject unknown fields
|
| 63 |
+
validate_assignment=True, # Validate on field assignment
|
| 64 |
+
arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc.
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
metadata: Dict[str, Any] = Field(
|
| 68 |
+
default_factory=dict, description="Additional metadata for the action"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Observation(BaseModel):
|
| 73 |
+
"""Base class for all environment observations.
|
| 74 |
+
|
| 75 |
+
All observation subclasses should inherit from this base class.
|
| 76 |
+
Uses Pydantic for automatic validation and serialization.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
model_config = ConfigDict(
|
| 80 |
+
extra="forbid",
|
| 81 |
+
validate_assignment=True,
|
| 82 |
+
arbitrary_types_allowed=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
done: bool = Field(default=False, description="Whether the episode has terminated")
|
| 86 |
+
reward: bool | int | float | None = Field(
|
| 87 |
+
default=None, description="Reward signal from the last action"
|
| 88 |
+
)
|
| 89 |
+
metadata: Dict[str, Any] = Field(
|
| 90 |
+
default_factory=dict, description="Additional metadata for the observation"
|
| 91 |
+
)
|
| 92 |
|
|
|
|
| 93 |
|
| 94 |
+
class ResetRequest(BaseModel):
|
| 95 |
+
"""Request model for environment reset."""
|
| 96 |
|
| 97 |
+
model_config = ConfigDict(
|
| 98 |
+
extra="allow", # Allow extra fields for custom reset parameters
|
| 99 |
+
json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]},
|
| 100 |
+
)
|
| 101 |
|
| 102 |
+
seed: Optional[int] = Field(
|
| 103 |
+
default=None, ge=0, description="Random seed for reproducible episodes"
|
| 104 |
+
)
|
| 105 |
+
episode_id: Optional[str] = Field(
|
| 106 |
+
default=None, max_length=255, description="Custom episode identifier"
|
| 107 |
+
)
|
| 108 |
|
| 109 |
|
| 110 |
+
class ResetResponse(BaseModel):
|
| 111 |
+
"""Response model for environment reset."""
|
|
|
|
| 112 |
|
| 113 |
+
model_config = ConfigDict(extra="forbid")
|
|
|
|
| 114 |
|
| 115 |
+
observation: Dict[str, Any] = Field(
|
| 116 |
+
..., description="Initial observation from the environment"
|
| 117 |
+
)
|
| 118 |
+
reward: Optional[float] = Field(
|
| 119 |
+
default=None, description="Initial reward (typically None at reset)"
|
| 120 |
+
)
|
| 121 |
+
done: bool = Field(
|
| 122 |
+
default=False, description="Whether episode is already done (typically False)"
|
| 123 |
+
)
|
| 124 |
|
| 125 |
+
|
| 126 |
+
class StepRequest(BaseModel):
|
| 127 |
+
"""Request model for environment step."""
|
| 128 |
+
|
| 129 |
+
model_config = ConfigDict(
|
| 130 |
+
extra="allow", # Allow extra fields for custom step parameters
|
| 131 |
+
json_schema_extra={
|
| 132 |
+
"examples": [
|
| 133 |
+
{"action": {"value": 1}, "timeout_s": 30.0},
|
| 134 |
+
{"action": {"value": 1}, "render": True, "verbose": False},
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
action: Dict[str, Any] = Field(
|
| 140 |
+
...,
|
| 141 |
+
description="Action to execute, must conform to environment's action schema",
|
| 142 |
+
)
|
| 143 |
+
timeout_s: Optional[float] = Field(
|
| 144 |
+
default=None,
|
| 145 |
+
gt=0,
|
| 146 |
+
description="Optional timeout in seconds for action execution",
|
| 147 |
+
)
|
| 148 |
+
request_id: Optional[str] = Field(
|
| 149 |
+
default=None,
|
| 150 |
+
max_length=255,
|
| 151 |
+
description="Optional request identifier for tracking",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class StepResponse(BaseModel):
|
| 156 |
+
"""Response model for environment step."""
|
| 157 |
+
|
| 158 |
+
model_config = ConfigDict(extra="forbid")
|
| 159 |
+
|
| 160 |
+
observation: Dict[str, Any] = Field(
|
| 161 |
+
..., description="Observation resulting from the action"
|
| 162 |
+
)
|
| 163 |
+
reward: Optional[float] = Field(
|
| 164 |
+
default=None, description="Reward signal from the action"
|
| 165 |
+
)
|
| 166 |
+
done: bool = Field(default=False, description="Whether the episode has terminated")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class BaseMessage(BaseModel):
|
| 170 |
+
"""Base class for WebSocket messages with shared configuration."""
|
| 171 |
+
|
| 172 |
+
model_config = ConfigDict(
|
| 173 |
+
extra="forbid",
|
| 174 |
+
validate_assignment=True,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class State(BaseModel):
|
| 179 |
+
"""Base class for environment state.
|
| 180 |
+
|
| 181 |
+
Represents internal environment state, separate from observations.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
model_config = ConfigDict(
|
| 185 |
+
extra="allow", # Allow extra fields for flexibility
|
| 186 |
+
validate_assignment=True,
|
| 187 |
+
arbitrary_types_allowed=True,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
episode_id: Optional[str] = Field(
|
| 191 |
+
default=None, description="Unique identifier for the current episode"
|
| 192 |
+
)
|
| 193 |
+
step_count: int = Field(
|
| 194 |
+
default=0,
|
| 195 |
+
ge=0, # Greater than or equal to 0
|
| 196 |
+
description="Number of steps taken in the current episode",
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class CodeExecResult(BaseMessage):
|
| 201 |
"""Result of code execution containing stdout, stderr, and exit code."""
|
| 202 |
|
| 203 |
+
stdout: str = Field(description="Standard output from code execution")
|
| 204 |
+
stderr: str = Field(description="Standard error from code execution")
|
| 205 |
+
exit_code: int = Field(description="Exit code from code execution")
|
| 206 |
|
| 207 |
|
| 208 |
+
class EnvironmentMetadata(BaseMessage):
|
|
|
|
| 209 |
"""Metadata about an environment for documentation and UI purposes."""
|
| 210 |
+
|
| 211 |
+
name: str = Field(description="Name of the environment")
|
| 212 |
+
description: str = Field(description="Description of what the environment does")
|
| 213 |
+
readme_content: Optional[str] = Field(
|
| 214 |
+
default=None, description="Content of the README file for the environment"
|
| 215 |
+
)
|
| 216 |
+
version: Optional[str] = Field(
|
| 217 |
+
default=None, description="Version of the environment"
|
| 218 |
+
)
|
| 219 |
+
author: Optional[str] = Field(default=None, description="Author of the environment")
|
| 220 |
+
documentation_url: Optional[str] = Field(
|
| 221 |
+
default=None, description="URL to the environment's documentation"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class SchemaResponse(BaseMessage):
|
| 226 |
+
"""Response model for the combined schema endpoint."""
|
| 227 |
+
|
| 228 |
+
action: Dict[str, Any] = Field(
|
| 229 |
+
description="JSON schema for actions accepted by this environment"
|
| 230 |
+
)
|
| 231 |
+
observation: Dict[str, Any] = Field(
|
| 232 |
+
description="JSON schema for observations returned by this environment"
|
| 233 |
+
)
|
| 234 |
+
state: Dict[str, Any] = Field(
|
| 235 |
+
description="JSON schema for environment state objects"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class HealthResponse(BaseMessage):
|
| 240 |
+
"""Response model for health check endpoint."""
|
| 241 |
+
|
| 242 |
+
status: HealthStatus = Field(
|
| 243 |
+
default=HealthStatus.HEALTHY,
|
| 244 |
+
description="Health status of the environment server",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class WSResetMessage(BaseMessage):
|
| 249 |
+
"""WebSocket message to reset the environment."""
|
| 250 |
+
|
| 251 |
+
type: Literal["reset"] = Field(default="reset", description="Message type")
|
| 252 |
+
data: Dict[str, Any] = Field(
|
| 253 |
+
default_factory=dict,
|
| 254 |
+
description="Optional reset parameters (seed, episode_id, etc.)",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class WSStepMessage(BaseMessage):
|
| 259 |
+
"""WebSocket message to execute a step."""
|
| 260 |
+
|
| 261 |
+
type: Literal["step"] = Field(default="step", description="Message type")
|
| 262 |
+
data: Dict[str, Any] = Field(
|
| 263 |
+
..., description="Action data conforming to environment's action schema"
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class WSStateMessage(BaseMessage):
|
| 268 |
+
"""WebSocket message to request current state."""
|
| 269 |
+
|
| 270 |
+
type: Literal["state"] = Field(default="state", description="Message type")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class WSCloseMessage(BaseMessage):
|
| 274 |
+
"""WebSocket message to close the session."""
|
| 275 |
+
|
| 276 |
+
type: Literal["close"] = Field(default="close", description="Message type")
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# Discriminated union for incoming WebSocket messages
|
| 280 |
+
# Note: WSMCPMessage is defined in mcp_types.py to avoid circular imports
|
| 281 |
+
# The union here covers the core message types; MCP messages are handled separately
|
| 282 |
+
WSIncomingMessage = Annotated[
|
| 283 |
+
WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage,
|
| 284 |
+
Field(discriminator="type"),
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class WSObservationResponse(BaseModel):
|
| 289 |
+
"""WebSocket response containing an observation."""
|
| 290 |
+
|
| 291 |
+
model_config = ConfigDict(extra="forbid")
|
| 292 |
+
|
| 293 |
+
type: Literal["observation"] = Field(
|
| 294 |
+
default="observation", description="Response type"
|
| 295 |
+
)
|
| 296 |
+
data: Dict[str, Any] = Field(description="Observation data")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class WSStateResponse(BaseModel):
|
| 300 |
+
"""WebSocket response containing environment state."""
|
| 301 |
+
|
| 302 |
+
model_config = ConfigDict(extra="forbid")
|
| 303 |
+
|
| 304 |
+
type: Literal["state"] = Field(default="state", description="Response type")
|
| 305 |
+
data: Dict[str, Any] = Field(description="State data")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class WSErrorResponse(BaseModel):
|
| 309 |
+
"""WebSocket response for errors."""
|
| 310 |
+
|
| 311 |
+
model_config = ConfigDict(extra="forbid")
|
| 312 |
+
|
| 313 |
+
type: Literal["error"] = Field(default="error", description="Response type")
|
| 314 |
+
data: Dict[str, Any] = Field(description="Error details including message and code")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class ConcurrencyConfig(BaseMessage):
|
| 318 |
+
"""Configuration for concurrent environment sessions."""
|
| 319 |
+
|
| 320 |
+
max_concurrent_envs: int = Field(
|
| 321 |
+
default=1,
|
| 322 |
+
ge=1,
|
| 323 |
+
description="Maximum number of concurrent WebSocket sessions allowed",
|
| 324 |
+
)
|
| 325 |
+
session_timeout: Optional[float] = Field(
|
| 326 |
+
default=None,
|
| 327 |
+
gt=0,
|
| 328 |
+
description="Timeout in seconds for inactive sessions. None means no timeout.",
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class ServerCapacityStatus(BaseMessage):
|
| 333 |
+
"""Status of server capacity for concurrent sessions."""
|
| 334 |
+
|
| 335 |
+
active_sessions: int = Field(
|
| 336 |
+
ge=0,
|
| 337 |
+
description="Number of currently active sessions",
|
| 338 |
+
)
|
| 339 |
+
max_sessions: int = Field(
|
| 340 |
+
ge=1,
|
| 341 |
+
description="Maximum number of allowed sessions",
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
@model_validator(mode="after")
|
| 345 |
+
def check_capacity_bounds(self) -> "ServerCapacityStatus":
|
| 346 |
+
if self.active_sessions > self.max_sessions:
|
| 347 |
+
raise ValueError(
|
| 348 |
+
f"active_sessions ({self.active_sessions}) cannot exceed "
|
| 349 |
+
f"max_sessions ({self.max_sessions})"
|
| 350 |
+
)
|
| 351 |
+
return self
|
| 352 |
+
|
| 353 |
+
@property
|
| 354 |
+
def available_slots(self) -> int:
|
| 355 |
+
"""Number of available session slots."""
|
| 356 |
+
return self.max_sessions - self.active_sessions
|
| 357 |
+
|
| 358 |
+
@property
|
| 359 |
+
def is_at_capacity(self) -> bool:
|
| 360 |
+
"""Whether the server has reached maximum capacity."""
|
| 361 |
+
return self.available_slots == 0
|
| 362 |
+
|
| 363 |
+
@classmethod
|
| 364 |
+
def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus":
|
| 365 |
+
"""Create status from active and max session counts."""
|
| 366 |
+
return cls(
|
| 367 |
+
active_sessions=active,
|
| 368 |
+
max_sessions=max_sessions,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class SessionInfo(BaseMessage):
|
| 373 |
+
"""Information about an active session."""
|
| 374 |
+
|
| 375 |
+
session_id: str = Field(description="Unique identifier for the session")
|
| 376 |
+
created_at: float = Field(description="Unix timestamp when the session was created")
|
| 377 |
+
last_activity_at: float = Field(
|
| 378 |
+
description="Unix timestamp of the last activity in the session"
|
| 379 |
+
)
|
| 380 |
+
step_count: int = Field(
|
| 381 |
+
default=0,
|
| 382 |
+
ge=0,
|
| 383 |
+
description="Number of steps executed in this session",
|
| 384 |
+
)
|
| 385 |
+
environment_type: str = Field(
|
| 386 |
+
description="Environment type for this session (e.g. `CodingEnv`)"
|
| 387 |
+
)
|
src/core/env_server/web_interface.py
CHANGED
|
@@ -7,61 +7,164 @@
|
|
| 7 |
"""
|
| 8 |
Web interface for OpenEnv environments.
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
| 15 |
|
|
|
|
| 16 |
import json
|
| 17 |
-
import
|
| 18 |
-
from dataclasses import asdict, dataclass
|
| 19 |
-
from typing import Any, Dict, List, Optional, Type
|
| 20 |
from datetime import datetime
|
|
|
|
| 21 |
|
| 22 |
-
|
| 23 |
-
from fastapi
|
| 24 |
-
from
|
| 25 |
-
from pydantic import BaseModel
|
| 26 |
|
|
|
|
|
|
|
| 27 |
from .interfaces import Environment
|
| 28 |
-
from .
|
|
|
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
Load environment metadata including README content.
|
| 34 |
-
|
| 35 |
Args:
|
| 36 |
-
env: The environment instance
|
|
|
|
|
|
|
|
|
|
| 37 |
env_name: Optional environment name for README file lookup
|
| 38 |
-
|
| 39 |
Returns:
|
| 40 |
EnvironmentMetadata with loaded information
|
| 41 |
"""
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return env.get_metadata()
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# Default metadata
|
| 47 |
metadata = EnvironmentMetadata(
|
| 48 |
-
name=env_name or
|
| 49 |
-
description=f"{
|
| 50 |
-
version="1.0.0"
|
| 51 |
)
|
| 52 |
-
|
| 53 |
# Try to load README from file system
|
| 54 |
readme_content = _load_readme_from_filesystem(env_name)
|
| 55 |
if readme_content:
|
| 56 |
metadata.readme_content = readme_content
|
| 57 |
-
|
| 58 |
return metadata
|
| 59 |
|
| 60 |
|
| 61 |
def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
|
| 62 |
"""
|
| 63 |
Load README content from the filesystem.
|
| 64 |
-
|
| 65 |
Tries multiple locations:
|
| 66 |
1. Container filesystem: /app/README.md
|
| 67 |
2. Local development: src/envs/{env_name}/README.md
|
|
@@ -69,59 +172,73 @@ def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
|
|
| 69 |
"""
|
| 70 |
import os
|
| 71 |
from pathlib import Path
|
| 72 |
-
|
| 73 |
# Try container filesystem first
|
| 74 |
container_readme = Path("/app/README.md")
|
| 75 |
if container_readme.exists():
|
| 76 |
try:
|
| 77 |
-
return container_readme.read_text(encoding=
|
| 78 |
except Exception:
|
| 79 |
pass
|
| 80 |
-
|
| 81 |
# Try environment variable path
|
| 82 |
custom_path = os.environ.get("ENV_README_PATH")
|
| 83 |
if custom_path and Path(custom_path).exists():
|
| 84 |
try:
|
| 85 |
-
return Path(custom_path).read_text(encoding=
|
| 86 |
except Exception:
|
| 87 |
pass
|
| 88 |
-
|
| 89 |
# Try local development path
|
| 90 |
if env_name:
|
| 91 |
local_readme = Path(f"src/envs/{env_name}/README.md")
|
| 92 |
if local_readme.exists():
|
| 93 |
try:
|
| 94 |
-
return local_readme.read_text(encoding=
|
| 95 |
except Exception:
|
| 96 |
pass
|
| 97 |
-
|
| 98 |
return None
|
| 99 |
|
| 100 |
|
| 101 |
-
|
| 102 |
-
class ActionLog:
|
| 103 |
"""Log entry for an action taken."""
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
-
|
| 113 |
-
class EpisodeState:
|
| 114 |
"""Current episode state for the web interface."""
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
class WebInterfaceManager:
|
| 123 |
"""Manages the web interface for an environment."""
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
def __init__(
|
| 126 |
self,
|
| 127 |
env: Environment,
|
|
@@ -129,152 +246,146 @@ class WebInterfaceManager:
|
|
| 129 |
observation_cls: Type[Observation],
|
| 130 |
metadata: Optional[EnvironmentMetadata] = None,
|
| 131 |
):
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
self.action_cls = action_cls
|
| 134 |
self.observation_cls = observation_cls
|
| 135 |
self.metadata = metadata or EnvironmentMetadata(
|
| 136 |
name=env.__class__.__name__,
|
| 137 |
-
description=f"{env.__class__.__name__} environment"
|
| 138 |
)
|
| 139 |
self.episode_state = EpisodeState(
|
| 140 |
episode_id=None,
|
| 141 |
step_count=0,
|
| 142 |
current_observation=None,
|
| 143 |
-
action_logs=[]
|
| 144 |
)
|
| 145 |
self.connected_clients: List[WebSocket] = []
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
async def connect_websocket(self, websocket: WebSocket):
|
| 148 |
"""Connect a new WebSocket client."""
|
| 149 |
await websocket.accept()
|
| 150 |
self.connected_clients.append(websocket)
|
| 151 |
-
|
| 152 |
# Send current state to the new client
|
| 153 |
await self._send_state_update()
|
| 154 |
-
|
| 155 |
async def disconnect_websocket(self, websocket: WebSocket):
|
| 156 |
"""Disconnect a WebSocket client."""
|
| 157 |
if websocket in self.connected_clients:
|
| 158 |
self.connected_clients.remove(websocket)
|
| 159 |
-
|
| 160 |
async def _send_state_update(self):
|
| 161 |
"""Send current state to all connected clients."""
|
| 162 |
if not self.connected_clients:
|
| 163 |
return
|
| 164 |
-
|
| 165 |
state_data = {
|
| 166 |
"type": "state_update",
|
| 167 |
-
"episode_state":
|
| 168 |
}
|
| 169 |
-
|
| 170 |
# Send to all connected clients
|
| 171 |
disconnected_clients = []
|
| 172 |
for client in self.connected_clients:
|
| 173 |
try:
|
| 174 |
await client.send_text(json.dumps(state_data))
|
| 175 |
-
except:
|
| 176 |
disconnected_clients.append(client)
|
| 177 |
-
|
| 178 |
# Remove disconnected clients
|
| 179 |
for client in disconnected_clients:
|
| 180 |
self.connected_clients.remove(client)
|
| 181 |
-
|
| 182 |
async def reset_environment(self) -> Dict[str, Any]:
|
| 183 |
"""Reset the environment and update state."""
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
# Update episode state
|
| 188 |
self.episode_state.episode_id = state.episode_id
|
| 189 |
self.episode_state.step_count = 0
|
| 190 |
-
self.episode_state.current_observation =
|
| 191 |
self.episode_state.action_logs = []
|
| 192 |
self.episode_state.is_reset = True
|
| 193 |
-
|
| 194 |
# Send state update
|
| 195 |
await self._send_state_update()
|
| 196 |
-
|
| 197 |
-
return
|
| 198 |
-
|
| 199 |
-
"reward": observation.reward,
|
| 200 |
-
"done": observation.done,
|
| 201 |
-
}
|
| 202 |
-
|
| 203 |
async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 204 |
"""Execute a step in the environment and update state."""
|
| 205 |
-
# Deserialize action
|
| 206 |
-
action =
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
# Create action log
|
| 213 |
action_log = ActionLog(
|
| 214 |
timestamp=datetime.now().isoformat(),
|
| 215 |
-
action=
|
| 216 |
-
observation=
|
| 217 |
reward=observation.reward,
|
| 218 |
done=observation.done,
|
| 219 |
-
step_count=state.step_count
|
| 220 |
)
|
| 221 |
-
|
| 222 |
# Update episode state
|
| 223 |
self.episode_state.episode_id = state.episode_id
|
| 224 |
self.episode_state.step_count = state.step_count
|
| 225 |
-
self.episode_state.current_observation =
|
| 226 |
self.episode_state.action_logs.append(action_log)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
self.episode_state.is_reset = False
|
| 228 |
-
|
| 229 |
# Send state update
|
| 230 |
await self._send_state_update()
|
| 231 |
-
|
| 232 |
-
return
|
| 233 |
-
|
| 234 |
-
"reward": observation.reward,
|
| 235 |
-
"done": observation.done,
|
| 236 |
-
}
|
| 237 |
-
|
| 238 |
def get_state(self) -> Dict[str, Any]:
|
| 239 |
"""Get current environment state."""
|
| 240 |
-
state = self.env.state
|
| 241 |
-
return
|
| 242 |
-
|
| 243 |
-
def _deserialize_action(self, action_data: Dict[str, Any]) -> Action:
|
| 244 |
-
"""Convert JSON dict to Action instance."""
|
| 245 |
-
metadata = action_data.pop("metadata", {})
|
| 246 |
-
|
| 247 |
-
# Handle tensor fields that come from JSON as lists
|
| 248 |
-
processed_data = {}
|
| 249 |
-
for key, value in action_data.items():
|
| 250 |
-
if key == "tokens" and isinstance(value, (list, str)):
|
| 251 |
-
# Convert list or string to tensor
|
| 252 |
-
if isinstance(value, str):
|
| 253 |
-
# If it's a string, try to parse it as a list of numbers
|
| 254 |
-
try:
|
| 255 |
-
import json
|
| 256 |
-
value = json.loads(value)
|
| 257 |
-
except:
|
| 258 |
-
# If parsing fails, treat as empty list
|
| 259 |
-
value = []
|
| 260 |
-
if isinstance(value, list):
|
| 261 |
-
import torch
|
| 262 |
-
processed_data[key] = torch.tensor(value, dtype=torch.long)
|
| 263 |
-
else:
|
| 264 |
-
processed_data[key] = value
|
| 265 |
-
elif key == "action_id" and isinstance(value, str):
|
| 266 |
-
# Convert action_id from string to int
|
| 267 |
-
try:
|
| 268 |
-
processed_data[key] = int(value)
|
| 269 |
-
except ValueError:
|
| 270 |
-
# If conversion fails, keep original value
|
| 271 |
-
processed_data[key] = value
|
| 272 |
-
else:
|
| 273 |
-
processed_data[key] = value
|
| 274 |
-
|
| 275 |
-
action = self.action_cls(**processed_data)
|
| 276 |
-
action.metadata = metadata
|
| 277 |
-
return action
|
| 278 |
|
| 279 |
|
| 280 |
def create_web_interface_app(
|
|
@@ -282,44 +393,53 @@ def create_web_interface_app(
|
|
| 282 |
action_cls: Type[Action],
|
| 283 |
observation_cls: Type[Observation],
|
| 284 |
env_name: Optional[str] = None,
|
|
|
|
|
|
|
|
|
|
| 285 |
) -> FastAPI:
|
| 286 |
"""
|
| 287 |
Create a FastAPI application with web interface for the given environment.
|
| 288 |
-
|
| 289 |
Args:
|
| 290 |
env: The Environment instance to serve
|
| 291 |
action_cls: The Action subclass this environment expects
|
| 292 |
observation_cls: The Observation subclass this environment returns
|
| 293 |
env_name: Optional environment name for README loading
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
Returns:
|
| 296 |
FastAPI application instance with web interface
|
| 297 |
"""
|
| 298 |
from .http_server import create_fastapi_app
|
| 299 |
-
|
| 300 |
# Create the base environment app
|
| 301 |
-
app = create_fastapi_app(
|
| 302 |
-
|
|
|
|
|
|
|
| 303 |
# Load environment metadata
|
| 304 |
metadata = load_environment_metadata(env, env_name)
|
| 305 |
-
|
| 306 |
# Create web interface manager
|
| 307 |
web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
|
| 308 |
-
|
| 309 |
-
#
|
| 310 |
-
@app.get("/web", response_class=HTMLResponse)
|
| 311 |
-
async def web_interface():
|
| 312 |
-
"""Serve the web interface."""
|
| 313 |
-
return get_web_interface_html(action_cls, web_manager.metadata)
|
| 314 |
-
|
| 315 |
@app.get("/web/metadata")
|
| 316 |
async def web_metadata():
|
| 317 |
"""Get environment metadata."""
|
| 318 |
-
return
|
| 319 |
-
|
| 320 |
-
@app.websocket("/ws")
|
| 321 |
-
async def
|
| 322 |
-
"""WebSocket endpoint for real-time updates.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
await web_manager.connect_websocket(websocket)
|
| 324 |
try:
|
| 325 |
while True:
|
|
@@ -327,1287 +447,198 @@ def create_web_interface_app(
|
|
| 327 |
await websocket.receive_text()
|
| 328 |
except WebSocketDisconnect:
|
| 329 |
await web_manager.disconnect_websocket(websocket)
|
| 330 |
-
|
| 331 |
@app.post("/web/reset")
|
| 332 |
async def web_reset():
|
| 333 |
"""Reset endpoint for web interface."""
|
| 334 |
return await web_manager.reset_environment()
|
| 335 |
-
|
| 336 |
@app.post("/web/step")
|
| 337 |
async def web_step(request: Dict[str, Any]):
|
| 338 |
"""Step endpoint for web interface."""
|
| 339 |
# Check if this is a message-based request (chat environment)
|
| 340 |
if "message" in request:
|
| 341 |
message = request["message"]
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
else:
|
| 346 |
action_data = request.get("action", {})
|
| 347 |
-
|
| 348 |
return await web_manager.step_environment(action_data)
|
| 349 |
-
|
| 350 |
@app.get("/web/state")
|
| 351 |
async def web_state():
|
| 352 |
"""State endpoint for web interface."""
|
| 353 |
return web_manager.get_state()
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
return app
|
| 356 |
|
| 357 |
|
| 358 |
-
def
|
| 359 |
-
"""
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
# Get action fields for dynamic form generation with enhanced metadata
|
| 370 |
-
action_fields = _extract_action_fields(action_cls)
|
| 371 |
-
|
| 372 |
-
return f"""
|
| 373 |
-
<!DOCTYPE html>
|
| 374 |
-
<html lang="en">
|
| 375 |
-
<head>
|
| 376 |
-
<meta charset="UTF-8">
|
| 377 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 378 |
-
<title>OpenEnv Web Interface</title>
|
| 379 |
-
<style>
|
| 380 |
-
* {{
|
| 381 |
-
margin: 0;
|
| 382 |
-
padding: 0;
|
| 383 |
-
box-sizing: border-box;
|
| 384 |
-
}}
|
| 385 |
-
|
| 386 |
-
body {{
|
| 387 |
-
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 388 |
-
background-color: #f5f5f5;
|
| 389 |
-
height: 100vh;
|
| 390 |
-
overflow: hidden;
|
| 391 |
-
}}
|
| 392 |
-
|
| 393 |
-
.container {{
|
| 394 |
-
display: flex;
|
| 395 |
-
height: 100vh;
|
| 396 |
-
}}
|
| 397 |
-
|
| 398 |
-
.left-pane {{
|
| 399 |
-
width: 50%;
|
| 400 |
-
background: white;
|
| 401 |
-
border-right: 1px solid #e0e0e0;
|
| 402 |
-
display: flex;
|
| 403 |
-
flex-direction: column;
|
| 404 |
-
}}
|
| 405 |
-
|
| 406 |
-
.right-pane {{
|
| 407 |
-
width: 50%;
|
| 408 |
-
background: #fafafa;
|
| 409 |
-
display: flex;
|
| 410 |
-
flex-direction: column;
|
| 411 |
-
}}
|
| 412 |
-
|
| 413 |
-
.pane-header {{
|
| 414 |
-
padding: 20px;
|
| 415 |
-
border-bottom: 1px solid #e0e0e0;
|
| 416 |
-
background: #f8f9fa;
|
| 417 |
-
font-weight: 600;
|
| 418 |
-
font-size: 16px;
|
| 419 |
-
}}
|
| 420 |
-
|
| 421 |
-
.pane-content {{
|
| 422 |
-
flex: 1;
|
| 423 |
-
padding: 20px;
|
| 424 |
-
overflow-y: auto;
|
| 425 |
-
}}
|
| 426 |
-
|
| 427 |
-
.action-form {{
|
| 428 |
-
background: white;
|
| 429 |
-
border: 1px solid #e0e0e0;
|
| 430 |
-
border-radius: 8px;
|
| 431 |
-
padding: 20px;
|
| 432 |
-
margin-bottom: 20px;
|
| 433 |
-
}}
|
| 434 |
-
|
| 435 |
-
.form-group {{
|
| 436 |
-
margin-bottom: 15px;
|
| 437 |
-
}}
|
| 438 |
-
|
| 439 |
-
.form-group label {{
|
| 440 |
-
display: block;
|
| 441 |
-
margin-bottom: 5px;
|
| 442 |
-
font-weight: 500;
|
| 443 |
-
color: #333;
|
| 444 |
-
}}
|
| 445 |
-
|
| 446 |
-
.form-group input, .form-group textarea {{
|
| 447 |
-
width: 100%;
|
| 448 |
-
padding: 8px 12px;
|
| 449 |
-
border: 1px solid #ddd;
|
| 450 |
-
border-radius: 4px;
|
| 451 |
-
font-size: 14px;
|
| 452 |
-
}}
|
| 453 |
-
|
| 454 |
-
.form-group input:focus, .form-group textarea:focus {{
|
| 455 |
-
outline: none;
|
| 456 |
-
border-color: #007bff;
|
| 457 |
-
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 458 |
-
}}
|
| 459 |
-
|
| 460 |
-
.btn {{
|
| 461 |
-
background: #007bff;
|
| 462 |
-
color: white;
|
| 463 |
-
border: none;
|
| 464 |
-
padding: 10px 20px;
|
| 465 |
-
border-radius: 4px;
|
| 466 |
-
cursor: pointer;
|
| 467 |
-
font-size: 14px;
|
| 468 |
-
margin-right: 10px;
|
| 469 |
-
margin-bottom: 10px;
|
| 470 |
-
}}
|
| 471 |
-
|
| 472 |
-
.btn:hover {{
|
| 473 |
-
background: #0056b3;
|
| 474 |
-
}}
|
| 475 |
-
|
| 476 |
-
.btn:disabled {{
|
| 477 |
-
background: #6c757d;
|
| 478 |
-
cursor: not-allowed;
|
| 479 |
-
}}
|
| 480 |
-
|
| 481 |
-
.btn-secondary {{
|
| 482 |
-
background: #6c757d;
|
| 483 |
-
}}
|
| 484 |
-
|
| 485 |
-
.btn-secondary:hover {{
|
| 486 |
-
background: #545b62;
|
| 487 |
-
}}
|
| 488 |
-
|
| 489 |
-
.state-display {{
|
| 490 |
-
background: white;
|
| 491 |
-
border: 1px solid #e0e0e0;
|
| 492 |
-
border-radius: 8px;
|
| 493 |
-
padding: 15px;
|
| 494 |
-
margin-bottom: 20px;
|
| 495 |
-
}}
|
| 496 |
-
|
| 497 |
-
.state-item {{
|
| 498 |
-
margin-bottom: 8px;
|
| 499 |
-
}}
|
| 500 |
-
|
| 501 |
-
.state-label {{
|
| 502 |
-
font-weight: 500;
|
| 503 |
-
color: #666;
|
| 504 |
-
}}
|
| 505 |
-
|
| 506 |
-
.state-value {{
|
| 507 |
-
color: #333;
|
| 508 |
-
font-family: monospace;
|
| 509 |
-
}}
|
| 510 |
-
|
| 511 |
-
.logs-container {{
|
| 512 |
-
background: white;
|
| 513 |
-
border: 1px solid #e0e0e0;
|
| 514 |
-
border-radius: 8px;
|
| 515 |
-
padding: 15px;
|
| 516 |
-
max-height: 400px;
|
| 517 |
-
overflow-y: auto;
|
| 518 |
-
}}
|
| 519 |
-
|
| 520 |
-
.log-entry {{
|
| 521 |
-
border-bottom: 1px solid #f0f0f0;
|
| 522 |
-
padding: 10px 0;
|
| 523 |
-
}}
|
| 524 |
-
|
| 525 |
-
.log-entry:last-child {{
|
| 526 |
-
border-bottom: none;
|
| 527 |
-
}}
|
| 528 |
-
|
| 529 |
-
.log-timestamp {{
|
| 530 |
-
font-size: 12px;
|
| 531 |
-
color: #666;
|
| 532 |
-
margin-bottom: 5px;
|
| 533 |
-
}}
|
| 534 |
-
|
| 535 |
-
.log-action {{
|
| 536 |
-
background: #e3f2fd;
|
| 537 |
-
padding: 8px;
|
| 538 |
-
border-radius: 4px;
|
| 539 |
-
margin-bottom: 5px;
|
| 540 |
-
font-family: monospace;
|
| 541 |
-
font-size: 12px;
|
| 542 |
-
}}
|
| 543 |
-
|
| 544 |
-
.log-observation {{
|
| 545 |
-
background: #f3e5f5;
|
| 546 |
-
padding: 8px;
|
| 547 |
-
border-radius: 4px;
|
| 548 |
-
font-family: monospace;
|
| 549 |
-
font-size: 12px;
|
| 550 |
-
}}
|
| 551 |
-
|
| 552 |
-
.log-reward {{
|
| 553 |
-
font-weight: 600;
|
| 554 |
-
color: #28a745;
|
| 555 |
-
}}
|
| 556 |
-
|
| 557 |
-
.log-done {{
|
| 558 |
-
font-weight: 600;
|
| 559 |
-
color: #dc3545;
|
| 560 |
-
}}
|
| 561 |
-
|
| 562 |
-
.status-indicator {{
|
| 563 |
-
display: inline-block;
|
| 564 |
-
width: 8px;
|
| 565 |
-
height: 8px;
|
| 566 |
-
border-radius: 50%;
|
| 567 |
-
margin-right: 8px;
|
| 568 |
-
}}
|
| 569 |
-
|
| 570 |
-
.status-connected {{
|
| 571 |
-
background: #28a745;
|
| 572 |
-
}}
|
| 573 |
-
|
| 574 |
-
.status-disconnected {{
|
| 575 |
-
background: #dc3545;
|
| 576 |
-
}}
|
| 577 |
-
|
| 578 |
-
.json-display {{
|
| 579 |
-
background: #f8f9fa;
|
| 580 |
-
border: 1px solid #e9ecef;
|
| 581 |
-
border-radius: 4px;
|
| 582 |
-
padding: 10px;
|
| 583 |
-
font-family: monospace;
|
| 584 |
-
font-size: 12px;
|
| 585 |
-
white-space: pre-wrap;
|
| 586 |
-
max-height: 200px;
|
| 587 |
-
overflow-y: auto;
|
| 588 |
-
}}
|
| 589 |
-
|
| 590 |
-
/* Chat Interface Styles */
|
| 591 |
-
.chat-interface {{
|
| 592 |
-
background: white;
|
| 593 |
-
border: 1px solid #e0e0e0;
|
| 594 |
-
border-radius: 8px;
|
| 595 |
-
padding: 20px;
|
| 596 |
-
margin-bottom: 20px;
|
| 597 |
-
}}
|
| 598 |
-
|
| 599 |
-
.chat-messages {{
|
| 600 |
-
background: #f8f9fa;
|
| 601 |
-
border: 1px solid #e0e0e0;
|
| 602 |
-
border-radius: 8px;
|
| 603 |
-
padding: 15px;
|
| 604 |
-
margin-bottom: 15px;
|
| 605 |
-
max-height: 400px;
|
| 606 |
-
overflow-y: auto;
|
| 607 |
-
}}
|
| 608 |
-
|
| 609 |
-
.chat-message {{
|
| 610 |
-
margin-bottom: 15px;
|
| 611 |
-
padding: 10px;
|
| 612 |
-
border-radius: 8px;
|
| 613 |
-
}}
|
| 614 |
-
|
| 615 |
-
.chat-message:last-child {{
|
| 616 |
-
margin-bottom: 0;
|
| 617 |
-
}}
|
| 618 |
-
|
| 619 |
-
.chat-message.user {{
|
| 620 |
-
background: #e3f2fd;
|
| 621 |
-
margin-left: 20px;
|
| 622 |
-
}}
|
| 623 |
-
|
| 624 |
-
.chat-message.assistant {{
|
| 625 |
-
background: #f3e5f5;
|
| 626 |
-
margin-right: 20px;
|
| 627 |
-
}}
|
| 628 |
-
|
| 629 |
-
.chat-message.system {{
|
| 630 |
-
background: #e8f5e8;
|
| 631 |
-
font-style: italic;
|
| 632 |
-
}}
|
| 633 |
-
|
| 634 |
-
.message-role {{
|
| 635 |
-
font-weight: 600;
|
| 636 |
-
font-size: 12px;
|
| 637 |
-
color: #666;
|
| 638 |
-
margin-bottom: 5px;
|
| 639 |
-
}}
|
| 640 |
-
|
| 641 |
-
.message-content {{
|
| 642 |
-
font-size: 14px;
|
| 643 |
-
line-height: 1.4;
|
| 644 |
-
}}
|
| 645 |
-
|
| 646 |
-
.chat-input-container {{
|
| 647 |
-
border-top: 1px solid #e0e0e0;
|
| 648 |
-
padding-top: 15px;
|
| 649 |
-
}}
|
| 650 |
-
|
| 651 |
-
.role-selector {{
|
| 652 |
-
margin-bottom: 10px;
|
| 653 |
-
}}
|
| 654 |
-
|
| 655 |
-
.role-selector label {{
|
| 656 |
-
font-weight: 500;
|
| 657 |
-
margin-right: 10px;
|
| 658 |
-
}}
|
| 659 |
-
|
| 660 |
-
.role-selector select {{
|
| 661 |
-
padding: 5px 10px;
|
| 662 |
-
border: 1px solid #ddd;
|
| 663 |
-
border-radius: 4px;
|
| 664 |
-
}}
|
| 665 |
-
|
| 666 |
-
.message-input {{
|
| 667 |
-
display: flex;
|
| 668 |
-
gap: 10px;
|
| 669 |
-
align-items: flex-end;
|
| 670 |
-
}}
|
| 671 |
-
|
| 672 |
-
.message-input textarea {{
|
| 673 |
-
flex: 1;
|
| 674 |
-
padding: 10px;
|
| 675 |
-
border: 1px solid #ddd;
|
| 676 |
-
border-radius: 4px;
|
| 677 |
-
resize: vertical;
|
| 678 |
-
font-family: inherit;
|
| 679 |
-
}}
|
| 680 |
-
|
| 681 |
-
.message-input textarea:focus {{
|
| 682 |
-
outline: none;
|
| 683 |
-
border-color: #007bff;
|
| 684 |
-
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 685 |
-
}}
|
| 686 |
-
|
| 687 |
-
/* Instructions Section Styles */
|
| 688 |
-
.instructions-section {{
|
| 689 |
-
background: white;
|
| 690 |
-
border: 1px solid #e0e0e0;
|
| 691 |
-
border-radius: 8px;
|
| 692 |
-
padding: 20px;
|
| 693 |
-
margin-bottom: 20px;
|
| 694 |
-
}}
|
| 695 |
-
|
| 696 |
-
.instructions-header {{
|
| 697 |
-
display: flex;
|
| 698 |
-
justify-content: space-between;
|
| 699 |
-
align-items: center;
|
| 700 |
-
margin-bottom: 15px;
|
| 701 |
-
}}
|
| 702 |
-
|
| 703 |
-
.instructions-title {{
|
| 704 |
-
font-size: 18px;
|
| 705 |
-
font-weight: 600;
|
| 706 |
-
color: #333;
|
| 707 |
-
margin: 0;
|
| 708 |
-
}}
|
| 709 |
-
|
| 710 |
-
.instructions-toggle {{
|
| 711 |
-
background: #f8f9fa;
|
| 712 |
-
border: 1px solid #dee2e6;
|
| 713 |
-
border-radius: 4px;
|
| 714 |
-
padding: 5px 10px;
|
| 715 |
-
cursor: pointer;
|
| 716 |
-
font-size: 12px;
|
| 717 |
-
color: #6c757d;
|
| 718 |
-
}}
|
| 719 |
-
|
| 720 |
-
.instructions-toggle:hover {{
|
| 721 |
-
background: #e9ecef;
|
| 722 |
-
}}
|
| 723 |
-
|
| 724 |
-
.instructions-content {{
|
| 725 |
-
display: none;
|
| 726 |
-
max-height: 400px;
|
| 727 |
-
overflow-y: auto;
|
| 728 |
-
border-top: 1px solid #e0e0e0;
|
| 729 |
-
padding-top: 15px;
|
| 730 |
-
}}
|
| 731 |
-
|
| 732 |
-
.instructions-content.expanded {{
|
| 733 |
-
display: block;
|
| 734 |
-
}}
|
| 735 |
-
|
| 736 |
-
.instructions-content h1,
|
| 737 |
-
.instructions-content h2,
|
| 738 |
-
.instructions-content h3 {{
|
| 739 |
-
color: #333;
|
| 740 |
-
margin-top: 20px;
|
| 741 |
-
margin-bottom: 10px;
|
| 742 |
-
}}
|
| 743 |
-
|
| 744 |
-
.instructions-content h1 {{
|
| 745 |
-
font-size: 24px;
|
| 746 |
-
border-bottom: 2px solid #007bff;
|
| 747 |
-
padding-bottom: 10px;
|
| 748 |
-
}}
|
| 749 |
-
|
| 750 |
-
.instructions-content h2 {{
|
| 751 |
-
font-size: 20px;
|
| 752 |
-
}}
|
| 753 |
-
|
| 754 |
-
.instructions-content h3 {{
|
| 755 |
-
font-size: 16px;
|
| 756 |
-
}}
|
| 757 |
-
|
| 758 |
-
.instructions-content p {{
|
| 759 |
-
margin-bottom: 10px;
|
| 760 |
-
line-height: 1.6;
|
| 761 |
-
}}
|
| 762 |
-
|
| 763 |
-
.instructions-content code {{
|
| 764 |
-
background: #f8f9fa;
|
| 765 |
-
padding: 2px 4px;
|
| 766 |
-
border-radius: 3px;
|
| 767 |
-
font-family: monospace;
|
| 768 |
-
font-size: 14px;
|
| 769 |
-
}}
|
| 770 |
-
|
| 771 |
-
.instructions-content pre {{
|
| 772 |
-
background: #f8f9fa;
|
| 773 |
-
border: 1px solid #e9ecef;
|
| 774 |
-
border-radius: 4px;
|
| 775 |
-
padding: 15px;
|
| 776 |
-
overflow-x: auto;
|
| 777 |
-
margin: 10px 0;
|
| 778 |
-
}}
|
| 779 |
-
|
| 780 |
-
.instructions-content pre code {{
|
| 781 |
-
background: none;
|
| 782 |
-
padding: 0;
|
| 783 |
-
}}
|
| 784 |
-
|
| 785 |
-
.instructions-content ul,
|
| 786 |
-
.instructions-content ol {{
|
| 787 |
-
margin: 10px 0;
|
| 788 |
-
padding-left: 20px;
|
| 789 |
-
}}
|
| 790 |
-
|
| 791 |
-
.instructions-content li {{
|
| 792 |
-
margin-bottom: 5px;
|
| 793 |
-
}}
|
| 794 |
-
|
| 795 |
-
.instructions-content table {{
|
| 796 |
-
border-collapse: collapse;
|
| 797 |
-
width: 100%;
|
| 798 |
-
margin: 15px 0;
|
| 799 |
-
}}
|
| 800 |
-
|
| 801 |
-
.instructions-content th,
|
| 802 |
-
.instructions-content td {{
|
| 803 |
-
border: 1px solid #dee2e6;
|
| 804 |
-
padding: 8px 12px;
|
| 805 |
-
text-align: left;
|
| 806 |
-
}}
|
| 807 |
-
|
| 808 |
-
.instructions-content th {{
|
| 809 |
-
background: #f8f9fa;
|
| 810 |
-
font-weight: 600;
|
| 811 |
-
}}
|
| 812 |
-
|
| 813 |
-
/* Enhanced Form Styles */
|
| 814 |
-
.help-text {{
|
| 815 |
-
display: block;
|
| 816 |
-
margin-top: 5px;
|
| 817 |
-
font-size: 12px;
|
| 818 |
-
color: #6c757d;
|
| 819 |
-
font-style: italic;
|
| 820 |
-
}}
|
| 821 |
-
|
| 822 |
-
.form-group label {{
|
| 823 |
-
font-weight: 500;
|
| 824 |
-
color: #333;
|
| 825 |
-
margin-bottom: 5px;
|
| 826 |
-
}}
|
| 827 |
-
|
| 828 |
-
.form-group select {{
|
| 829 |
-
width: 100%;
|
| 830 |
-
padding: 8px 12px;
|
| 831 |
-
border: 1px solid #ddd;
|
| 832 |
-
border-radius: 4px;
|
| 833 |
-
font-size: 14px;
|
| 834 |
-
background-color: white;
|
| 835 |
-
}}
|
| 836 |
-
|
| 837 |
-
.form-group select:focus {{
|
| 838 |
-
outline: none;
|
| 839 |
-
border-color: #007bff;
|
| 840 |
-
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 841 |
-
}}
|
| 842 |
-
|
| 843 |
-
.form-group textarea {{
|
| 844 |
-
width: 100%;
|
| 845 |
-
padding: 8px 12px;
|
| 846 |
-
border: 1px solid #ddd;
|
| 847 |
-
border-radius: 4px;
|
| 848 |
-
font-size: 14px;
|
| 849 |
-
font-family: inherit;
|
| 850 |
-
resize: vertical;
|
| 851 |
-
}}
|
| 852 |
-
|
| 853 |
-
.form-group textarea:focus {{
|
| 854 |
-
outline: none;
|
| 855 |
-
border-color: #007bff;
|
| 856 |
-
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 857 |
-
}}
|
| 858 |
-
|
| 859 |
-
.form-group input[type="number"] {{
|
| 860 |
-
width: 100%;
|
| 861 |
-
padding: 8px 12px;
|
| 862 |
-
border: 1px solid #ddd;
|
| 863 |
-
border-radius: 4px;
|
| 864 |
-
font-size: 14px;
|
| 865 |
-
}}
|
| 866 |
-
|
| 867 |
-
.form-group input[type="number"]:focus {{
|
| 868 |
-
outline: none;
|
| 869 |
-
border-color: #007bff;
|
| 870 |
-
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 871 |
-
}}
|
| 872 |
-
|
| 873 |
-
.form-group input[type="text"]:focus {{
|
| 874 |
-
outline: none;
|
| 875 |
-
border-color: #007bff;
|
| 876 |
-
box-shadow: 0 0 0 2px rgba(0, 123, 255, 0.25);
|
| 877 |
-
}}
|
| 878 |
-
|
| 879 |
-
.required-indicator {{
|
| 880 |
-
color: #dc3545;
|
| 881 |
-
font-weight: bold;
|
| 882 |
-
}}
|
| 883 |
-
|
| 884 |
-
.form-group .field-description {{
|
| 885 |
-
font-size: 11px;
|
| 886 |
-
color: #666;
|
| 887 |
-
margin-top: 2px;
|
| 888 |
-
font-style: italic;
|
| 889 |
-
}}
|
| 890 |
-
</style>
|
| 891 |
-
</head>
|
| 892 |
-
<body>
|
| 893 |
-
<div class="container">
|
| 894 |
-
<!-- Left Pane: HumanAgent Interface -->
|
| 895 |
-
<div class="left-pane">
|
| 896 |
-
<div class="pane-header">
|
| 897 |
-
<span class="status-indicator status-disconnected" id="connection-status"></span>
|
| 898 |
-
HumanAgent Interface
|
| 899 |
-
</div>
|
| 900 |
-
<div class="pane-content">
|
| 901 |
-
<!-- Instructions Section -->
|
| 902 |
-
{_generate_instructions_section(metadata)}
|
| 903 |
-
|
| 904 |
-
<!-- Action Form or Chat Interface -->
|
| 905 |
-
{_generate_action_interface(action_fields, is_chat_env)}
|
| 906 |
-
|
| 907 |
-
<!-- Control Buttons -->
|
| 908 |
-
<div style="margin-bottom: 20px;">
|
| 909 |
-
<button class="btn btn-secondary" id="reset-btn">Reset Environment</button>
|
| 910 |
-
<button class="btn btn-secondary" id="state-btn">Get State</button>
|
| 911 |
-
</div>
|
| 912 |
-
|
| 913 |
-
<!-- Current State Display -->
|
| 914 |
-
<div class="state-display">
|
| 915 |
-
<h3>Current State</h3>
|
| 916 |
-
<div id="current-state">
|
| 917 |
-
<div class="state-item">
|
| 918 |
-
<span class="state-label">Status:</span>
|
| 919 |
-
<span class="state-value" id="env-status">Not initialized</span>
|
| 920 |
-
</div>
|
| 921 |
-
<div class="state-item">
|
| 922 |
-
<span class="state-label">Episode ID:</span>
|
| 923 |
-
<span class="state-value" id="episode-id">-</span>
|
| 924 |
-
</div>
|
| 925 |
-
<div class="state-item">
|
| 926 |
-
<span class="state-label">Step Count:</span>
|
| 927 |
-
<span class="state-value" id="step-count">0</span>
|
| 928 |
-
</div>
|
| 929 |
-
</div>
|
| 930 |
-
</div>
|
| 931 |
-
</div>
|
| 932 |
-
</div>
|
| 933 |
-
|
| 934 |
-
<!-- Right Pane: State Observer -->
|
| 935 |
-
<div class="right-pane">
|
| 936 |
-
<div class="pane-header">
|
| 937 |
-
State Observer
|
| 938 |
-
</div>
|
| 939 |
-
<div class="pane-content">
|
| 940 |
-
<!-- Current Observation -->
|
| 941 |
-
<div class="state-display">
|
| 942 |
-
<h3>Current Observation</h3>
|
| 943 |
-
<div id="current-observation" class="json-display">
|
| 944 |
-
No observation yet
|
| 945 |
-
</div>
|
| 946 |
-
</div>
|
| 947 |
-
|
| 948 |
-
<!-- Action Logs -->
|
| 949 |
-
<div class="logs-container">
|
| 950 |
-
<h3>Action History</h3>
|
| 951 |
-
<div id="action-logs">
|
| 952 |
-
No actions taken yet
|
| 953 |
-
</div>
|
| 954 |
-
</div>
|
| 955 |
-
</div>
|
| 956 |
-
</div>
|
| 957 |
-
</div>
|
| 958 |
-
|
| 959 |
-
<script>
|
| 960 |
-
class OpenEnvWebInterface {{
|
| 961 |
-
constructor() {{
|
| 962 |
-
this.ws = null;
|
| 963 |
-
this.isConnected = false;
|
| 964 |
-
this.init();
|
| 965 |
-
}}
|
| 966 |
-
|
| 967 |
-
init() {{
|
| 968 |
-
this.connectWebSocket();
|
| 969 |
-
this.setupEventListeners();
|
| 970 |
-
}}
|
| 971 |
-
|
| 972 |
-
connectWebSocket() {{
|
| 973 |
-
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 974 |
-
const wsUrl = `${{protocol}}//${{window.location.host}}/ws`;
|
| 975 |
-
|
| 976 |
-
this.ws = new WebSocket(wsUrl);
|
| 977 |
-
|
| 978 |
-
this.ws.onopen = () => {{
|
| 979 |
-
this.isConnected = true;
|
| 980 |
-
this.updateConnectionStatus(true);
|
| 981 |
-
console.log('WebSocket connected');
|
| 982 |
-
}};
|
| 983 |
-
|
| 984 |
-
this.ws.onmessage = (event) => {{
|
| 985 |
-
const data = JSON.parse(event.data);
|
| 986 |
-
if (data.type === 'state_update') {{
|
| 987 |
-
this.updateUI(data.episode_state);
|
| 988 |
-
}}
|
| 989 |
-
}};
|
| 990 |
-
|
| 991 |
-
this.ws.onclose = () => {{
|
| 992 |
-
this.isConnected = false;
|
| 993 |
-
this.updateConnectionStatus(false);
|
| 994 |
-
console.log('WebSocket disconnected');
|
| 995 |
-
// Attempt to reconnect after 3 seconds
|
| 996 |
-
setTimeout(() => this.connectWebSocket(), 3000);
|
| 997 |
-
}};
|
| 998 |
-
|
| 999 |
-
this.ws.onerror = (error) => {{
|
| 1000 |
-
console.error('WebSocket error:', error);
|
| 1001 |
-
}};
|
| 1002 |
-
}}
|
| 1003 |
-
|
| 1004 |
-
setupEventListeners() {{
|
| 1005 |
-
// Instructions toggle
|
| 1006 |
-
const instructionsToggle = document.getElementById('instructions-toggle');
|
| 1007 |
-
const instructionsContent = document.getElementById('instructions-content');
|
| 1008 |
-
if (instructionsToggle && instructionsContent) {{
|
| 1009 |
-
instructionsToggle.addEventListener('click', () => {{
|
| 1010 |
-
instructionsContent.classList.toggle('expanded');
|
| 1011 |
-
instructionsToggle.textContent = instructionsContent.classList.contains('expanded')
|
| 1012 |
-
? 'Hide Instructions' : 'Show Instructions';
|
| 1013 |
-
}});
|
| 1014 |
-
}}
|
| 1015 |
-
|
| 1016 |
-
// Check if this is a chat environment
|
| 1017 |
-
const isChatEnv = document.getElementById('chat-messages') !== null;
|
| 1018 |
-
|
| 1019 |
-
if (isChatEnv) {{
|
| 1020 |
-
// Chat environment event listeners
|
| 1021 |
-
document.getElementById('send-message-btn').addEventListener('click', () => {{
|
| 1022 |
-
this.sendMessage();
|
| 1023 |
-
}});
|
| 1024 |
-
|
| 1025 |
-
// Send message on Enter (but allow Shift+Enter for new lines)
|
| 1026 |
-
document.getElementById('message-input').addEventListener('keydown', (e) => {{
|
| 1027 |
-
if (e.key === 'Enter' && !e.shiftKey) {{
|
| 1028 |
-
e.preventDefault();
|
| 1029 |
-
this.sendMessage();
|
| 1030 |
-
}}
|
| 1031 |
-
}});
|
| 1032 |
-
}} else {{
|
| 1033 |
-
// Traditional action form submission
|
| 1034 |
-
const actionForm = document.getElementById('action-form');
|
| 1035 |
-
if (actionForm) {{
|
| 1036 |
-
actionForm.addEventListener('submit', (e) => {{
|
| 1037 |
-
e.preventDefault();
|
| 1038 |
-
this.submitAction();
|
| 1039 |
-
}});
|
| 1040 |
-
}}
|
| 1041 |
-
}}
|
| 1042 |
-
|
| 1043 |
-
// Reset button
|
| 1044 |
-
document.getElementById('reset-btn').addEventListener('click', () => {{
|
| 1045 |
-
this.resetEnvironment();
|
| 1046 |
-
}});
|
| 1047 |
-
|
| 1048 |
-
// State button
|
| 1049 |
-
document.getElementById('state-btn').addEventListener('click', () => {{
|
| 1050 |
-
this.getState();
|
| 1051 |
-
}});
|
| 1052 |
-
}}
|
| 1053 |
-
|
| 1054 |
-
async sendMessage() {{
|
| 1055 |
-
const messageInput = document.getElementById('message-input');
|
| 1056 |
-
const roleSelect = document.getElementById('message-role');
|
| 1057 |
-
const message = messageInput.value.trim();
|
| 1058 |
-
const role = roleSelect.value;
|
| 1059 |
-
|
| 1060 |
-
if (!message) {{
|
| 1061 |
-
return;
|
| 1062 |
-
}}
|
| 1063 |
-
|
| 1064 |
-
// Add message to chat display immediately
|
| 1065 |
-
this.addMessageToChat(role, message);
|
| 1066 |
-
|
| 1067 |
-
// Clear input
|
| 1068 |
-
messageInput.value = '';
|
| 1069 |
-
|
| 1070 |
-
try {{
|
| 1071 |
-
// Send message to server to convert to action and step
|
| 1072 |
-
const response = await fetch('/web/step', {{
|
| 1073 |
-
method: 'POST',
|
| 1074 |
-
headers: {{ 'Content-Type': 'application/json' }},
|
| 1075 |
-
body: JSON.stringify({{
|
| 1076 |
-
message: {{
|
| 1077 |
-
role: role,
|
| 1078 |
-
content: message
|
| 1079 |
-
}}
|
| 1080 |
-
}})
|
| 1081 |
-
}});
|
| 1082 |
-
|
| 1083 |
-
if (!response.ok) {{
|
| 1084 |
-
throw new Error(`HTTP error! status: ${{response.status}}`);
|
| 1085 |
-
}}
|
| 1086 |
-
|
| 1087 |
-
const result = await response.json();
|
| 1088 |
-
console.log('Message sent:', result);
|
| 1089 |
-
}} catch (error) {{
|
| 1090 |
-
console.error('Error sending message:', error);
|
| 1091 |
-
alert('Error sending message: ' + error.message);
|
| 1092 |
-
}}
|
| 1093 |
-
}}
|
| 1094 |
-
|
| 1095 |
-
addMessageToChat(role, content) {{
|
| 1096 |
-
const chatMessages = document.getElementById('chat-messages');
|
| 1097 |
-
const messageDiv = document.createElement('div');
|
| 1098 |
-
messageDiv.className = `chat-message ${{role}}`;
|
| 1099 |
-
|
| 1100 |
-
messageDiv.innerHTML = `
|
| 1101 |
-
<div class="message-role">${{role.charAt(0).toUpperCase() + role.slice(1)}}</div>
|
| 1102 |
-
<div class="message-content">${{content}}</div>
|
| 1103 |
-
`;
|
| 1104 |
-
|
| 1105 |
-
chatMessages.appendChild(messageDiv);
|
| 1106 |
-
chatMessages.scrollTop = chatMessages.scrollHeight;
|
| 1107 |
-
}}
|
| 1108 |
-
|
| 1109 |
-
async submitAction() {{
|
| 1110 |
-
const formData = new FormData(document.getElementById('action-form'));
|
| 1111 |
-
const action = {{}};
|
| 1112 |
-
|
| 1113 |
-
// Collect form data
|
| 1114 |
-
for (const [key, value] of formData.entries()) {{
|
| 1115 |
-
if (value !== '') {{
|
| 1116 |
-
// Handle tensor fields (tokens) - convert comma-separated string to array
|
| 1117 |
-
if (key === 'tokens') {{
|
| 1118 |
-
try {{
|
| 1119 |
-
action[key] = value.split(',').map(x => parseInt(x.trim())).filter(x => !isNaN(x));
|
| 1120 |
-
}} catch (e) {{
|
| 1121 |
-
console.error('Error parsing tokens:', e);
|
| 1122 |
-
action[key] = [];
|
| 1123 |
-
}}
|
| 1124 |
-
}} else {{
|
| 1125 |
-
action[key] = value;
|
| 1126 |
-
}}
|
| 1127 |
-
}}
|
| 1128 |
-
}}
|
| 1129 |
-
|
| 1130 |
-
try {{
|
| 1131 |
-
const response = await fetch('/web/step', {{
|
| 1132 |
-
method: 'POST',
|
| 1133 |
-
headers: {{ 'Content-Type': 'application/json' }},
|
| 1134 |
-
body: JSON.stringify({{ action }})
|
| 1135 |
-
}});
|
| 1136 |
-
|
| 1137 |
-
if (!response.ok) {{
|
| 1138 |
-
throw new Error(`HTTP error! status: ${{response.status}}`);
|
| 1139 |
-
}}
|
| 1140 |
-
|
| 1141 |
-
const result = await response.json();
|
| 1142 |
-
console.log('Step result:', result);
|
| 1143 |
-
}} catch (error) {{
|
| 1144 |
-
console.error('Error submitting action:', error);
|
| 1145 |
-
alert('Error submitting action: ' + error.message);
|
| 1146 |
-
}}
|
| 1147 |
-
}}
|
| 1148 |
-
|
| 1149 |
-
async resetEnvironment() {{
|
| 1150 |
-
try {{
|
| 1151 |
-
const response = await fetch('/web/reset', {{
|
| 1152 |
-
method: 'POST',
|
| 1153 |
-
headers: {{ 'Content-Type': 'application/json' }}
|
| 1154 |
-
}});
|
| 1155 |
-
|
| 1156 |
-
if (!response.ok) {{
|
| 1157 |
-
throw new Error(`HTTP error! status: ${{response.status}}`);
|
| 1158 |
-
}}
|
| 1159 |
-
|
| 1160 |
-
const result = await response.json();
|
| 1161 |
-
console.log('Reset result:', result);
|
| 1162 |
-
}} catch (error) {{
|
| 1163 |
-
console.error('Error resetting environment:', error);
|
| 1164 |
-
alert('Error resetting environment: ' + error.message);
|
| 1165 |
-
}}
|
| 1166 |
-
}}
|
| 1167 |
-
|
| 1168 |
-
async getState() {{
|
| 1169 |
-
try {{
|
| 1170 |
-
const response = await fetch('/web/state');
|
| 1171 |
-
const state = await response.json();
|
| 1172 |
-
console.log('Current state:', state);
|
| 1173 |
-
alert('Current state: ' + JSON.stringify(state, null, 2));
|
| 1174 |
-
}} catch (error) {{
|
| 1175 |
-
console.error('Error getting state:', error);
|
| 1176 |
-
alert('Error getting state: ' + error.message);
|
| 1177 |
-
}}
|
| 1178 |
-
}}
|
| 1179 |
-
|
| 1180 |
-
updateConnectionStatus(connected) {{
|
| 1181 |
-
const indicator = document.getElementById('connection-status');
|
| 1182 |
-
if (connected) {{
|
| 1183 |
-
indicator.className = 'status-indicator status-connected';
|
| 1184 |
-
}} else {{
|
| 1185 |
-
indicator.className = 'status-indicator status-disconnected';
|
| 1186 |
-
}}
|
| 1187 |
-
}}
|
| 1188 |
-
|
| 1189 |
-
updateUI(episodeState) {{
|
| 1190 |
-
// Check if this is a chat environment
|
| 1191 |
-
const isChatEnv = document.getElementById('chat-messages') !== null;
|
| 1192 |
-
|
| 1193 |
-
// Update current state
|
| 1194 |
-
document.getElementById('env-status').textContent =
|
| 1195 |
-
episodeState.is_reset ? 'Reset' : 'Running';
|
| 1196 |
-
document.getElementById('episode-id').textContent =
|
| 1197 |
-
episodeState.episode_id || '-';
|
| 1198 |
-
document.getElementById('step-count').textContent =
|
| 1199 |
-
episodeState.step_count.toString();
|
| 1200 |
-
|
| 1201 |
-
if (isChatEnv) {{
|
| 1202 |
-
// Update chat interface
|
| 1203 |
-
this.updateChatInterface(episodeState);
|
| 1204 |
-
}} else {{
|
| 1205 |
-
// Update traditional observation display
|
| 1206 |
-
const observationDiv = document.getElementById('current-observation');
|
| 1207 |
-
if (episodeState.current_observation) {{
|
| 1208 |
-
observationDiv.textContent = JSON.stringify(
|
| 1209 |
-
episodeState.current_observation, null, 2
|
| 1210 |
-
);
|
| 1211 |
-
}} else {{
|
| 1212 |
-
observationDiv.textContent = 'No observation yet';
|
| 1213 |
-
}}
|
| 1214 |
-
}}
|
| 1215 |
-
|
| 1216 |
-
// Update action logs
|
| 1217 |
-
const logsDiv = document.getElementById('action-logs');
|
| 1218 |
-
if (episodeState.action_logs.length === 0) {{
|
| 1219 |
-
logsDiv.innerHTML = 'No actions taken yet';
|
| 1220 |
-
}} else {{
|
| 1221 |
-
logsDiv.innerHTML = episodeState.action_logs.map(log => `
|
| 1222 |
-
<div class="log-entry">
|
| 1223 |
-
<div class="log-timestamp">${{log.timestamp}} (Step ${{log.step_count}})</div>
|
| 1224 |
-
<div class="log-action">Action: ${{JSON.stringify(log.action, null, 2)}}</div>
|
| 1225 |
-
<div class="log-observation">Observation: ${{JSON.stringify(log.observation, null, 2)}}</div>
|
| 1226 |
-
<div>
|
| 1227 |
-
<span class="log-reward">Reward: ${{log.reward !== null ? log.reward : 'None'}}</span>
|
| 1228 |
-
${{log.done ? '<span class="log-done">DONE</span>' : ''}}
|
| 1229 |
-
</div>
|
| 1230 |
-
</div>
|
| 1231 |
-
`).join('');
|
| 1232 |
-
}}
|
| 1233 |
-
}}
|
| 1234 |
-
|
| 1235 |
-
updateChatInterface(episodeState) {{
|
| 1236 |
-
const chatMessages = document.getElementById('chat-messages');
|
| 1237 |
-
if (!chatMessages) return;
|
| 1238 |
-
|
| 1239 |
-
// Clear existing messages (except system message)
|
| 1240 |
-
const systemMessage = chatMessages.querySelector('.chat-message.system');
|
| 1241 |
-
chatMessages.innerHTML = '';
|
| 1242 |
-
if (systemMessage) {{
|
| 1243 |
-
chatMessages.appendChild(systemMessage);
|
| 1244 |
-
}}
|
| 1245 |
-
|
| 1246 |
-
// Add messages from current observation
|
| 1247 |
-
if (episodeState.current_observation && episodeState.current_observation.messages) {{
|
| 1248 |
-
episodeState.current_observation.messages.forEach(msg => {{
|
| 1249 |
-
this.addMessageToChat(msg.role, msg.content);
|
| 1250 |
-
}});
|
| 1251 |
-
}}
|
| 1252 |
-
}}
|
| 1253 |
-
}}
|
| 1254 |
-
|
| 1255 |
-
// Initialize the web interface when the page loads
|
| 1256 |
-
document.addEventListener('DOMContentLoaded', () => {{
|
| 1257 |
-
new OpenEnvWebInterface();
|
| 1258 |
-
}});
|
| 1259 |
-
</script>
|
| 1260 |
-
</body>
|
| 1261 |
-
</html>
|
| 1262 |
-
""".replace('{_generate_action_form_fields(action_fields)}', _generate_action_form_fields(action_fields))
|
| 1263 |
-
|
| 1264 |
-
|
| 1265 |
-
def _generate_instructions_section(metadata: Optional[EnvironmentMetadata]) -> str:
|
| 1266 |
-
"""Generate the instructions section with environment documentation."""
|
| 1267 |
-
if not metadata or not metadata.readme_content:
|
| 1268 |
-
return ''
|
| 1269 |
-
|
| 1270 |
-
# Convert markdown to HTML (basic conversion)
|
| 1271 |
-
import re
|
| 1272 |
-
html_content = _markdown_to_html(metadata.readme_content)
|
| 1273 |
-
|
| 1274 |
-
return f'''
|
| 1275 |
-
<!-- Instructions Section -->
|
| 1276 |
-
<div class="instructions-section">
|
| 1277 |
-
<div class="instructions-header">
|
| 1278 |
-
<h3 class="instructions-title">{metadata.name}</h3>
|
| 1279 |
-
<button class="instructions-toggle" id="instructions-toggle">Show Instructions</button>
|
| 1280 |
-
</div>
|
| 1281 |
-
<div class="instructions-content" id="instructions-content">
|
| 1282 |
-
<div class="instructions-readme">
|
| 1283 |
-
{html_content}
|
| 1284 |
-
</div>
|
| 1285 |
-
</div>
|
| 1286 |
-
</div>
|
| 1287 |
-
'''
|
| 1288 |
|
| 1289 |
|
| 1290 |
def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
|
| 1291 |
"""Extract enhanced field metadata from Action class for form generation."""
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1295 |
action_fields = []
|
| 1296 |
-
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
for field_name, field_info in action_cls.__dataclass_fields__.items():
|
| 1300 |
-
if field_name == 'metadata':
|
| 1301 |
continue
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
-
|
| 1312 |
-
|
| 1313 |
-
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
|
|
|
|
|
|
| 1325 |
return action_fields
|
| 1326 |
|
| 1327 |
|
| 1328 |
-
def
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
from
|
| 1332 |
-
|
| 1333 |
-
metadata = {}
|
| 1334 |
-
|
| 1335 |
-
# Extract description from field docstring or annotation
|
| 1336 |
-
if hasattr(field_info, 'metadata') and field_info.metadata:
|
| 1337 |
-
# Check for custom metadata
|
| 1338 |
-
for meta in field_info.metadata:
|
| 1339 |
-
if isinstance(meta, dict):
|
| 1340 |
-
metadata.update(meta)
|
| 1341 |
-
|
| 1342 |
-
# Extract type information
|
| 1343 |
-
field_type = field_info.type
|
| 1344 |
-
origin = get_origin(field_type)
|
| 1345 |
-
|
| 1346 |
-
# Handle Literal types for dropdown choices
|
| 1347 |
-
if origin is Literal:
|
| 1348 |
-
args = get_args(field_type)
|
| 1349 |
-
metadata['choices'] = list(args)
|
| 1350 |
-
|
| 1351 |
-
# Handle Optional types
|
| 1352 |
-
if origin is Union:
|
| 1353 |
-
args = get_args(field_type)
|
| 1354 |
-
if len(args) == 2 and type(None) in args:
|
| 1355 |
-
# This is Optional[SomeType]
|
| 1356 |
-
non_none_type = args[0] if args[1] is type(None) else args[1]
|
| 1357 |
-
metadata['optional'] = True
|
| 1358 |
-
# Recursively check the non-None type for choices
|
| 1359 |
-
if get_origin(non_none_type) is Literal:
|
| 1360 |
-
metadata['choices'] = list(get_args(non_none_type))
|
| 1361 |
-
else:
|
| 1362 |
-
# Regular Union type
|
| 1363 |
-
metadata['choices'] = [str(arg) for arg in args if arg is not type(None)]
|
| 1364 |
-
|
| 1365 |
-
# Handle numeric constraints
|
| 1366 |
-
if field_type in (int, float):
|
| 1367 |
-
# Check for common constraint patterns in field name
|
| 1368 |
-
if 'count' in field_name.lower() or 'num' in field_name.lower():
|
| 1369 |
-
metadata['min_value'] = 0
|
| 1370 |
-
if 'id' in field_name.lower():
|
| 1371 |
-
metadata['min_value'] = 0
|
| 1372 |
-
|
| 1373 |
-
# Generate placeholder text
|
| 1374 |
-
if 'message' in field_name.lower():
|
| 1375 |
-
metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
|
| 1376 |
-
elif 'code' in field_name.lower():
|
| 1377 |
-
metadata['placeholder'] = 'Enter Python code here...'
|
| 1378 |
-
elif 'tokens' in field_name.lower():
|
| 1379 |
-
metadata['placeholder'] = 'Enter comma-separated token IDs (e.g., 1,2,3,4,5)'
|
| 1380 |
-
else:
|
| 1381 |
-
metadata['placeholder'] = f'Enter {field_name.replace("_", " ")}...'
|
| 1382 |
-
|
| 1383 |
-
# Generate help text based on field name and type
|
| 1384 |
-
if 'action_id' in field_name.lower():
|
| 1385 |
-
metadata['help_text'] = 'The action ID to execute in the environment'
|
| 1386 |
-
elif 'game_name' in field_name.lower():
|
| 1387 |
-
metadata['help_text'] = 'Name of the game or environment'
|
| 1388 |
-
elif 'tokens' in field_name.lower():
|
| 1389 |
-
metadata['help_text'] = 'Token IDs as a comma-separated list of integers'
|
| 1390 |
-
elif 'code' in field_name.lower():
|
| 1391 |
-
metadata['help_text'] = 'Python code to execute in the environment'
|
| 1392 |
-
elif 'message' in field_name.lower():
|
| 1393 |
-
metadata['help_text'] = 'Text message to send'
|
| 1394 |
-
|
| 1395 |
-
return metadata
|
| 1396 |
|
|
|
|
|
|
|
|
|
|
| 1397 |
|
| 1398 |
-
|
| 1399 |
-
"""Determine the appropriate HTML input type for a field type."""
|
| 1400 |
-
import typing
|
| 1401 |
-
from typing import get_origin, get_args, Literal, Union
|
| 1402 |
-
|
| 1403 |
-
# Handle direct types
|
| 1404 |
-
if field_type == str:
|
| 1405 |
-
return "text"
|
| 1406 |
-
elif field_type == int:
|
| 1407 |
-
return "number"
|
| 1408 |
-
elif field_type == float:
|
| 1409 |
-
return "number"
|
| 1410 |
-
elif field_type == bool:
|
| 1411 |
-
return "checkbox"
|
| 1412 |
-
|
| 1413 |
-
# Handle complex types
|
| 1414 |
-
origin = get_origin(field_type)
|
| 1415 |
-
|
| 1416 |
-
if origin is Literal:
|
| 1417 |
return "select"
|
| 1418 |
-
|
| 1419 |
-
|
| 1420 |
-
|
| 1421 |
-
|
| 1422 |
-
|
| 1423 |
-
|
| 1424 |
-
|
| 1425 |
-
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
-
|
| 1430 |
-
|
|
|
|
|
|
|
| 1431 |
return "text"
|
| 1432 |
|
|
|
|
|
|
|
| 1433 |
|
| 1434 |
-
|
| 1435 |
-
|
| 1436 |
-
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
|
| 1440 |
-
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
html_content = re.sub(r'^# (.*?)$', r'<h1>\1</h1>', html_content, flags=re.MULTILINE)
|
| 1444 |
-
html_content = re.sub(r'^## (.*?)$', r'<h2>\1</h2>', html_content, flags=re.MULTILINE)
|
| 1445 |
-
html_content = re.sub(r'^### (.*?)$', r'<h3>\1</h3>', html_content, flags=re.MULTILINE)
|
| 1446 |
-
|
| 1447 |
-
# Convert code blocks
|
| 1448 |
-
html_content = re.sub(r'```(.*?)\n(.*?)\n```', r'<pre><code>\2</code></pre>', html_content, flags=re.DOTALL)
|
| 1449 |
-
html_content = re.sub(r'`([^`]+)`', r'<code>\1</code>', html_content)
|
| 1450 |
-
|
| 1451 |
-
# Convert bold and italic
|
| 1452 |
-
html_content = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', html_content)
|
| 1453 |
-
html_content = re.sub(r'\*(.*?)\*', r'<em>\1</em>', html_content)
|
| 1454 |
-
|
| 1455 |
-
# Convert lists
|
| 1456 |
-
html_content = re.sub(r'^- (.*?)$', r'<li>\1</li>', html_content, flags=re.MULTILINE)
|
| 1457 |
-
html_content = re.sub(r'(<li>.*</li>)', r'<ul>\1</ul>', html_content, flags=re.DOTALL)
|
| 1458 |
-
|
| 1459 |
-
# Convert line breaks
|
| 1460 |
-
html_content = html_content.replace('\n', '<br>')
|
| 1461 |
-
|
| 1462 |
-
return html_content
|
| 1463 |
-
|
| 1464 |
-
|
| 1465 |
-
def _generate_action_interface(action_fields: List[Dict[str, Any]], is_chat_env: bool) -> str:
|
| 1466 |
-
"""Generate either a chat interface or action form based on environment type."""
|
| 1467 |
-
if is_chat_env:
|
| 1468 |
-
return _generate_chat_interface()
|
| 1469 |
-
else:
|
| 1470 |
-
return _generate_action_form(action_fields)
|
| 1471 |
-
|
| 1472 |
-
def _generate_chat_interface() -> str:
|
| 1473 |
-
"""Generate a chat-style interface for chat environments."""
|
| 1474 |
-
return '''
|
| 1475 |
-
<!-- Chat Interface -->
|
| 1476 |
-
<div class="chat-interface">
|
| 1477 |
-
<h3>Chat Interface</h3>
|
| 1478 |
-
<div class="chat-messages" id="chat-messages">
|
| 1479 |
-
<div class="chat-message system">
|
| 1480 |
-
<div class="message-role">System</div>
|
| 1481 |
-
<div class="message-content">Chat environment ready. Send a message to start the conversation.</div>
|
| 1482 |
-
</div>
|
| 1483 |
-
</div>
|
| 1484 |
-
<div class="chat-input-container">
|
| 1485 |
-
<div class="role-selector">
|
| 1486 |
-
<label for="message-role">Role:</label>
|
| 1487 |
-
<select id="message-role">
|
| 1488 |
-
<option value="user">User</option>
|
| 1489 |
-
<option value="assistant">Assistant</option>
|
| 1490 |
-
</select>
|
| 1491 |
-
</div>
|
| 1492 |
-
<div class="message-input">
|
| 1493 |
-
<textarea id="message-input" placeholder="Type your message here..." rows="3"></textarea>
|
| 1494 |
-
<button class="btn" id="send-message-btn">Send Message</button>
|
| 1495 |
-
</div>
|
| 1496 |
-
</div>
|
| 1497 |
-
</div>
|
| 1498 |
-
'''
|
| 1499 |
-
|
| 1500 |
-
def _generate_action_form(action_fields: List[Dict[str, Any]]) -> str:
|
| 1501 |
-
"""Generate a traditional action form for non-chat environments."""
|
| 1502 |
-
return f'''
|
| 1503 |
-
<!-- Action Form -->
|
| 1504 |
-
<div class="action-form">
|
| 1505 |
-
<h3>Take Action</h3>
|
| 1506 |
-
<form id="action-form">
|
| 1507 |
-
{_generate_action_form_fields(action_fields)}
|
| 1508 |
-
<button type="submit" class="btn" id="step-btn">Step</button>
|
| 1509 |
-
</form>
|
| 1510 |
-
</div>
|
| 1511 |
-
'''
|
| 1512 |
-
|
| 1513 |
-
def _generate_action_form_fields(action_fields: List[Dict[str, Any]]) -> str:
|
| 1514 |
-
"""Generate HTML form fields for action input with enhanced metadata."""
|
| 1515 |
-
if not action_fields:
|
| 1516 |
-
return '<p>No action fields available</p>'
|
| 1517 |
-
|
| 1518 |
-
fields_html = []
|
| 1519 |
-
for field in action_fields:
|
| 1520 |
-
field_html = _generate_single_field(field)
|
| 1521 |
-
fields_html.append(field_html)
|
| 1522 |
-
|
| 1523 |
-
return '\n'.join(fields_html)
|
| 1524 |
-
|
| 1525 |
-
|
| 1526 |
-
def _generate_single_field(field: Dict[str, Any]) -> str:
|
| 1527 |
-
"""Generate HTML for a single form field with enhanced metadata."""
|
| 1528 |
-
field_name = field['name']
|
| 1529 |
-
field_type = field['type']
|
| 1530 |
-
required = field['required']
|
| 1531 |
-
placeholder = field.get('placeholder', '')
|
| 1532 |
-
help_text = field.get('help_text', '')
|
| 1533 |
-
choices = field.get('choices', [])
|
| 1534 |
-
min_value = field.get('min_value')
|
| 1535 |
-
max_value = field.get('max_value')
|
| 1536 |
-
default_value = field.get('default_value')
|
| 1537 |
-
|
| 1538 |
-
# Build label with required indicator
|
| 1539 |
-
label_text = field_name.replace('_', ' ').title()
|
| 1540 |
-
if required:
|
| 1541 |
-
label_text += ' <span style="color: red;">*</span>'
|
| 1542 |
-
|
| 1543 |
-
# Build input attributes
|
| 1544 |
-
input_attrs = []
|
| 1545 |
-
if required:
|
| 1546 |
-
input_attrs.append('required')
|
| 1547 |
-
if placeholder:
|
| 1548 |
-
input_attrs.append(f'placeholder="{placeholder}"')
|
| 1549 |
-
if min_value is not None:
|
| 1550 |
-
input_attrs.append(f'min="{min_value}"')
|
| 1551 |
-
if max_value is not None:
|
| 1552 |
-
input_attrs.append(f'max="{max_value}"')
|
| 1553 |
-
if default_value is not None:
|
| 1554 |
-
input_attrs.append(f'value="{default_value}"')
|
| 1555 |
-
|
| 1556 |
-
attrs_str = ' '.join(input_attrs)
|
| 1557 |
-
|
| 1558 |
-
if field_type == 'checkbox':
|
| 1559 |
-
return f'''
|
| 1560 |
-
<div class="form-group">
|
| 1561 |
-
<label>
|
| 1562 |
-
<input type="checkbox" name="{field_name}" value="true" {attrs_str}>
|
| 1563 |
-
{label_text}
|
| 1564 |
-
</label>
|
| 1565 |
-
{f'<small class="help-text">{help_text}</small>' if help_text else ''}
|
| 1566 |
-
</div>
|
| 1567 |
-
'''
|
| 1568 |
-
|
| 1569 |
-
elif field_type == 'select':
|
| 1570 |
-
options_html = []
|
| 1571 |
-
if not required:
|
| 1572 |
-
options_html.append(f'<option value="">-- Select {label_text} --</option>')
|
| 1573 |
-
|
| 1574 |
-
for choice in choices:
|
| 1575 |
-
selected = 'selected' if str(choice) == str(default_value) else ''
|
| 1576 |
-
options_html.append(f'<option value="{choice}" {selected}>{choice}</option>')
|
| 1577 |
-
|
| 1578 |
-
return f'''
|
| 1579 |
-
<div class="form-group">
|
| 1580 |
-
<label for="{field_name}">{label_text}:</label>
|
| 1581 |
-
<select name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1582 |
-
{''.join(options_html)}
|
| 1583 |
-
</select>
|
| 1584 |
-
{f'<small class="help-text">{help_text}</small>' if help_text else ''}
|
| 1585 |
-
</div>
|
| 1586 |
-
'''
|
| 1587 |
-
|
| 1588 |
-
elif field_type == 'tensor':
|
| 1589 |
-
return f'''
|
| 1590 |
-
<div class="form-group">
|
| 1591 |
-
<label for="{field_name}">{label_text} (comma-separated integers):</label>
|
| 1592 |
-
<input type="text" name="{field_name}" id="{field_name}" {attrs_str}>
|
| 1593 |
-
<small class="help-text">{help_text or 'Enter token IDs as comma-separated integers (e.g., 1,2,3,4,5)'}</small>
|
| 1594 |
-
</div>
|
| 1595 |
-
'''
|
| 1596 |
-
|
| 1597 |
-
elif field_type == 'text' and ('message' in field_name.lower() or 'code' in field_name.lower()):
|
| 1598 |
-
return f'''
|
| 1599 |
-
<div class="form-group">
|
| 1600 |
-
<label for="{field_name}">{label_text}:</label>
|
| 1601 |
-
<textarea name="{field_name}" id="{field_name}" rows="3" {attrs_str}></textarea>
|
| 1602 |
-
{f'<small class="help-text">{help_text}</small>' if help_text else ''}
|
| 1603 |
-
</div>
|
| 1604 |
-
'''
|
| 1605 |
-
|
| 1606 |
else:
|
| 1607 |
-
return f'''
|
| 1608 |
-
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
|
| 1612 |
-
|
| 1613 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
Web interface for OpenEnv environments.
|
| 9 |
|
| 10 |
+
When ENABLE_WEB_INTERFACE is set, the server exposes a Gradio UI at /web for
|
| 11 |
+
reset, step, and state observation. Controlled by the CLI enable_interface
|
| 12 |
+
option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var.
|
| 13 |
"""
|
| 14 |
|
| 15 |
from __future__ import annotations
|
| 16 |
|
| 17 |
+
import asyncio
|
| 18 |
import json
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
|
| 20 |
from datetime import datetime
|
| 21 |
+
from typing import Any, Callable, Dict, List, Optional, Type
|
| 22 |
|
| 23 |
+
import gradio as gr
|
| 24 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 25 |
+
from pydantic import BaseModel, ConfigDict, Field
|
|
|
|
| 26 |
|
| 27 |
+
from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
|
| 28 |
+
from .gradio_ui import build_gradio_app, get_gradio_display_title
|
| 29 |
from .interfaces import Environment
|
| 30 |
+
from .serialization import deserialize_action_with_preprocessing, serialize_observation
|
| 31 |
+
from .types import Action, EnvironmentMetadata, Observation, State
|
| 32 |
|
| 33 |
+
# Quick Start markdown template; placeholders match init suffixes (__ENV_NAME__, __ENV_CLASS_NAME__*).
|
| 34 |
+
DEFAULT_QUICK_START_MARKDOWN = """
|
| 35 |
+
### Connect to this environment
|
| 36 |
|
| 37 |
+
Connect from Python using `__ENV_CLASS_NAME__Env`:
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env
|
| 41 |
+
|
| 42 |
+
with __ENV_CLASS_NAME__Env.from_env("<SPACE_ID>") as env:
|
| 43 |
+
result = await env.step(__ENV_CLASS_NAME__Action(message="..."))
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Or connect directly to a running server:
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
env = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000")
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Contribute to this environment
|
| 53 |
+
|
| 54 |
+
Submit improvements via pull request on the Hugging Face Hub.
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
openenv fork <SPACE_ID> --repo-id <your-username>/<your-repo-name>
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Then make your changes and submit a pull request:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
cd <forked-repo>
|
| 64 |
+
openenv push <SPACE_ID> --create-pr
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
For more information, see the [OpenEnv documentation](https://meta-pytorch.org/OpenEnv/).
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_quick_start_markdown(
|
| 72 |
+
metadata: Optional[EnvironmentMetadata],
|
| 73 |
+
action_cls: Type[Action],
|
| 74 |
+
observation_cls: Type[Observation],
|
| 75 |
+
) -> str:
|
| 76 |
+
"""
|
| 77 |
+
Build Quick Start markdown with class names replaced from current env (init-style suffixes).
|
| 78 |
+
|
| 79 |
+
Uses the same placeholder names as the init template so that __ENV_CLASS_NAME__Env,
|
| 80 |
+
__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation and __ENV_NAME__ are
|
| 81 |
+
replaced with the actual class/package names.
|
| 82 |
+
"""
|
| 83 |
+
import os
|
| 84 |
+
|
| 85 |
+
# Prefix from action class (e.g. EchoAction -> Echo)
|
| 86 |
+
action_name = getattr(action_cls, "__name__", "Action")
|
| 87 |
+
if action_name.endswith("Action"):
|
| 88 |
+
prefix = action_name[: -len("Action")]
|
| 89 |
+
else:
|
| 90 |
+
prefix = action_name.replace("Action", "").strip() or "Env"
|
| 91 |
+
|
| 92 |
+
env_client_name = f"{prefix}Env"
|
| 93 |
+
obs_name = getattr(observation_cls, "__name__", "Observation")
|
| 94 |
+
pkg_name = (metadata.name if metadata else "env").replace(" ", "_").lower()
|
| 95 |
+
|
| 96 |
+
space_id = os.environ.get("SPACE_ID", "<hf-username>/<hf-repo-name>")
|
| 97 |
+
|
| 98 |
+
content = DEFAULT_QUICK_START_MARKDOWN
|
| 99 |
+
content = content.replace("__ENV_CLASS_NAME__Env", env_client_name)
|
| 100 |
+
content = content.replace("__ENV_CLASS_NAME__Action", action_name)
|
| 101 |
+
content = content.replace("__ENV_CLASS_NAME__Observation", obs_name)
|
| 102 |
+
content = content.replace("__ENV_CLASS_NAME__", prefix)
|
| 103 |
+
content = content.replace("__ENV_NAME__", pkg_name)
|
| 104 |
+
content = content.replace("<SPACE_ID>", space_id)
|
| 105 |
+
return content.strip()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def load_environment_metadata(
|
| 109 |
+
env: Environment, env_name: Optional[str] = None
|
| 110 |
+
) -> EnvironmentMetadata:
|
| 111 |
"""
|
| 112 |
Load environment metadata including README content.
|
| 113 |
+
|
| 114 |
Args:
|
| 115 |
+
env: The environment instance, class, or factory function.
|
| 116 |
+
- If a class: used as a factory, won't call instance methods
|
| 117 |
+
- If a function: used as a factory, won't call instance methods
|
| 118 |
+
- If an instance: may call get_metadata() if available
|
| 119 |
env_name: Optional environment name for README file lookup
|
| 120 |
+
|
| 121 |
Returns:
|
| 122 |
EnvironmentMetadata with loaded information
|
| 123 |
"""
|
| 124 |
+
import inspect
|
| 125 |
+
|
| 126 |
+
# Determine what type of env we received:
|
| 127 |
+
# 1. A class (used as factory) - e.g., PythonCodeActEnv
|
| 128 |
+
# 2. A function (factory function) - e.g., create_chat_environment
|
| 129 |
+
# 3. An actual instance - e.g., SnakeEnvironment()
|
| 130 |
+
is_class = inspect.isclass(env)
|
| 131 |
+
is_function = inspect.isfunction(env) or inspect.ismethod(env)
|
| 132 |
+
is_factory = is_class or is_function
|
| 133 |
+
|
| 134 |
+
# Try to get metadata from environment if it's an instance with get_metadata
|
| 135 |
+
if not is_factory and hasattr(env, "get_metadata"):
|
| 136 |
return env.get_metadata()
|
| 137 |
+
|
| 138 |
+
# Determine the class name for default metadata
|
| 139 |
+
if is_class:
|
| 140 |
+
# env is the class itself
|
| 141 |
+
class_name = env.__name__
|
| 142 |
+
elif is_function:
|
| 143 |
+
# env is a factory function - use its name or derive from env_name
|
| 144 |
+
class_name = env_name or env.__name__
|
| 145 |
+
else:
|
| 146 |
+
# env is an instance
|
| 147 |
+
class_name = env.__class__.__name__
|
| 148 |
+
|
| 149 |
# Default metadata
|
| 150 |
metadata = EnvironmentMetadata(
|
| 151 |
+
name=env_name or class_name,
|
| 152 |
+
description=f"{class_name} environment",
|
| 153 |
+
version="1.0.0",
|
| 154 |
)
|
| 155 |
+
|
| 156 |
# Try to load README from file system
|
| 157 |
readme_content = _load_readme_from_filesystem(env_name)
|
| 158 |
if readme_content:
|
| 159 |
metadata.readme_content = readme_content
|
| 160 |
+
|
| 161 |
return metadata
|
| 162 |
|
| 163 |
|
| 164 |
def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
|
| 165 |
"""
|
| 166 |
Load README content from the filesystem.
|
| 167 |
+
|
| 168 |
Tries multiple locations:
|
| 169 |
1. Container filesystem: /app/README.md
|
| 170 |
2. Local development: src/envs/{env_name}/README.md
|
|
|
|
| 172 |
"""
|
| 173 |
import os
|
| 174 |
from pathlib import Path
|
| 175 |
+
|
| 176 |
# Try container filesystem first
|
| 177 |
container_readme = Path("/app/README.md")
|
| 178 |
if container_readme.exists():
|
| 179 |
try:
|
| 180 |
+
return container_readme.read_text(encoding="utf-8")
|
| 181 |
except Exception:
|
| 182 |
pass
|
| 183 |
+
|
| 184 |
# Try environment variable path
|
| 185 |
custom_path = os.environ.get("ENV_README_PATH")
|
| 186 |
if custom_path and Path(custom_path).exists():
|
| 187 |
try:
|
| 188 |
+
return Path(custom_path).read_text(encoding="utf-8")
|
| 189 |
except Exception:
|
| 190 |
pass
|
| 191 |
+
|
| 192 |
# Try local development path
|
| 193 |
if env_name:
|
| 194 |
local_readme = Path(f"src/envs/{env_name}/README.md")
|
| 195 |
if local_readme.exists():
|
| 196 |
try:
|
| 197 |
+
return local_readme.read_text(encoding="utf-8")
|
| 198 |
except Exception:
|
| 199 |
pass
|
| 200 |
+
|
| 201 |
return None
|
| 202 |
|
| 203 |
|
| 204 |
+
class ActionLog(BaseModel):
|
|
|
|
| 205 |
"""Log entry for an action taken."""
|
| 206 |
+
|
| 207 |
+
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
| 208 |
+
|
| 209 |
+
timestamp: str = Field(description="Timestamp when action was taken")
|
| 210 |
+
action: Dict[str, Any] = Field(description="Action that was taken")
|
| 211 |
+
observation: Dict[str, Any] = Field(description="Observation returned from action")
|
| 212 |
+
reward: Optional[float] = Field(
|
| 213 |
+
default=None, description="Reward received from action"
|
| 214 |
+
)
|
| 215 |
+
done: bool = Field(description="Whether the episode is done after this action")
|
| 216 |
+
step_count: int = Field(description="Step count when this action was taken")
|
| 217 |
|
| 218 |
|
| 219 |
+
class EpisodeState(BaseModel):
|
|
|
|
| 220 |
"""Current episode state for the web interface."""
|
| 221 |
+
|
| 222 |
+
model_config = ConfigDict(extra="forbid", validate_assignment=True)
|
| 223 |
+
|
| 224 |
+
episode_id: Optional[str] = Field(default=None, description="Current episode ID")
|
| 225 |
+
step_count: int = Field(description="Current step count in episode")
|
| 226 |
+
current_observation: Optional[Dict[str, Any]] = Field(
|
| 227 |
+
default=None, description="Current observation"
|
| 228 |
+
)
|
| 229 |
+
action_logs: List[ActionLog] = Field(
|
| 230 |
+
default_factory=list, description="List of action logs"
|
| 231 |
+
)
|
| 232 |
+
is_reset: bool = Field(
|
| 233 |
+
default=True, description="Whether the episode has been reset"
|
| 234 |
+
)
|
| 235 |
|
| 236 |
|
| 237 |
class WebInterfaceManager:
|
| 238 |
"""Manages the web interface for an environment."""
|
| 239 |
+
|
| 240 |
+
MAX_ACTION_LOGS = 1000
|
| 241 |
+
|
| 242 |
def __init__(
|
| 243 |
self,
|
| 244 |
env: Environment,
|
|
|
|
| 246 |
observation_cls: Type[Observation],
|
| 247 |
metadata: Optional[EnvironmentMetadata] = None,
|
| 248 |
):
|
| 249 |
+
import inspect
|
| 250 |
+
|
| 251 |
+
# If env is a class or factory function, instantiate it
|
| 252 |
+
if inspect.isclass(env) or inspect.isfunction(env):
|
| 253 |
+
self.env = env()
|
| 254 |
+
else:
|
| 255 |
+
self.env = env
|
| 256 |
self.action_cls = action_cls
|
| 257 |
self.observation_cls = observation_cls
|
| 258 |
self.metadata = metadata or EnvironmentMetadata(
|
| 259 |
name=env.__class__.__name__,
|
| 260 |
+
description=f"{env.__class__.__name__} environment",
|
| 261 |
)
|
| 262 |
self.episode_state = EpisodeState(
|
| 263 |
episode_id=None,
|
| 264 |
step_count=0,
|
| 265 |
current_observation=None,
|
| 266 |
+
action_logs=[],
|
| 267 |
)
|
| 268 |
self.connected_clients: List[WebSocket] = []
|
| 269 |
+
# Thread pool for running sync code (e.g., Playwright sync API) in async context
|
| 270 |
+
self._executor = ThreadPoolExecutor(max_workers=1)
|
| 271 |
+
|
| 272 |
+
async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
|
| 273 |
+
"""Run a synchronous function in the thread pool executor.
|
| 274 |
+
|
| 275 |
+
This is needed for environments using sync libraries (e.g., Playwright sync API)
|
| 276 |
+
that cannot be called directly from an async context.
|
| 277 |
+
"""
|
| 278 |
+
loop = asyncio.get_event_loop()
|
| 279 |
+
# Use default arguments to capture values at lambda definition time
|
| 280 |
+
# to avoid closure issues with late binding
|
| 281 |
+
return await loop.run_in_executor(
|
| 282 |
+
self._executor, lambda f=func, a=args, kw=kwargs: f(*a, **kw)
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
async def connect_websocket(self, websocket: WebSocket):
|
| 286 |
"""Connect a new WebSocket client."""
|
| 287 |
await websocket.accept()
|
| 288 |
self.connected_clients.append(websocket)
|
| 289 |
+
|
| 290 |
# Send current state to the new client
|
| 291 |
await self._send_state_update()
|
| 292 |
+
|
| 293 |
async def disconnect_websocket(self, websocket: WebSocket):
|
| 294 |
"""Disconnect a WebSocket client."""
|
| 295 |
if websocket in self.connected_clients:
|
| 296 |
self.connected_clients.remove(websocket)
|
| 297 |
+
|
| 298 |
async def _send_state_update(self):
|
| 299 |
"""Send current state to all connected clients."""
|
| 300 |
if not self.connected_clients:
|
| 301 |
return
|
| 302 |
+
|
| 303 |
state_data = {
|
| 304 |
"type": "state_update",
|
| 305 |
+
"episode_state": self.episode_state.model_dump(),
|
| 306 |
}
|
| 307 |
+
|
| 308 |
# Send to all connected clients
|
| 309 |
disconnected_clients = []
|
| 310 |
for client in self.connected_clients:
|
| 311 |
try:
|
| 312 |
await client.send_text(json.dumps(state_data))
|
| 313 |
+
except Exception:
|
| 314 |
disconnected_clients.append(client)
|
| 315 |
+
|
| 316 |
# Remove disconnected clients
|
| 317 |
for client in disconnected_clients:
|
| 318 |
self.connected_clients.remove(client)
|
| 319 |
+
|
| 320 |
async def reset_environment(self) -> Dict[str, Any]:
|
| 321 |
"""Reset the environment and update state."""
|
| 322 |
+
# Run sync reset in thread pool to avoid blocking event loop
|
| 323 |
+
# and to support environments using sync libraries (e.g., Playwright)
|
| 324 |
+
observation: Observation = await self._run_sync_in_thread_pool(self.env.reset)
|
| 325 |
+
state: State = self.env.state
|
| 326 |
+
|
| 327 |
+
# Serialize observation once using shared utility
|
| 328 |
+
serialized = serialize_observation(observation)
|
| 329 |
+
|
| 330 |
# Update episode state
|
| 331 |
self.episode_state.episode_id = state.episode_id
|
| 332 |
self.episode_state.step_count = 0
|
| 333 |
+
self.episode_state.current_observation = serialized["observation"]
|
| 334 |
self.episode_state.action_logs = []
|
| 335 |
self.episode_state.is_reset = True
|
| 336 |
+
|
| 337 |
# Send state update
|
| 338 |
await self._send_state_update()
|
| 339 |
+
|
| 340 |
+
return serialized
|
| 341 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 343 |
"""Execute a step in the environment and update state."""
|
| 344 |
+
# Deserialize action with preprocessing for web interface special cases
|
| 345 |
+
action: Action = deserialize_action_with_preprocessing(
|
| 346 |
+
action_data, self.action_cls
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Run sync step in thread pool to avoid blocking event loop
|
| 350 |
+
# and to support environments using sync libraries (e.g., Playwright)
|
| 351 |
+
observation: Observation = await self._run_sync_in_thread_pool(
|
| 352 |
+
self.env.step, action
|
| 353 |
+
)
|
| 354 |
+
state: State = self.env.state
|
| 355 |
+
|
| 356 |
+
# Serialize observation once using shared utility
|
| 357 |
+
serialized = serialize_observation(observation)
|
| 358 |
+
|
| 359 |
# Create action log
|
| 360 |
action_log = ActionLog(
|
| 361 |
timestamp=datetime.now().isoformat(),
|
| 362 |
+
action=action.model_dump(exclude={"metadata"}),
|
| 363 |
+
observation=serialized["observation"],
|
| 364 |
reward=observation.reward,
|
| 365 |
done=observation.done,
|
| 366 |
+
step_count=state.step_count,
|
| 367 |
)
|
| 368 |
+
|
| 369 |
# Update episode state
|
| 370 |
self.episode_state.episode_id = state.episode_id
|
| 371 |
self.episode_state.step_count = state.step_count
|
| 372 |
+
self.episode_state.current_observation = serialized["observation"]
|
| 373 |
self.episode_state.action_logs.append(action_log)
|
| 374 |
+
if len(self.episode_state.action_logs) > self.MAX_ACTION_LOGS:
|
| 375 |
+
self.episode_state.action_logs = self.episode_state.action_logs[
|
| 376 |
+
-self.MAX_ACTION_LOGS :
|
| 377 |
+
]
|
| 378 |
self.episode_state.is_reset = False
|
| 379 |
+
|
| 380 |
# Send state update
|
| 381 |
await self._send_state_update()
|
| 382 |
+
|
| 383 |
+
return serialized
|
| 384 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
def get_state(self) -> Dict[str, Any]:
|
| 386 |
"""Get current environment state."""
|
| 387 |
+
state: State = self.env.state
|
| 388 |
+
return state.model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
|
| 391 |
def create_web_interface_app(
|
|
|
|
| 393 |
action_cls: Type[Action],
|
| 394 |
observation_cls: Type[Observation],
|
| 395 |
env_name: Optional[str] = None,
|
| 396 |
+
max_concurrent_envs: Optional[int] = None,
|
| 397 |
+
concurrency_config: Optional[Any] = None,
|
| 398 |
+
gradio_builder: Optional[Callable[..., Any]] = None,
|
| 399 |
) -> FastAPI:
|
| 400 |
"""
|
| 401 |
Create a FastAPI application with web interface for the given environment.
|
| 402 |
+
|
| 403 |
Args:
|
| 404 |
env: The Environment instance to serve
|
| 405 |
action_cls: The Action subclass this environment expects
|
| 406 |
observation_cls: The Observation subclass this environment returns
|
| 407 |
env_name: Optional environment name for README loading
|
| 408 |
+
max_concurrent_envs: Maximum concurrent WebSocket sessions
|
| 409 |
+
concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings
|
| 410 |
+
gradio_builder: Optional callable (web_manager, action_fields, metadata,
|
| 411 |
+
is_chat_env, title, quick_start_md) -> gr.Blocks to use instead of the
|
| 412 |
+
default Gradio UI. Lets envs replace or customize the /web interface.
|
| 413 |
+
|
| 414 |
Returns:
|
| 415 |
FastAPI application instance with web interface
|
| 416 |
"""
|
| 417 |
from .http_server import create_fastapi_app
|
| 418 |
+
|
| 419 |
# Create the base environment app
|
| 420 |
+
app = create_fastapi_app(
|
| 421 |
+
env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
# Load environment metadata
|
| 425 |
metadata = load_environment_metadata(env, env_name)
|
| 426 |
+
|
| 427 |
# Create web interface manager
|
| 428 |
web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
|
| 429 |
+
|
| 430 |
+
# Web API routes first (so they take precedence over Gradio mount at /web)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
@app.get("/web/metadata")
|
| 432 |
async def web_metadata():
|
| 433 |
"""Get environment metadata."""
|
| 434 |
+
return web_manager.metadata.model_dump()
|
| 435 |
+
|
| 436 |
+
@app.websocket("/ws/ui")
|
| 437 |
+
async def websocket_ui_endpoint(websocket: WebSocket):
|
| 438 |
+
"""WebSocket endpoint for web UI real-time updates.
|
| 439 |
+
|
| 440 |
+
Note: Uses /ws/ui to avoid conflict with /ws in http_server.py
|
| 441 |
+
which is used for concurrent environment sessions.
|
| 442 |
+
"""
|
| 443 |
await web_manager.connect_websocket(websocket)
|
| 444 |
try:
|
| 445 |
while True:
|
|
|
|
| 447 |
await websocket.receive_text()
|
| 448 |
except WebSocketDisconnect:
|
| 449 |
await web_manager.disconnect_websocket(websocket)
|
| 450 |
+
|
| 451 |
@app.post("/web/reset")
|
| 452 |
async def web_reset():
|
| 453 |
"""Reset endpoint for web interface."""
|
| 454 |
return await web_manager.reset_environment()
|
| 455 |
+
|
| 456 |
@app.post("/web/step")
|
| 457 |
async def web_step(request: Dict[str, Any]):
|
| 458 |
"""Step endpoint for web interface."""
|
| 459 |
# Check if this is a message-based request (chat environment)
|
| 460 |
if "message" in request:
|
| 461 |
message = request["message"]
|
| 462 |
+
if hasattr(web_manager.env, "message_to_action"):
|
| 463 |
+
action = web_manager.env.message_to_action(message)
|
| 464 |
+
if hasattr(action, "tokens"):
|
| 465 |
+
action_data = {"tokens": action.tokens.tolist()}
|
| 466 |
+
else:
|
| 467 |
+
action_data = action.model_dump(exclude={"metadata"})
|
| 468 |
+
else:
|
| 469 |
+
action_data = {"message": message}
|
| 470 |
else:
|
| 471 |
action_data = request.get("action", {})
|
| 472 |
+
|
| 473 |
return await web_manager.step_environment(action_data)
|
| 474 |
+
|
| 475 |
@app.get("/web/state")
|
| 476 |
async def web_state():
|
| 477 |
"""State endpoint for web interface."""
|
| 478 |
return web_manager.get_state()
|
| 479 |
+
|
| 480 |
+
action_fields = _extract_action_fields(action_cls)
|
| 481 |
+
is_chat_env = _is_chat_env(action_cls)
|
| 482 |
+
quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls)
|
| 483 |
+
|
| 484 |
+
default_blocks = build_gradio_app(
|
| 485 |
+
web_manager,
|
| 486 |
+
action_fields,
|
| 487 |
+
metadata,
|
| 488 |
+
is_chat_env,
|
| 489 |
+
title=metadata.name,
|
| 490 |
+
quick_start_md=quick_start_md,
|
| 491 |
+
)
|
| 492 |
+
if gradio_builder is not None:
|
| 493 |
+
custom_blocks = gradio_builder(
|
| 494 |
+
web_manager,
|
| 495 |
+
action_fields,
|
| 496 |
+
metadata,
|
| 497 |
+
is_chat_env,
|
| 498 |
+
metadata.name,
|
| 499 |
+
quick_start_md,
|
| 500 |
+
)
|
| 501 |
+
if not isinstance(custom_blocks, gr.Blocks):
|
| 502 |
+
raise TypeError(
|
| 503 |
+
f"gradio_builder must return a gr.Blocks instance, "
|
| 504 |
+
f"got {type(custom_blocks).__name__}"
|
| 505 |
+
)
|
| 506 |
+
gradio_blocks = gr.TabbedInterface(
|
| 507 |
+
[default_blocks, custom_blocks],
|
| 508 |
+
tab_names=["Playground", "Visualization"],
|
| 509 |
+
title=get_gradio_display_title(metadata),
|
| 510 |
+
)
|
| 511 |
+
else:
|
| 512 |
+
gradio_blocks = default_blocks
|
| 513 |
+
app = gr.mount_gradio_app(
|
| 514 |
+
app,
|
| 515 |
+
gradio_blocks,
|
| 516 |
+
path="/web",
|
| 517 |
+
theme=OPENENV_GRADIO_THEME,
|
| 518 |
+
css=OPENENV_GRADIO_CSS,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
return app
|
| 522 |
|
| 523 |
|
| 524 |
+
def _is_chat_env(action_cls: Type[Action]) -> bool:
|
| 525 |
+
"""Return True if the action class is a chat-style env (tokens field)."""
|
| 526 |
+
if hasattr(action_cls, "model_fields"):
|
| 527 |
+
for field_name, field_info in action_cls.model_fields.items():
|
| 528 |
+
if (
|
| 529 |
+
field_name == "tokens"
|
| 530 |
+
and hasattr(field_info.annotation, "__name__")
|
| 531 |
+
and "Tensor" in str(field_info.annotation)
|
| 532 |
+
):
|
| 533 |
+
return True
|
| 534 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
|
| 536 |
|
| 537 |
def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
|
| 538 |
"""Extract enhanced field metadata from Action class for form generation."""
|
| 539 |
+
# Use Pydantic's JSON schema generation for robust metadata extraction
|
| 540 |
+
try:
|
| 541 |
+
schema = action_cls.model_json_schema()
|
| 542 |
+
except AttributeError:
|
| 543 |
+
# Fallback for non-Pydantic v2 models or if something goes wrong
|
| 544 |
+
return []
|
| 545 |
+
|
| 546 |
+
properties = schema.get("properties", {})
|
| 547 |
+
required_fields = schema.get("required", [])
|
| 548 |
+
|
| 549 |
action_fields = []
|
| 550 |
+
|
| 551 |
+
for field_name, field_info in properties.items():
|
| 552 |
+
if field_name == "metadata":
|
|
|
|
|
|
|
| 553 |
continue
|
| 554 |
+
|
| 555 |
+
# JSON schema "type" can be a string or list/undefined
|
| 556 |
+
# Determine our internal input type
|
| 557 |
+
input_type = _determine_input_type_from_schema(field_info, field_name)
|
| 558 |
+
|
| 559 |
+
is_required = field_name in required_fields
|
| 560 |
+
|
| 561 |
+
action_fields.append(
|
| 562 |
+
{
|
| 563 |
+
"name": field_name,
|
| 564 |
+
"type": input_type,
|
| 565 |
+
"required": is_required,
|
| 566 |
+
"description": field_info.get("description", ""),
|
| 567 |
+
"default_value": field_info.get("default"),
|
| 568 |
+
"choices": field_info.get("enum"),
|
| 569 |
+
"min_value": field_info.get("minimum"),
|
| 570 |
+
"max_value": field_info.get("maximum"),
|
| 571 |
+
"min_length": field_info.get("minLength"),
|
| 572 |
+
"max_length": field_info.get("maxLength"),
|
| 573 |
+
"pattern": field_info.get("pattern"),
|
| 574 |
+
"placeholder": _generate_placeholder(field_name, field_info),
|
| 575 |
+
"help_text": _generate_help_text(field_name, field_info),
|
| 576 |
+
}
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
return action_fields
|
| 580 |
|
| 581 |
|
| 582 |
+
def _determine_input_type_from_schema(
|
| 583 |
+
field_info: Dict[str, Any], field_name: str
|
| 584 |
+
) -> str:
|
| 585 |
+
"""Determine input type from JSON schema for form generation (Gradio UI)."""
|
| 586 |
+
schema_type = field_info.get("type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
+
# Check for specific tensor field convention
|
| 589 |
+
if "tokens" in field_name.lower():
|
| 590 |
+
return "tensor"
|
| 591 |
|
| 592 |
+
if "enum" in field_info:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
return "select"
|
| 594 |
+
|
| 595 |
+
if schema_type == "boolean":
|
| 596 |
+
return "checkbox"
|
| 597 |
+
|
| 598 |
+
if schema_type == "integer" or schema_type == "number":
|
| 599 |
+
return "number"
|
| 600 |
+
|
| 601 |
+
if schema_type == "string":
|
| 602 |
+
# Check if it should be a textarea
|
| 603 |
+
if (
|
| 604 |
+
field_info.get("maxLength", 0) > 100
|
| 605 |
+
or "message" in field_name.lower()
|
| 606 |
+
or "code" in field_name.lower()
|
| 607 |
+
):
|
| 608 |
+
return "textarea"
|
| 609 |
return "text"
|
| 610 |
|
| 611 |
+
# Default fallback
|
| 612 |
+
return "text"
|
| 613 |
|
| 614 |
+
|
| 615 |
+
def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str:
|
| 616 |
+
"""Generate placeholder text."""
|
| 617 |
+
if "message" in field_name.lower():
|
| 618 |
+
return f"Enter {field_name.replace('_', ' ')}..."
|
| 619 |
+
elif "code" in field_name.lower():
|
| 620 |
+
return "Enter Python code here..."
|
| 621 |
+
elif "tokens" in field_name.lower():
|
| 622 |
+
return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 623 |
else:
|
| 624 |
+
return f"Enter {field_name.replace('_', ' ')}..."
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str:
|
| 628 |
+
"""Generate help text."""
|
| 629 |
+
description = field_info.get("description", "")
|
| 630 |
+
if description:
|
| 631 |
+
return description
|
| 632 |
+
|
| 633 |
+
if "action_id" in field_name.lower():
|
| 634 |
+
return "The action ID to execute in environment"
|
| 635 |
+
elif "game_name" in field_name.lower():
|
| 636 |
+
return "Name of game or environment"
|
| 637 |
+
elif "tokens" in field_name.lower():
|
| 638 |
+
return "Token IDs as a comma-separated list of integers"
|
| 639 |
+
elif "code" in field_name.lower():
|
| 640 |
+
return "Python code to execute in environment"
|
| 641 |
+
elif "message" in field_name.lower():
|
| 642 |
+
return "Text message to send"
|
| 643 |
+
|
| 644 |
+
return ""
|
src/core/evals/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Evaluation harness support for OpenEnv."""
|
| 8 |
+
|
| 9 |
+
from openenv.core.evals.base import EvalHarness
|
| 10 |
+
from openenv.core.evals.inspect_harness import InspectAIHarness
|
| 11 |
+
from openenv.core.evals.types import EvalConfig, EvalResult
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"EvalHarness",
|
| 15 |
+
"EvalConfig",
|
| 16 |
+
"EvalResult",
|
| 17 |
+
"InspectAIHarness",
|
| 18 |
+
]
|
src/core/evals/base.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Base class for evaluation harnesses."""
|
| 8 |
+
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from typing import Any, Dict
|
| 11 |
+
|
| 12 |
+
from openenv.core.evals.types import EvalConfig, EvalResult
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class EvalHarness(ABC):
|
| 16 |
+
"""Abstract base class for evaluation harnesses.
|
| 17 |
+
|
| 18 |
+
Subclasses implement run() to define evaluation logic.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def run(
|
| 23 |
+
self,
|
| 24 |
+
harness_version: str,
|
| 25 |
+
library_versions: Dict[str, str],
|
| 26 |
+
dataset: str,
|
| 27 |
+
eval_parameters: Dict[str, Any],
|
| 28 |
+
) -> Dict[str, Any]:
|
| 29 |
+
"""Run the evaluation and return scores.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
harness_version: Version of the evaluation harness.
|
| 33 |
+
library_versions: Versions of libraries used in the evaluation.
|
| 34 |
+
dataset: Name of the dataset to evaluate on.
|
| 35 |
+
eval_parameters: Parameters for the evaluation.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Dictionary of scores from the evaluation.
|
| 39 |
+
"""
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
|
| 42 |
+
def run_from_config(self, config: EvalConfig) -> EvalResult:
|
| 43 |
+
"""Run evaluation from an EvalConfig and return an EvalResult.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
config: Configuration for the evaluation.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
EvalResult containing the config and scores.
|
| 50 |
+
"""
|
| 51 |
+
scores = self.run(
|
| 52 |
+
harness_version=config.harness_version,
|
| 53 |
+
library_versions=config.library_versions,
|
| 54 |
+
dataset=config.dataset,
|
| 55 |
+
eval_parameters=config.eval_parameters,
|
| 56 |
+
)
|
| 57 |
+
return EvalResult(config=config, scores=scores)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def name(self) -> str:
|
| 61 |
+
"""Return the name of the harness (class name)."""
|
| 62 |
+
return self.__class__.__name__
|
src/core/evals/inspect_harness.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Inspect AI harness integration for OpenEnv.
|
| 8 |
+
|
| 9 |
+
Requires the ``inspect-ai`` package: ``pip install 'inspect-ai>=0.3.0'``
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from typing import Any, Dict, Optional
|
| 15 |
+
|
| 16 |
+
from openenv.core.evals.base import EvalHarness
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class InspectAIHarness(EvalHarness):
|
| 20 |
+
"""Evaluation harness wrapping Inspect AI's ``eval()`` function.
|
| 21 |
+
|
| 22 |
+
All ``inspect_ai`` imports are deferred to :meth:`run` so this class is
|
| 23 |
+
importable without inspect-ai installed. An ``ImportError`` with a clear
|
| 24 |
+
message is raised at call time if the dependency is missing.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
log_dir: Directory for evaluation log output. Defaults to None
|
| 28 |
+
(Inspect AI writes logs to its default location).
|
| 29 |
+
|
| 30 |
+
``eval_parameters`` keys accepted by :meth:`run`:
|
| 31 |
+
|
| 32 |
+
+--------------------------+----------+-----------------+-----------------------------------+
|
| 33 |
+
| Key | Type | Default | Purpose |
|
| 34 |
+
+==========================+==========+=================+===================================+
|
| 35 |
+
| ``model`` | str | *required* | Model string, e.g. "openai/gpt-4o"|
|
| 36 |
+
| ``task`` | str|None | ``dataset`` arg | Task file path or task string |
|
| 37 |
+
| ``task_args`` | dict | ``{}`` | Arguments to pass to the task |
|
| 38 |
+
| ``max_samples`` | int|None | None | Limit samples per task |
|
| 39 |
+
| ``temperature`` | float|None| None | Model generation temperature |
|
| 40 |
+
| ``max_tokens`` | int|None | None | Max generation tokens |
|
| 41 |
+
| ``epochs`` | int|None | None | Number of evaluation epochs |
|
| 42 |
+
| ``solver`` | list|None| None | Solver pipeline override |
|
| 43 |
+
| ``scorer`` | list|None| None | Scorer override |
|
| 44 |
+
| ``model_args`` | dict | ``{}`` | Provider-specific model kwargs |
|
| 45 |
+
+--------------------------+----------+-----------------+-----------------------------------+
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
*,
|
| 51 |
+
log_dir: Optional[str] = None,
|
| 52 |
+
):
|
| 53 |
+
self.log_dir = log_dir
|
| 54 |
+
|
| 55 |
+
def run(
|
| 56 |
+
self,
|
| 57 |
+
harness_version: str,
|
| 58 |
+
library_versions: Dict[str, str],
|
| 59 |
+
dataset: str,
|
| 60 |
+
eval_parameters: Dict[str, Any],
|
| 61 |
+
) -> Dict[str, Any]:
|
| 62 |
+
"""Run an Inspect AI evaluation.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
harness_version: Version of inspect-ai being used.
|
| 66 |
+
library_versions: Versions of supporting libraries.
|
| 67 |
+
dataset: Default task string (used when ``task`` is not specified
|
| 68 |
+
in *eval_parameters*).
|
| 69 |
+
eval_parameters: See class docstring for accepted keys.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Dictionary mapping metric names to scores.
|
| 73 |
+
|
| 74 |
+
Raises:
|
| 75 |
+
ImportError: If ``inspect-ai`` is not installed.
|
| 76 |
+
ValueError: If ``model`` is missing from *eval_parameters*.
|
| 77 |
+
RuntimeError: If the evaluation fails (log status is not "success").
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
from inspect_ai import eval as inspect_eval
|
| 81 |
+
except ImportError:
|
| 82 |
+
raise ImportError(
|
| 83 |
+
"inspect-ai is required for InspectAIHarness. "
|
| 84 |
+
"Install it with: pip install 'inspect-ai>=0.3.0'"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Extract required model parameter
|
| 88 |
+
model = eval_parameters.get("model")
|
| 89 |
+
if model is None:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
"eval_parameters must include 'model' "
|
| 92 |
+
"(e.g. 'openai/gpt-4o', 'hf/meta-llama/...')."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Task: explicit parameter or fall back to dataset
|
| 96 |
+
task = eval_parameters.get("task", dataset)
|
| 97 |
+
|
| 98 |
+
# Build eval kwargs
|
| 99 |
+
eval_kwargs: Dict[str, Any] = {}
|
| 100 |
+
|
| 101 |
+
task_args = eval_parameters.get("task_args", {})
|
| 102 |
+
if task_args:
|
| 103 |
+
eval_kwargs["task_args"] = task_args
|
| 104 |
+
|
| 105 |
+
model_args = eval_parameters.get("model_args", {})
|
| 106 |
+
if model_args:
|
| 107 |
+
eval_kwargs["model_args"] = model_args
|
| 108 |
+
|
| 109 |
+
for key in ("max_samples", "temperature", "max_tokens", "epochs"):
|
| 110 |
+
value = eval_parameters.get(key)
|
| 111 |
+
if value is not None:
|
| 112 |
+
eval_kwargs[key] = value
|
| 113 |
+
|
| 114 |
+
if eval_parameters.get("solver") is not None:
|
| 115 |
+
eval_kwargs["solver"] = eval_parameters["solver"]
|
| 116 |
+
|
| 117 |
+
if eval_parameters.get("scorer") is not None:
|
| 118 |
+
eval_kwargs["scorer"] = eval_parameters["scorer"]
|
| 119 |
+
|
| 120 |
+
if self.log_dir is not None:
|
| 121 |
+
eval_kwargs["log_dir"] = self.log_dir
|
| 122 |
+
|
| 123 |
+
# Run evaluation
|
| 124 |
+
logs = inspect_eval(task, model=model, **eval_kwargs)
|
| 125 |
+
|
| 126 |
+
# Extract results from the first log
|
| 127 |
+
if not logs:
|
| 128 |
+
raise RuntimeError(
|
| 129 |
+
"Inspect AI evaluation returned no logs. "
|
| 130 |
+
"Check that the task and model arguments are valid."
|
| 131 |
+
)
|
| 132 |
+
log = logs[0]
|
| 133 |
+
if log.status != "success":
|
| 134 |
+
raise RuntimeError(
|
| 135 |
+
f"Inspect AI evaluation failed with status: {log.status}"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
return self._extract_scores(log)
|
| 139 |
+
|
| 140 |
+
def _extract_scores(self, log: Any) -> Dict[str, Any]:
|
| 141 |
+
"""Parse an EvalLog's results into a flat score dictionary.
|
| 142 |
+
|
| 143 |
+
Iterates over ``log.results.scores`` (a list of ``EvalScore``),
|
| 144 |
+
flattening each scorer's ``metrics`` dict into a single output dict.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
log: An ``inspect_ai`` ``EvalLog`` object.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Dictionary mapping metric names to their values.
|
| 151 |
+
"""
|
| 152 |
+
scores: Dict[str, Any] = {}
|
| 153 |
+
if log.results is None:
|
| 154 |
+
return scores
|
| 155 |
+
|
| 156 |
+
for eval_score in log.results.scores:
|
| 157 |
+
for metric_name, metric in eval_score.metrics.items():
|
| 158 |
+
scores[metric_name] = metric.value
|
| 159 |
+
|
| 160 |
+
return scores
|