burtenshaw HF Staff commited on
Commit
81b02bf
·
verified ·
1 Parent(s): 2e2a17e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +10 -5
  2. README.md +23 -4
  3. __init__.py +9 -9
  4. client.py +5 -8
  5. envs/repl_env/README.md +450 -0
  6. envs/repl_env/__init__.py +78 -0
  7. envs/repl_env/client.py +466 -0
  8. envs/repl_env/models.py +110 -0
  9. envs/repl_env/openenv.yaml +6 -0
  10. envs/repl_env/prompts.py +389 -0
  11. envs/repl_env/pyproject.toml +43 -0
  12. envs/repl_env/server/Dockerfile +80 -0
  13. envs/repl_env/server/__init__.py +19 -0
  14. envs/repl_env/server/app.py +92 -0
  15. envs/repl_env/server/python_executor.py +339 -0
  16. envs/repl_env/server/repl_environment.py +516 -0
  17. models.py +4 -12
  18. prompts.py +23 -10
  19. pyproject.toml +1 -1
  20. server/Dockerfile +80 -0
  21. server/__init__.py +1 -1
  22. server/app.py +3 -1
  23. server/python_executor.py +5 -16
  24. server/repl_environment.py +12 -30
  25. src/__init__.py +7 -0
  26. src/core/README.md +212 -0
  27. src/core/__init__.py +81 -0
  28. src/core/client_types.py +23 -0
  29. src/core/containers/__init__.py +7 -0
  30. src/core/containers/images/Dockerfile +64 -0
  31. src/core/containers/images/README.md +92 -0
  32. src/core/containers/runtime/__init__.py +25 -0
  33. src/core/containers/runtime/daytona_provider.py +572 -0
  34. src/core/containers/runtime/providers.py +669 -0
  35. src/core/containers/runtime/uv_provider.py +224 -0
  36. src/core/containers/test_local_docker_provider.py +260 -0
  37. src/core/env_client.py +484 -0
  38. src/core/env_server/__init__.py +150 -0
  39. src/core/env_server/base_transforms.py +29 -0
  40. src/core/env_server/exceptions.py +105 -0
  41. src/core/env_server/gradio_theme.py +128 -0
  42. src/core/env_server/gradio_ui.py +240 -0
  43. src/core/env_server/http_server.py +1391 -0
  44. src/core/env_server/interfaces.py +297 -0
  45. src/core/env_server/mcp_environment.py +624 -0
  46. src/core/env_server/mcp_types.py +321 -0
  47. src/core/env_server/route_config.py +57 -0
  48. src/core/env_server/serialization.py +137 -0
  49. src/core/env_server/types.py +387 -0
  50. src/core/env_server/web_interface.py +644 -0
Dockerfile CHANGED
@@ -11,7 +11,7 @@
11
  # The build script (openenv build) handles context detection and sets appropriate build args.
12
 
13
  ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
- FROM ${BASE_IMAGE} AS builder
15
 
16
  WORKDIR /app
17
 
@@ -40,22 +40,26 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
40
 
41
  # Install dependencies using uv sync
42
  # If uv.lock exists, use it; otherwise resolve on the fly
 
 
 
 
43
  RUN --mount=type=cache,target=/root/.cache/uv \
44
  if [ -f uv.lock ]; then \
45
- uv sync --frozen --no-install-project --no-editable; \
46
  else \
47
  uv sync --no-install-project --no-editable; \
48
  fi
49
 
50
  RUN --mount=type=cache,target=/root/.cache/uv \
51
  if [ -f uv.lock ]; then \
52
- uv sync --frozen --no-editable; \
53
  else \
54
  uv sync --no-editable; \
55
  fi
56
 
57
  # Final runtime stage
58
- FROM ${BASE_IMAGE}
59
 
60
  WORKDIR /app
61
 
@@ -77,5 +81,6 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
77
 
78
  # Run the FastAPI server
79
  # The module path is constructed to work with the /app/env structure
80
- ENV ENABLE_WEB_INTERFACE=true
81
  CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
 
 
 
11
  # The build script (openenv build) handles context detection and sets appropriate build args.
12
 
13
  ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ghcr.io/meta-pytorch/openenv-base:latest AS builder
15
 
16
  WORKDIR /app
17
 
 
40
 
41
  # Install dependencies using uv sync
42
  # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
44
+ install -m 0755 /root/.local/bin/uv /usr/local/bin/uv && \
45
+ install -m 0755 /root/.local/bin/uvx /usr/local/bin/uvx
46
+
47
  RUN --mount=type=cache,target=/root/.cache/uv \
48
  if [ -f uv.lock ]; then \
49
+ uv sync --no-install-project --no-editable; \
50
  else \
51
  uv sync --no-install-project --no-editable; \
52
  fi
53
 
54
  RUN --mount=type=cache,target=/root/.cache/uv \
55
  if [ -f uv.lock ]; then \
56
+ uv sync --no-editable; \
57
  else \
58
  uv sync --no-editable; \
59
  fi
60
 
61
  # Final runtime stage
62
+ FROM ghcr.io/meta-pytorch/openenv-base:latest
63
 
64
  WORKDIR /app
65
 
 
81
 
82
  # Run the FastAPI server
83
  # The module path is constructed to work with the /app/env structure
 
84
  CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
85
+
86
+ ENV ENABLE_WEB_INTERFACE=true
README.md CHANGED
@@ -1,16 +1,33 @@
1
  ---
2
  title: REPL Environment Server
3
  emoji: 🎮
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  app_port: 8000
9
  base_path: /web
10
  tags:
 
11
  - openenv
12
  ---
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # REPL Environment for OpenEnv
15
 
16
  A Python REPL environment for training language models on code execution tasks, based on the [Recursive Language Models (RLM)](https://arxiv.org/abs/2512.24601) paradigm.
@@ -99,7 +116,7 @@ from repl_env import REPLEnv
99
  env = REPLEnv.from_docker_image("repl-env:latest")
100
 
101
  # Or from HuggingFace Hub
102
- env = REPLEnv.from_hub("openenv/repl-env")
103
  ```
104
 
105
  ## API Reference
@@ -431,13 +448,15 @@ uv run --project . server
431
 
432
  ### Using Docker
433
  ```bash
 
 
434
  docker build -t repl-env:latest -f server/Dockerfile .
435
  docker run -p 8000:8000 repl-env:latest
436
  ```
437
 
438
  ### Testing
439
  ```bash
440
- pytest tests/
441
  ```
442
 
443
  ## References
 
1
  ---
2
  title: REPL Environment Server
3
  emoji: 🎮
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
  app_port: 8000
9
  base_path: /web
10
  tags:
11
+ - openenv-0.2.2
12
  - openenv
13
  ---
14
 
15
+ ## Hugging Face Space Deployment
16
+
17
+ This Space is built from OpenEnv environment `repl_env`.
18
+
19
+ - Space URL: `https://huggingface.co/spaces/openenv/repl`
20
+ - OpenEnv pinned ref: `0.2.2`
21
+ - Hub tag: `openenv`
22
+
23
+ ### Connecting from Code
24
+
25
+ ```python
26
+ from envs.repl_env import Env
27
+
28
+ env = Env(base_url="https://huggingface.co/spaces/openenv/repl")
29
+ ```
30
+
31
  # REPL Environment for OpenEnv
32
 
33
  A Python REPL environment for training language models on code execution tasks, based on the [Recursive Language Models (RLM)](https://arxiv.org/abs/2512.24601) paradigm.
 
116
  env = REPLEnv.from_docker_image("repl-env:latest")
117
 
118
  # Or from HuggingFace Hub
119
+ env = REPLEnv.from_hub("openenv/repl")
120
  ```
121
 
122
  ## API Reference
 
448
 
449
  ### Using Docker
450
  ```bash
451
+ # From the repl_env directory
452
+ cd envs/repl_env
453
  docker build -t repl-env:latest -f server/Dockerfile .
454
  docker run -p 8000:8000 repl-env:latest
455
  ```
456
 
457
  ### Testing
458
  ```bash
459
+ pytest tests/envs/test_repl_env.py
460
  ```
461
 
462
  ## References
__init__.py CHANGED
@@ -40,20 +40,20 @@ References:
40
  - Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/
41
  """
42
 
43
- from .models import REPLAction, REPLObservation, REPLState, CodeBlockResult
44
  from .client import REPLEnv
 
45
  from .prompts import (
46
- # System prompts
47
- RLM_SYSTEM_PROMPT,
48
- RLM_SYSTEM_PROMPT_QWEN,
49
- # Prompt building
50
- QueryMetadata,
51
  build_rlm_system_prompt,
52
  build_user_prompt,
53
- build_initial_prompt,
54
  # Parsing utilities
55
  extract_code_blocks,
56
- format_observation,
 
 
 
 
 
57
  )
58
 
59
  __all__ = [
@@ -74,5 +74,5 @@ __all__ = [
74
  "build_initial_prompt",
75
  # Parsing utilities
76
  "extract_code_blocks",
77
- "format_observation",
78
  ]
 
40
  - Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/
41
  """
42
 
 
43
  from .client import REPLEnv
44
+ from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState
45
  from .prompts import (
46
+ build_initial_prompt,
 
 
 
 
47
  build_rlm_system_prompt,
48
  build_user_prompt,
 
49
  # Parsing utilities
50
  extract_code_blocks,
51
+ format_observations,
52
+ # Prompt building
53
+ QueryMetadata,
54
+ # System prompts
55
+ RLM_SYSTEM_PROMPT,
56
+ RLM_SYSTEM_PROMPT_QWEN,
57
  )
58
 
59
  __all__ = [
 
74
  "build_initial_prompt",
75
  # Parsing utilities
76
  "extract_code_blocks",
77
+ "format_observations",
78
  ]
client.py CHANGED
@@ -38,11 +38,12 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
38
  try:
39
  from openenv.core.client_types import StepResult
40
  from openenv.core.env_client import EnvClient
41
- from .models import REPLAction, REPLObservation, REPLState, CodeBlockResult
 
42
  except ImportError:
 
43
  from openenv.core.client_types import StepResult
44
  from openenv.core.env_client import EnvClient
45
- from models import REPLAction, REPLObservation, REPLState, CodeBlockResult
46
 
47
  if TYPE_CHECKING:
48
  from .server.repl_environment import REPLEnvironment
@@ -265,9 +266,7 @@ class REPLEnv:
265
  Returns:
266
  StepResult with done=True.
267
  """
268
- return self.step(
269
- REPLAction(code="", is_final=True, final_answer=answer)
270
- )
271
 
272
  def get_variable(self, name: str) -> StepResult[REPLObservation]:
273
  """
@@ -315,9 +314,7 @@ class REPLEnv:
315
  self._remote_client.close()
316
  self._remote_client = None
317
 
318
- def _wrap_observation(
319
- self, obs: REPLObservation
320
- ) -> StepResult[REPLObservation]:
321
  """Wrap a local REPLObservation in a StepResult."""
322
  return StepResult(
323
  observation=obs,
 
38
  try:
39
  from openenv.core.client_types import StepResult
40
  from openenv.core.env_client import EnvClient
41
+
42
+ from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState
43
  except ImportError:
44
+ from models import CodeBlockResult, REPLAction, REPLObservation, REPLState
45
  from openenv.core.client_types import StepResult
46
  from openenv.core.env_client import EnvClient
 
47
 
48
  if TYPE_CHECKING:
49
  from .server.repl_environment import REPLEnvironment
 
266
  Returns:
267
  StepResult with done=True.
268
  """
269
+ return self.step(REPLAction(code="", is_final=True, final_answer=answer))
 
 
270
 
271
  def get_variable(self, name: str) -> StepResult[REPLObservation]:
272
  """
 
314
  self._remote_client.close()
315
  self._remote_client = None
316
 
317
+ def _wrap_observation(self, obs: REPLObservation) -> StepResult[REPLObservation]:
 
 
318
  """Wrap a local REPLObservation in a StepResult."""
319
  return StepResult(
320
  observation=obs,
envs/repl_env/README.md ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: REPL Environment Server
3
+ emoji: 🎮
4
+ colorFrom: yellow
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
+ ---
13
+
14
+ # REPL Environment for OpenEnv
15
+
16
+ A Python REPL environment for training language models on code execution tasks, based on the [Recursive Language Models (RLM)](https://arxiv.org/abs/2512.24601) paradigm.
17
+
18
+ ## Overview
19
+
20
+ The RLM paradigm allows language models to:
21
+ - Execute Python code in a sandboxed REPL environment
22
+ - Make recursive calls to themselves or other LMs via `llm_query()` / `llm_query_batched()`
23
+ - Handle near-infinite context by programmatically decomposing and exploring data
24
+ - Terminate with explicit `FINAL(answer)` or `answer = {"content": ..., "ready": True}` signals
25
+
26
+ ## Features
27
+
28
+ - **Unified API**: Same `REPLEnv` class works for both local and remote execution
29
+ - **Sandboxed Python Execution**: Safe code execution with restricted builtins
30
+ - **Context Loading**: Load large contexts that agents can explore programmatically
31
+ - **Multiple Finalization Patterns**:
32
+ - Direct call: `FINAL(answer)` - helper function injected into namespace
33
+ - Print pattern: `print('FINAL(answer)')` or `print('FINAL_VAR(var_name)')`
34
+ - Prime Intellect style: `answer = {"content": "...", "ready": True}`
35
+ - **Iteration Limits**: Configurable maximum steps per episode
36
+ - **Reward Signals**: Customizable reward functions for RL training
37
+ - **Optional LLM Oracle**: Can enable `llm_query()` and `llm_query_batched()` for recursive calls
38
+
39
+ ## Quick Start
40
+
41
+ ### Local Mode (No Server Required)
42
+
43
+ ```python
44
+ from repl_env import REPLEnv
45
+
46
+ # Create environment - runs locally by default
47
+ with REPLEnv() as env:
48
+ result = env.reset(
49
+ context="This is a large document with lots of text...",
50
+ task_prompt="Find the word count"
51
+ )
52
+
53
+ # Execute code iteratively
54
+ result = env.execute("words = context.split()")
55
+ result = env.execute("count = len(words)")
56
+ result = env.execute("print(f'FINAL({count})')")
57
+
58
+ print(f"Done: {result.done}")
59
+ print(f"Final Answer: {env.state().final_answer}")
60
+ ```
61
+
62
+ ### Remote Server Mode
63
+
64
+ ```python
65
+ from repl_env import REPLEnv
66
+
67
+ # Connect to a running server - same API!
68
+ with REPLEnv(base_url="https://my-server.hf.space") as env:
69
+ result = env.reset(context="...", task_prompt="...")
70
+ result = env.execute("count = len(context)")
71
+ result = env.execute("print(f'FINAL({count})')")
72
+ ```
73
+
74
+ ### Local Mode with LLM Support
75
+
76
+ ```python
77
+ from repl_env import REPLEnv
78
+
79
+ def my_llm_query(prompt: str) -> str:
80
+ return your_llm.generate(prompt)
81
+
82
+ def my_llm_query_batched(prompts: list[str]) -> list[str]:
83
+ return [my_llm_query(p) for p in prompts]
84
+
85
+ # Pass LLM functions for recursive calls
86
+ with REPLEnv(llm_query_fn=my_llm_query, llm_batch_fn=my_llm_query_batched) as env:
87
+ result = env.reset(context=large_document, task_prompt="Summarize this")
88
+
89
+ # Now the executed code can use llm_query() and llm_query_batched()!
90
+ result = env.execute("summary = llm_query('Summarize: ' + context[:1000])")
91
+ ```
92
+
93
+ ### From Docker or HuggingFace Hub
94
+
95
+ ```python
96
+ from repl_env import REPLEnv
97
+
98
+ # Start from Docker image
99
+ env = REPLEnv.from_docker_image("repl-env:latest")
100
+
101
+ # Or from HuggingFace Hub
102
+ env = REPLEnv.from_hub("openenv/repl-env")
103
+ ```
104
+
105
+ ## API Reference
106
+
107
+ ### REPLEnv
108
+
109
+ ```python
110
+ class REPLEnv:
111
+ def __init__(
112
+ self,
113
+ base_url: str | None = None, # Server URL (None = local mode)
114
+ *,
115
+ # Local-only options
116
+ llm_query_fn: Callable | None = None, # Function for llm_query()
117
+ llm_batch_fn: Callable | None = None, # Function for llm_query_batched()
118
+ max_output_length: int = 8192, # Max stdout/stderr chars
119
+ context_preview_length: int = 500, # Chars in context preview
120
+ reward_on_success: float = 1.0, # Reward on FINAL()
121
+ reward_on_iteration: float = 0.0, # Reward per step
122
+ reward_on_failure: float = -0.1, # Reward on max iterations
123
+ reward_on_error: float = -0.05, # Reward on execution error
124
+ # Remote-only options
125
+ connect_timeout_s: float = 10.0,
126
+ message_timeout_s: float = 60.0,
127
+ ): ...
128
+
129
+ def reset(
130
+ self,
131
+ *,
132
+ context: str = "", # Text to analyze (as `context` variable)
133
+ task_prompt: str = "", # Task description
134
+ max_iterations: int = 30, # Max code execution steps
135
+ seed: int | None = None, # Random seed
136
+ episode_id: str | None = None, # Custom episode ID
137
+ hf_token: str | None = None, # HF token for llm_query (remote mode)
138
+ llm_model: str | None = None, # Model for llm_query (remote mode)
139
+ ) -> StepResult[REPLObservation]: ...
140
+
141
+ def execute(self, code: str) -> StepResult[REPLObservation]: ...
142
+ def step(self, action: REPLAction) -> StepResult[REPLObservation]: ...
143
+ def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]: ...
144
+ def state(self) -> REPLState: ...
145
+ def close(self) -> None: ...
146
+ ```
147
+
148
+ ### Action Space
149
+
150
+ ```python
151
+ class REPLAction:
152
+ code: str = "" # Python code to execute
153
+ is_final: bool = False # Whether this signals the final answer
154
+ final_answer: str | None = None # The final answer (if is_final=True)
155
+ ```
156
+
157
+ ### Observation Space
158
+
159
+ ```python
160
+ class REPLObservation:
161
+ result: CodeBlockResult # Execution result (stdout, stderr, etc.)
162
+ context_preview: str | None # First 500 chars of context
163
+ context_length: int # Total context length
164
+ available_variables: list # Variables in namespace
165
+ iteration: int # Current iteration
166
+ max_iterations: int # Max iterations
167
+ done: bool # Episode complete?
168
+ reward: float # Step reward
169
+ metadata: dict # Additional info (final_answer, etc.)
170
+ ```
171
+
172
+ ## Finalization Patterns
173
+
174
+ ### Pattern 1: Direct FINAL() call (recommended)
175
+ ```python
176
+ result = env.execute("answer = 42")
177
+ result = env.execute("FINAL(answer)")
178
+ # -> done=True, final_answer="42"
179
+ ```
180
+
181
+ ### Pattern 2: FINAL() via print
182
+ ```python
183
+ result = env.execute("answer = 42")
184
+ result = env.execute("print(f'FINAL({answer})')")
185
+ # -> done=True, final_answer="42"
186
+ ```
187
+
188
+ ### Pattern 3: FINAL_VAR() for variable reference
189
+ ```python
190
+ result = env.execute("my_result = 'The answer is 42'")
191
+ # Direct call (recommended) - pass variable name as string
192
+ # FINAL_VAR looks up the variable and returns FINAL(value)
193
+ result = env.execute('FINAL_VAR("my_result")')
194
+ # -> done=True, final_answer="The answer is 42"
195
+
196
+ # Also works via print (for regex detection)
197
+ result = env.execute("print('FINAL_VAR(my_result)')")
198
+ # -> done=True, final_answer="The answer is 42"
199
+ ```
200
+
201
+ ### Pattern 4: Prime Intellect style answer dict
202
+ ```python
203
+ result = env.execute("answer['content'] = '42'")
204
+ result = env.execute("answer['ready'] = True")
205
+ # -> done=True, final_answer="42"
206
+ ```
207
+
208
+ ## Prompts Module
209
+
210
+ The `prompts` module provides RLM-style prompts and parsing utilities:
211
+
212
+ ```python
213
+ from repl_env.prompts import (
214
+ # System prompts (from official RLM repo)
215
+ RLM_SYSTEM_PROMPT, # Base prompt with llm_query_batched
216
+ RLM_SYSTEM_PROMPT_QWEN, # For Qwen models (adds cost warning)
217
+
218
+ # Prompt building
219
+ QueryMetadata, # Context metadata dataclass
220
+ build_rlm_system_prompt, # Build system messages with metadata
221
+ build_user_prompt, # Build user prompt for each iteration
222
+ build_initial_prompt, # Convenience wrapper for iteration 0
223
+
224
+ # Parsing utilities
225
+ extract_code_blocks, # Extract code from ```repl``` or ```python``` blocks
226
+ format_observation, # Format execution result for LLM
227
+ )
228
+
229
+ # Example: Build messages using official RLM style
230
+ query_metadata = QueryMetadata(
231
+ context_lengths=[len(context)],
232
+ context_total_length=len(context),
233
+ context_type="str",
234
+ )
235
+ messages = build_rlm_system_prompt(RLM_SYSTEM_PROMPT_QWEN, query_metadata)
236
+ messages.append(build_user_prompt(root_prompt="Count words in the context", iteration=0))
237
+
238
+ # Extract code from LLM response (supports ```repl``` and ```python```)
239
+ response = "Here's my solution:\n```repl\ncount = len(context.split())\nFINAL(count)\n```"
240
+ code_blocks = extract_code_blocks(response) # ["count = len(context.split())\nFINAL(count)"]
241
+ ```
242
+
243
+ ## Examples
244
+
245
+ See the `examples/` directory for complete working examples:
246
+
247
+ - **`examples/repl_with_llm.py`** - Full RLM loop with local Qwen model
248
+ - **`examples/repl_oolong_simple.py`** - RLM on Oolong benchmark with HuggingFace Inference API
249
+
250
+ Run examples:
251
+ ```bash
252
+ # Full RLM example with local model (requires GPU)
253
+ python examples/repl_with_llm.py
254
+
255
+ # Oolong benchmark with HF Inference API (requires HF_TOKEN)
256
+ python examples/repl_oolong_simple.py
257
+ ```
258
+
259
+ ## Model Usage
260
+
261
+ ### Inference Loop
262
+
263
+ A typical model inference loop where the LLM generates code and the environment executes it:
264
+
265
+ ```python
266
+ from repl_env import REPLEnv
267
+ from repl_env.prompts import RLM_SYSTEM_PROMPT, build_initial_prompt, extract_code_blocks, format_observation
268
+
269
+ # Works with both local and remote!
270
+ with REPLEnv(base_url="http://localhost:8000") as env: # or REPLEnv() for local
271
+ result = env.reset(
272
+ context="The quick brown fox jumps over the lazy dog. " * 1000,
273
+ task_prompt="Count how many times 'fox' appears"
274
+ )
275
+
276
+ messages = [
277
+ {"role": "system", "content": RLM_SYSTEM_PROMPT},
278
+ {"role": "user", "content": build_initial_prompt(
279
+ task_prompt="Count how many times 'fox' appears",
280
+ context_length=result.observation.context_length,
281
+ context_preview=result.observation.context_preview,
282
+ variables=result.observation.available_variables,
283
+ )},
284
+ ]
285
+
286
+ while not result.done:
287
+ # Get code from LLM
288
+ response = your_llm.chat(messages)
289
+ code_blocks = extract_code_blocks(response)
290
+
291
+ for code in code_blocks:
292
+ result = env.execute(code)
293
+ if result.done:
294
+ break
295
+
296
+ # Update conversation
297
+ messages.append({"role": "assistant", "content": response})
298
+ messages.append({"role": "user", "content": format_observation(result.observation)})
299
+
300
+ print(f"Final answer: {env.state().final_answer}")
301
+ ```
302
+
303
+ ### Recursive LLM Calls (RLM Paradigm)
304
+
305
+ The key insight of RLM is that models can make recursive calls to themselves or other LLMs from within the code:
306
+
307
+ ```python
308
+ from repl_env import REPLEnv
309
+
310
+ def llm_query(prompt: str) -> str:
311
+ """Single LLM call - model can call this from executed code"""
312
+ return your_llm.generate(prompt)
313
+
314
+ def llm_query_batched(prompts: list[str]) -> list[str]:
315
+ """Batch LLM calls for efficiency (parallel in production)"""
316
+ return [your_llm.generate(p) for p in prompts]
317
+
318
+ # Create environment with LLM oracle (local mode)
319
+ with REPLEnv(llm_query_fn=llm_query, llm_batch_fn=llm_query_batched) as env:
320
+ result = env.reset(
321
+ context=massive_document, # Could be 100K+ chars
322
+ task_prompt="Summarize each section and find key themes"
323
+ )
324
+
325
+ # The model can now generate code like this:
326
+ code = """
327
+ # Split document into sections
328
+ sections = context.split('\\n\\n')
329
+
330
+ # Use LLM to summarize each section (recursive call!)
331
+ summaries = llm_query_batched([f"Summarize: {s[:1000]}" for s in sections[:10]])
332
+
333
+ # Combine summaries
334
+ combined = '\\n'.join(summaries)
335
+
336
+ # Final synthesis using another LLM call
337
+ answer['content'] = llm_query(f"Find key themes in: {combined}")
338
+ answer['ready'] = True
339
+ """
340
+
341
+ result = env.execute(code)
342
+ print(f"Done: {result.done}, Answer: {env.state().final_answer}")
343
+ ```
344
+
345
+ ### RL Training Integration
346
+
347
+ For RL training, integrate with frameworks like TRL, prime-rl, or verifiers:
348
+
349
+ ```python
350
+ from repl_env import REPLEnv
351
+
352
+ def collect_trajectory(env, policy, context, task):
353
+ """Collect a single trajectory for RL training"""
354
+ result = env.reset(context=context, task_prompt=task)
355
+
356
+ trajectory = []
357
+ total_reward = 0
358
+
359
+ while not result.done:
360
+ # Policy generates code
361
+ code = policy.generate(result.observation)
362
+
363
+ # Step environment
364
+ next_result = env.execute(code)
365
+
366
+ # Store transition
367
+ trajectory.append({
368
+ "observation": result.observation,
369
+ "action": code,
370
+ "reward": next_result.reward,
371
+ "next_observation": next_result.observation,
372
+ "done": next_result.done,
373
+ })
374
+
375
+ total_reward += next_result.reward
376
+ result = next_result
377
+
378
+ return trajectory, total_reward
379
+
380
+ # Training loop
381
+ with REPLEnv(
382
+ reward_on_success=1.0,
383
+ reward_on_iteration=0.0,
384
+ reward_on_error=-0.05,
385
+ reward_on_failure=-0.1,
386
+ ) as env:
387
+ for epoch in range(num_epochs):
388
+ for context, task, ground_truth in dataset:
389
+ trajectory, reward = collect_trajectory(env, policy, context, task)
390
+
391
+ # Verify answer correctness (optional external reward)
392
+ if trajectory:
393
+ final_answer = env.state().final_answer
394
+ if final_answer == ground_truth:
395
+ reward += verification_bonus
396
+
397
+ # Update policy (use your RL framework - PPO, GRPO, DPO, etc.)
398
+ policy.update(trajectory, reward)
399
+ ```
400
+
401
+ ### Reward Configuration
402
+
403
+ Configure rewards for different outcomes:
404
+
405
+ ```python
406
+ env = REPLEnv(
407
+ reward_on_success=1.0, # When FINAL() is called
408
+ reward_on_iteration=0.0, # Per step (can be negative to encourage efficiency)
409
+ reward_on_error=-0.05, # When code execution fails
410
+ reward_on_failure=-0.1, # When max iterations reached without answer
411
+ )
412
+ ```
413
+
414
+ ## Environment Configuration
415
+
416
+ | Environment Variable | Description | Default |
417
+ |---------------------|-------------|---------|
418
+ | `REPL_CONTEXT` | Initial context to load | "" |
419
+ | `REPL_TASK_PROMPT` | Task description | "" |
420
+ | `REPL_MAX_ITERATIONS` | Max steps per episode | 30 |
421
+ | `HF_TOKEN` | HuggingFace token for llm_query (server fallback) | None |
422
+ | `LLM_MODEL` | Model for llm_query/llm_query_batched | Qwen/Qwen3-Coder-480B-A35B-Instruct |
423
+
424
+ ## Running the Server
425
+
426
+ ### Using UV
427
+ ```bash
428
+ cd envs/repl_env
429
+ uv run --project . server
430
+ ```
431
+
432
+ ### Using Docker
433
+ ```bash
434
+ # From the repl_env directory
435
+ cd envs/repl_env
436
+ docker build -t repl-env:latest -f server/Dockerfile .
437
+ docker run -p 8000:8000 repl-env:latest
438
+ ```
439
+
440
+ ### Testing
441
+ ```bash
442
+ pytest tests/envs/test_repl_env.py
443
+ ```
444
+
445
+ ## References
446
+
447
+ - [RLM Paper (arXiv:2512.24601)](https://arxiv.org/abs/2512.24601)
448
+ - [RLM Implementation](https://github.com/alexzhang13/rlm)
449
+ - [Alex Zhang's RLM Blog](https://alexzhang13.github.io/blog/2025/rlm/)
450
+ - [Prime Intellect RLM Blog](https://www.primeintellect.ai/blog/rlm)
envs/repl_env/__init__.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ REPL Environment for OpenEnv.
9
+
10
+ A Python REPL environment for training language models on code execution tasks,
11
+ based on the Recursive Language Models (RLM) paradigm.
12
+
13
+ This environment allows language models to:
14
+ - Execute Python code in a sandboxed REPL
15
+ - Work with large contexts loaded as variables
16
+ - Finalize answers via FINAL(), FINAL_VAR(), or answer dict pattern
17
+ - Optionally make recursive LLM calls via llm_query() / llm_query_batched()
18
+
19
+ Example:
20
+ >>> from repl_env import REPLEnv, REPLAction
21
+ >>>
22
+ >>> # Start from Docker
23
+ >>> env = REPLEnv.from_docker_image("repl-env:latest")
24
+ >>>
25
+ >>> # Reset with context
26
+ >>> result = env.reset(context="Hello World", task_prompt="Count characters")
27
+ >>>
28
+ >>> # Execute code
29
+ >>> result = env.execute("count = len(context)")
30
+ >>> result = env.execute("print(f'FINAL({count})')")
31
+ >>>
32
+ >>> # Check result
33
+ >>> print(f"Done: {result.done}, Answer: {result.observation.metadata['final_answer']}")
34
+ >>>
35
+ >>> env.close()
36
+
37
+ References:
38
+ - RLM Paper: https://arxiv.org/abs/2512.24601
39
+ - Prime Intellect Blog: https://www.primeintellect.ai/blog/rlm
40
+ - Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/
41
+ """
42
+
43
+ from .client import REPLEnv
44
+ from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState
45
+ from .prompts import (
46
+ build_initial_prompt,
47
+ build_rlm_system_prompt,
48
+ build_user_prompt,
49
+ # Parsing utilities
50
+ extract_code_blocks,
51
+ format_observations,
52
+ # Prompt building
53
+ QueryMetadata,
54
+ # System prompts
55
+ RLM_SYSTEM_PROMPT,
56
+ RLM_SYSTEM_PROMPT_QWEN,
57
+ )
58
+
59
+ __all__ = [
60
+ # Models
61
+ "REPLAction",
62
+ "REPLObservation",
63
+ "REPLState",
64
+ "CodeBlockResult",
65
+ # Client
66
+ "REPLEnv",
67
+ # System prompts
68
+ "RLM_SYSTEM_PROMPT",
69
+ "RLM_SYSTEM_PROMPT_QWEN",
70
+ # Prompt building
71
+ "QueryMetadata",
72
+ "build_rlm_system_prompt",
73
+ "build_user_prompt",
74
+ "build_initial_prompt",
75
+ # Parsing utilities
76
+ "extract_code_blocks",
77
+ "format_observations",
78
+ ]
envs/repl_env/client.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ REPL Environment Client.
9
+
10
+ This module provides a unified client for the REPL Environment that works
11
+ with both remote servers (via WebSocket) and local execution (no server needed).
12
+
13
+ Examples:
14
+ # Connect to remote server with your HF token for sub-LLM calls
15
+ env = REPLEnv(base_url="https://my-server.hf.space")
16
+ result = env.reset(
17
+ context="...",
18
+ task_prompt="...",
19
+ hf_token=os.environ["HF_TOKEN"], # Server uses this for llm_query
20
+ )
21
+
22
+ # Run locally (no server)
23
+ env = REPLEnv()
24
+
25
+ # Local with LLM support
26
+ env = REPLEnv(llm_query_fn=my_llm, llm_batch_fn=my_batch)
27
+
28
+ # All use the same interface
29
+ result = env.execute("x = len(context)")
30
+ env.close()
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
36
+
37
+ # Support both in-repo and standalone imports
38
+ try:
39
+ from openenv.core.client_types import StepResult
40
+ from openenv.core.env_client import EnvClient
41
+
42
+ from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState
43
+ except ImportError:
44
+ from models import CodeBlockResult, REPLAction, REPLObservation, REPLState
45
+ from openenv.core.client_types import StepResult
46
+ from openenv.core.env_client import EnvClient
47
+
48
+ if TYPE_CHECKING:
49
+ from .server.repl_environment import REPLEnvironment
50
+
51
+
52
+ class REPLEnv:
53
+ """
54
+ Unified client for the REPL Environment.
55
+
56
+ Works with both remote servers and local execution, providing the same
57
+ interface regardless of where the code runs.
58
+
59
+ Examples:
60
+ >>> # Connect to a running server
61
+ >>> with REPLEnv(base_url="http://localhost:8000") as env:
62
+ ... result = env.reset(context="Hello World", task_prompt="Count chars")
63
+ ... result = env.execute("count = len(context)")
64
+ ... result = env.execute("print(f'FINAL({count})')")
65
+ ... print(result.done) # True
66
+
67
+ >>> # Run locally without a server
68
+ >>> with REPLEnv() as env:
69
+ ... result = env.reset(context="Hello World", task_prompt="Count chars")
70
+ ... result = env.execute("count = len(context)")
71
+ ... print(result.observation.result.success) # True
72
+
73
+ >>> # Local with LLM support for recursive calls
74
+ >>> def my_llm(prompt: str) -> str:
75
+ ... return "LLM response"
76
+ >>> with REPLEnv(llm_query_fn=my_llm) as env:
77
+ ... result = env.reset(context="...")
78
+ ... result = env.execute("response = llm_query('Summarize: ' + context)")
79
+
80
+ >>> # From Docker image
81
+ >>> env = REPLEnv.from_docker_image("repl-env:latest")
82
+
83
+ >>> # From HuggingFace Hub
84
+ >>> env = REPLEnv.from_hub("openenv/repl-env")
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ base_url: Optional[str] = None,
90
+ *,
91
+ # Local-only options (ignored when base_url is set)
92
+ llm_query_fn: Optional[Callable[[str], str]] = None,
93
+ llm_batch_fn: Optional[Callable[[List[str]], List[str]]] = None,
94
+ max_output_length: int = 8192,
95
+ context_preview_length: int = 500,
96
+ reward_on_success: float = 1.0,
97
+ reward_on_iteration: float = 0.0,
98
+ reward_on_failure: float = -0.1,
99
+ reward_on_error: float = -0.05,
100
+ # Connection options (ignored when running locally)
101
+ connect_timeout_s: float = 10.0,
102
+ message_timeout_s: float = 60.0,
103
+ ):
104
+ """
105
+ Initialize REPL environment.
106
+
107
+ Args:
108
+ base_url: Server URL. If None, runs locally without a server.
109
+ llm_query_fn: Function for llm_query() calls (local mode only).
110
+ llm_batch_fn: Function for llm_query_batched() calls (local mode only).
111
+ max_output_length: Max stdout/stderr chars per execution (local only).
112
+ context_preview_length: Chars to show in context preview (local only).
113
+ reward_on_success: Reward when final answer submitted (local only).
114
+ reward_on_iteration: Reward per iteration step (local only).
115
+ reward_on_failure: Reward when max iterations reached (local only).
116
+ reward_on_error: Reward when code execution fails (local only).
117
+ connect_timeout_s: WebSocket connection timeout (remote only).
118
+ message_timeout_s: Message response timeout (remote only).
119
+ """
120
+ self._base_url = base_url
121
+ self._local_env: Optional[REPLEnvironment] = None
122
+ self._remote_client: Optional[_RemoteREPLClient] = None
123
+
124
+ # Store local-mode options
125
+ self._llm_query_fn = llm_query_fn
126
+ self._llm_batch_fn = llm_batch_fn
127
+ self._max_output_length = max_output_length
128
+ self._context_preview_length = context_preview_length
129
+ self._reward_on_success = reward_on_success
130
+ self._reward_on_iteration = reward_on_iteration
131
+ self._reward_on_failure = reward_on_failure
132
+ self._reward_on_error = reward_on_error
133
+
134
+ # Store remote-mode options
135
+ self._connect_timeout_s = connect_timeout_s
136
+ self._message_timeout_s = message_timeout_s
137
+
138
+ # Provider for container/runtime lifecycle (set by factory methods)
139
+ self._provider = None
140
+
141
+ def _ensure_initialized(self) -> None:
142
+ """Initialize the appropriate backend (local or remote)."""
143
+ if self._local_env is not None or self._remote_client is not None:
144
+ return
145
+
146
+ if self._base_url is None:
147
+ # Local mode: create REPLEnvironment directly
148
+ from .server.repl_environment import REPLEnvironment
149
+
150
+ self._local_env = REPLEnvironment(
151
+ max_output_length=self._max_output_length,
152
+ context_preview_length=self._context_preview_length,
153
+ reward_on_success=self._reward_on_success,
154
+ reward_on_iteration=self._reward_on_iteration,
155
+ reward_on_failure=self._reward_on_failure,
156
+ reward_on_error=self._reward_on_error,
157
+ llm_query_fn=self._llm_query_fn,
158
+ llm_batch_fn=self._llm_batch_fn,
159
+ )
160
+ else:
161
+ # Remote mode: create WebSocket client
162
+ self._remote_client = _RemoteREPLClient(
163
+ base_url=self._base_url,
164
+ connect_timeout_s=self._connect_timeout_s,
165
+ message_timeout_s=self._message_timeout_s,
166
+ provider=self._provider,
167
+ )
168
+ self._remote_client.connect()
169
+
170
+ def reset(
171
+ self,
172
+ *,
173
+ context: str = "",
174
+ task_prompt: str = "",
175
+ max_iterations: int = 30,
176
+ seed: Optional[int] = None,
177
+ episode_id: Optional[str] = None,
178
+ hf_token: Optional[str] = None,
179
+ llm_model: Optional[str] = None,
180
+ ) -> StepResult[REPLObservation]:
181
+ """
182
+ Reset the environment for a new episode.
183
+
184
+ Args:
185
+ context: Text content to analyze (accessible as `context` variable).
186
+ task_prompt: Description of the task to solve.
187
+ max_iterations: Maximum code execution steps before timeout.
188
+ seed: Optional random seed for reproducibility.
189
+ episode_id: Optional custom episode identifier.
190
+ hf_token: Optional HuggingFace token for llm_query/llm_query_batched.
191
+ When provided, the server uses this token for sub-LLM calls
192
+ instead of its own configured token.
193
+ Security: Token is NOT stored in state or logged.
194
+ llm_model: Optional model name for LLM functions (default: Qwen3-Coder-480B).
195
+
196
+ Returns:
197
+ StepResult with initial observation.
198
+ """
199
+ self._ensure_initialized()
200
+
201
+ if self._local_env is not None:
202
+ # Local mode
203
+ self._local_env.max_iterations = max_iterations
204
+ obs = self._local_env.reset(
205
+ seed=seed,
206
+ episode_id=episode_id,
207
+ context=context,
208
+ task_prompt=task_prompt,
209
+ hf_token=hf_token,
210
+ llm_model=llm_model,
211
+ )
212
+ return self._wrap_observation(obs)
213
+ else:
214
+ # Remote mode
215
+ assert self._remote_client is not None
216
+ return self._remote_client.reset(
217
+ context=context,
218
+ task_prompt=task_prompt,
219
+ max_iterations=max_iterations,
220
+ seed=seed,
221
+ episode_id=episode_id,
222
+ hf_token=hf_token,
223
+ llm_model=llm_model,
224
+ )
225
+
226
+ def step(self, action: REPLAction) -> StepResult[REPLObservation]:
227
+ """
228
+ Execute a REPL action.
229
+
230
+ Args:
231
+ action: REPLAction containing code to execute.
232
+
233
+ Returns:
234
+ StepResult with execution observation.
235
+ """
236
+ self._ensure_initialized()
237
+
238
+ if self._local_env is not None:
239
+ obs = self._local_env.step(action)
240
+ return self._wrap_observation(obs)
241
+ else:
242
+ assert self._remote_client is not None
243
+ return self._remote_client.step(action)
244
+
245
+ def execute(self, code: str) -> StepResult[REPLObservation]:
246
+ """
247
+ Execute Python code in the REPL.
248
+
249
+ Convenience method that wraps step() with a code-only action.
250
+
251
+ Args:
252
+ code: Python code to execute.
253
+
254
+ Returns:
255
+ StepResult with execution observation.
256
+ """
257
+ return self.step(REPLAction(code=code))
258
+
259
+ def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]:
260
+ """
261
+ Submit a final answer and terminate the episode.
262
+
263
+ Args:
264
+ answer: The final answer string.
265
+
266
+ Returns:
267
+ StepResult with done=True.
268
+ """
269
+ return self.step(REPLAction(code="", is_final=True, final_answer=answer))
270
+
271
+ def get_variable(self, name: str) -> StepResult[REPLObservation]:
272
+ """
273
+ Retrieve and print a variable from the REPL namespace.
274
+
275
+ Args:
276
+ name: Variable name to retrieve.
277
+
278
+ Returns:
279
+ StepResult with variable value in stdout.
280
+ """
281
+ return self.execute(f"print(repr({name}))")
282
+
283
+ def state(self) -> REPLState:
284
+ """
285
+ Get current environment state.
286
+
287
+ Returns:
288
+ REPLState with current environment information.
289
+ """
290
+ self._ensure_initialized()
291
+
292
+ if self._local_env is not None:
293
+ return self._local_env.state
294
+ else:
295
+ assert self._remote_client is not None
296
+ return self._remote_client.state()
297
+
298
+ def list_variables(self) -> List[str]:
299
+ """
300
+ Get list of available variables in the current session.
301
+
302
+ Returns:
303
+ List of variable names.
304
+ """
305
+ return self.state().namespace_keys
306
+
307
+ def close(self) -> None:
308
+ """Clean up resources."""
309
+ if self._local_env is not None:
310
+ self._local_env.close()
311
+ self._local_env = None
312
+
313
+ if self._remote_client is not None:
314
+ self._remote_client.close()
315
+ self._remote_client = None
316
+
317
+ def _wrap_observation(self, obs: REPLObservation) -> StepResult[REPLObservation]:
318
+ """Wrap a local REPLObservation in a StepResult."""
319
+ return StepResult(
320
+ observation=obs,
321
+ reward=obs.reward,
322
+ done=obs.done,
323
+ )
324
+
325
+ # Context manager support
326
+
327
+ def __enter__(self) -> "REPLEnv":
328
+ """Enter context manager."""
329
+ self._ensure_initialized()
330
+ return self
331
+
332
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
333
+ """Exit context manager."""
334
+ self.close()
335
+
336
+ # Factory methods
337
+
338
+ @classmethod
339
+ def from_docker_image(
340
+ cls,
341
+ image: str,
342
+ **kwargs: Any,
343
+ ) -> "REPLEnv":
344
+ """
345
+ Create a REPL environment by spinning up a Docker container.
346
+
347
+ Args:
348
+ image: Docker image name to run (e.g., "repl-env:latest").
349
+ **kwargs: Additional arguments passed to container start.
350
+
351
+ Returns:
352
+ Connected REPLEnv instance.
353
+ """
354
+ from openenv.core.containers.runtime import LocalDockerProvider
355
+
356
+ provider = LocalDockerProvider()
357
+ base_url = provider.start_container(image, **kwargs)
358
+ provider.wait_for_ready(base_url)
359
+
360
+ env = cls(base_url=base_url)
361
+ env._provider = provider
362
+ env._ensure_initialized()
363
+ return env
364
+
365
+ @classmethod
366
+ def from_hub(
367
+ cls,
368
+ repo_id: str,
369
+ *,
370
+ use_docker: bool = True,
371
+ **kwargs: Any,
372
+ ) -> "REPLEnv":
373
+ """
374
+ Create a REPL environment from a HuggingFace Space.
375
+
376
+ Args:
377
+ repo_id: HuggingFace space identifier (e.g., "openenv/repl-env").
378
+ use_docker: If True, pull from HF registry. If False, run with UV.
379
+ **kwargs: Additional arguments passed to provider.
380
+
381
+ Returns:
382
+ Connected REPLEnv instance.
383
+ """
384
+ if use_docker:
385
+ from openenv.core.containers.runtime import LocalDockerProvider
386
+
387
+ provider = LocalDockerProvider()
388
+ tag = kwargs.pop("tag", "latest")
389
+ image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
390
+ base_url = provider.start_container(image, **kwargs)
391
+ provider.wait_for_ready(base_url)
392
+ else:
393
+ from openenv.core.containers.runtime import UVProvider
394
+
395
+ project_path = kwargs.pop(
396
+ "project_path", f"git+https://huggingface.co/spaces/{repo_id}"
397
+ )
398
+ provider = UVProvider(project_path=project_path, **kwargs)
399
+ base_url = provider.start()
400
+ provider.wait_for_ready()
401
+
402
+ env = cls(base_url=base_url)
403
+ env._provider = provider
404
+ env._ensure_initialized()
405
+ return env
406
+
407
+
408
+ class _RemoteREPLClient(EnvClient[REPLAction, REPLObservation, REPLState]):
409
+ """
410
+ Internal WebSocket client for remote REPL connections.
411
+
412
+ This is the original EnvClient-based implementation, now used internally
413
+ by REPLEnv for remote mode.
414
+ """
415
+
416
+ def _step_payload(self, action: REPLAction) -> Dict:
417
+ """Convert REPLAction to JSON payload for step request."""
418
+ return {
419
+ "code": action.code,
420
+ "is_final": action.is_final,
421
+ "final_answer": action.final_answer,
422
+ }
423
+
424
+ def _parse_result(self, payload: Dict) -> StepResult[REPLObservation]:
425
+ """Parse server response into StepResult[REPLObservation]."""
426
+ obs_data = payload.get("observation", {})
427
+ result_data = obs_data.get("result", {})
428
+
429
+ observation = REPLObservation(
430
+ result=CodeBlockResult(
431
+ stdout=result_data.get("stdout", ""),
432
+ stderr=result_data.get("stderr", ""),
433
+ locals_snapshot=result_data.get("locals_snapshot", {}),
434
+ execution_time=result_data.get("execution_time", 0.0),
435
+ success=result_data.get("success", True),
436
+ exception=result_data.get("exception"),
437
+ ),
438
+ context_preview=obs_data.get("context_preview"),
439
+ context_length=obs_data.get("context_length", 0),
440
+ available_variables=obs_data.get("available_variables", []),
441
+ iteration=obs_data.get("iteration", 0),
442
+ max_iterations=obs_data.get("max_iterations", 30),
443
+ done=payload.get("done", False),
444
+ reward=payload.get("reward"),
445
+ metadata=obs_data.get("metadata", {}),
446
+ )
447
+
448
+ return StepResult(
449
+ observation=observation,
450
+ reward=payload.get("reward"),
451
+ done=payload.get("done", False),
452
+ )
453
+
454
+ def _parse_state(self, payload: Dict) -> REPLState:
455
+ """Parse server response into REPLState object."""
456
+ return REPLState(
457
+ episode_id=payload.get("episode_id"),
458
+ step_count=payload.get("step_count", 0),
459
+ context=payload.get("context"),
460
+ task_prompt=payload.get("task_prompt"),
461
+ iteration=payload.get("iteration", 0),
462
+ max_iterations=payload.get("max_iterations", 30),
463
+ namespace_keys=payload.get("namespace_keys", []),
464
+ final_answer=payload.get("final_answer"),
465
+ total_execution_time=payload.get("total_execution_time", 0.0),
466
+ )
envs/repl_env/models.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the REPL Environment.
9
+
10
+ The REPL environment provides a Python REPL for training language models
11
+ on code execution tasks, based on the Recursive Language Models (RLM) paradigm.
12
+
13
+ Supports two finalization patterns:
14
+ 1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)')
15
+ 2. Prime Intellect style: answer = {"content": "...", "ready": True}
16
+ """
17
+
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ from pydantic import BaseModel, Field
21
+
22
+ # Support both in-repo and standalone imports
23
+ try:
24
+ from openenv.core.env_server.types import Action, Observation, State
25
+ except ImportError:
26
+ from openenv.core.env_server.types import Action, Observation, State
27
+
28
+
29
+ class REPLAction(Action):
30
+ """Action containing Python code to execute in the REPL.
31
+
32
+ Supports multiple finalization patterns:
33
+ 1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)') in code
34
+ 2. Prime Intellect style: answer = {"content": "...", "ready": True} in namespace
35
+ 3. Explicit: Set is_final=True with final_answer
36
+ """
37
+
38
+ code: str = Field(default="", description="Python code to execute")
39
+ is_final: bool = Field(
40
+ default=False,
41
+ description="Whether this action signals the final answer",
42
+ )
43
+ final_answer: Optional[str] = Field(
44
+ default=None, description="Final answer if is_final=True"
45
+ )
46
+
47
+
48
+ class CodeBlockResult(BaseModel):
49
+ """Result of executing a single code block."""
50
+
51
+ stdout: str = Field(default="", description="Standard output from execution")
52
+ stderr: str = Field(default="", description="Standard error from execution")
53
+ locals_snapshot: Dict[str, str] = Field(
54
+ default_factory=dict,
55
+ description="String representations of new/modified variables",
56
+ )
57
+ execution_time: float = Field(
58
+ default=0.0, ge=0, description="Execution time in seconds"
59
+ )
60
+ success: bool = Field(default=True, description="Whether execution succeeded")
61
+ exception: Optional[str] = Field(
62
+ default=None, description="Exception message if execution failed"
63
+ )
64
+
65
+
66
+ class REPLObservation(Observation):
67
+ """Observation returned after code execution in the REPL."""
68
+
69
+ result: CodeBlockResult = Field(
70
+ default_factory=CodeBlockResult, description="Result of code execution"
71
+ )
72
+ context_preview: Optional[str] = Field(
73
+ default=None,
74
+ description="Preview of the context (first N chars) if context is loaded",
75
+ )
76
+ context_length: int = Field(
77
+ default=0, ge=0, description="Total length of context in characters"
78
+ )
79
+ available_variables: List[str] = Field(
80
+ default_factory=list,
81
+ description="List of variable names available in the namespace",
82
+ )
83
+ iteration: int = Field(default=0, ge=0, description="Current iteration number")
84
+ max_iterations: int = Field(
85
+ default=30, ge=1, description="Maximum allowed iterations"
86
+ )
87
+
88
+
89
+ class REPLState(State):
90
+ """Extended state for REPL environment."""
91
+
92
+ context: Optional[str] = Field(
93
+ default=None, description="The context/problem to work with"
94
+ )
95
+ task_prompt: Optional[str] = Field(
96
+ default=None, description="The task description to solve"
97
+ )
98
+ iteration: int = Field(default=0, ge=0, description="Current iteration number")
99
+ max_iterations: int = Field(
100
+ default=30, ge=1, description="Max iterations before termination"
101
+ )
102
+ namespace_keys: List[str] = Field(
103
+ default_factory=list, description="Variables currently in namespace"
104
+ )
105
+ final_answer: Optional[str] = Field(
106
+ default=None, description="Final answer if episode is complete"
107
+ )
108
+ total_execution_time: float = Field(
109
+ default=0.0, ge=0, description="Total code execution time in seconds"
110
+ )
envs/repl_env/openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: repl
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
envs/repl_env/prompts.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ RLM System Prompts and Parsing Utilities for the REPL Environment.
9
+
10
+ Based on the official RLM repo: https://github.com/alexzhang13/rlm
11
+
12
+ Two versions available:
13
+ - RLM_SYSTEM_PROMPT: Base prompt from the repo (with llm_query_batched)
14
+ - RLM_SYSTEM_PROMPT_QWEN: For Qwen3-Coder-480B (adds IMPORTANT cost warning)
15
+
16
+ Parsing utilities help extract code blocks and format observations.
17
+ """
18
+
19
+ import re
20
+ import textwrap
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional
23
+
24
+
25
+ # =============================================================================
26
+ # Query Metadata (for context info)
27
+ # =============================================================================
28
+
29
+
30
+ @dataclass
31
+ class QueryMetadata:
32
+ """Metadata about the context for building prompts."""
33
+
34
+ context_lengths: List[int]
35
+ context_total_length: int
36
+ context_type: str = "str" # "str" or "List[str]"
37
+
38
+
39
+ # =============================================================================
40
+ # System Prompt from Official RLM Repo
41
+ # =============================================================================
42
+
43
+ RLM_SYSTEM_PROMPT = textwrap.dedent(
44
+ """You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer.
45
+
46
+ The REPL environment is initialized with:
47
+ 1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query.
48
+ 2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment.
49
+ 3. A `llm_query_batched` function that allows you to query multiple prompts concurrently: `llm_query_batched(prompts: List[str]) -> List[str]`. This is much faster than sequential `llm_query` calls when you have multiple independent queries. Results are returned in the same order as the input prompts.
50
+ 4. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning.
51
+
52
+ You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer.
53
+ Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer.
54
+
55
+ You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. For example, a viable strategy is to feed 10 documents per sub-LLM query. Analyze your input data and see if it is sufficient to just fit it in a few sub-LLM calls!
56
+
57
+ When you want to execute Python code in the REPL environment, wrap it in triple backticks with 'repl' language identifier. For example, say we want our recursive model to search for the magic number in the context (assuming the context is a string), and the context is very long, so we want to chunk it:
58
+ ```repl
59
+ chunk = context[:10000]
60
+ answer = llm_query(f"What is the magic number in the context? Here is the chunk: {{chunk}}")
61
+ print(answer)
62
+ ```
63
+
64
+ As an example, suppose you're trying to answer a question about a book. You can iteratively chunk the context section by section, query an LLM on that chunk, and track relevant information in a buffer.
65
+ ```repl
66
+ query = "In Harry Potter and the Sorcerer's Stone, did Gryffindor win the House Cup because they led?"
67
+ for i, section in enumerate(context):
68
+ if i == len(context) - 1:
69
+ buffer = llm_query(f"You are on the last section of the book. So far you know that: {{buffers}}. Gather from this last section to answer {{query}}. Here is the section: {{section}}")
70
+ print(f"Based on reading iteratively through the book, the answer is: {{buffer}}")
71
+ else:
72
+ buffer = llm_query(f"You are iteratively looking through a book, and are on section {{i}} of {{len(context)}}. Gather information to help answer {{query}}. Here is the section: {{section}}")
73
+ print(f"After section {{i}} of {{len(context)}}, you have tracked: {{buffer}}")
74
+ ```
75
+
76
+ As another example, when the context isn't that long (e.g. >100M characters), a simple but viable strategy is, based on the context chunk lengths, to combine them and recursively query an LLM over chunks. For example, if the context is a List[str], we ask the same query over each chunk using `llm_query_batched` for concurrent processing:
77
+ ```repl
78
+ query = "A man became famous for his book "The Great Gatsby". How many jobs did he have?"
79
+ # Suppose our context is ~1M chars, and we want each sub-LLM query to be ~0.1M chars so we split it into 10 chunks
80
+ chunk_size = len(context) // 10
81
+ chunks = []
82
+ for i in range(10):
83
+ if i < 9:
84
+ chunk_str = "\\n".join(context[i*chunk_size:(i+1)*chunk_size])
85
+ else:
86
+ chunk_str = "\\n".join(context[i*chunk_size:])
87
+ chunks.append(chunk_str)
88
+
89
+ # Use batched query for concurrent processing - much faster than sequential calls!
90
+ prompts = [f"Try to answer the following query: {{query}}. Here are the documents:\\n{{chunk}}. Only answer if you are confident in your answer based on the evidence." for chunk in chunks]
91
+ answers = llm_query_batched(prompts)
92
+ for i, answer in enumerate(answers):
93
+ print(f"I got the answer from chunk {{i}}: {{answer}}")
94
+ final_answer = llm_query(f"Aggregating all the answers per chunk, answer the original query about total number of jobs: {{query}}\\n\\nAnswers:\\n" + "\\n".join(answers))
95
+ ```
96
+
97
+ As a final example, after analyzing the context and realizing its separated by Markdown headers, we can maintain state through buffers by chunking the context by headers, and iteratively querying an LLM over it:
98
+ ```repl
99
+ # After finding out the context is separated by Markdown headers, we can chunk, summarize, and answer
100
+ import re
101
+ sections = re.split(r'### (.+)', context["content"])
102
+ buffers = []
103
+ for i in range(1, len(sections), 2):
104
+ header = sections[i]
105
+ info = sections[i+1]
106
+ summary = llm_query(f"Summarize this {{header}} section: {{info}}")
107
+ buffers.append(f"{{header}}: {{summary}}")
108
+ final_answer = llm_query(f"Based on these summaries, answer the original query: {{query}}\\n\\nSummaries:\\n" + "\\n".join(buffers))
109
+ ```
110
+ In the next step, we can return FINAL_VAR("final_answer").
111
+
112
+ IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of the FINAL functions. Do not use these unless you have completed your task. You have two options:
113
+ 1. Use FINAL(value) to provide the answer directly, e.g., FINAL(42) or FINAL(my_variable)
114
+ 2. Use FINAL_VAR("variable_name") to return a variable by name, e.g., FINAL_VAR("final_answer")
115
+
116
+ Think step by step carefully, plan, and execute this plan immediately in your response -- do not just say "I will do this" or "I will do that". Output to the REPL environment and recursive LLMs as much as possible. Remember to explicitly answer the original query in your final answer.
117
+ """
118
+ )
119
+
120
+
121
+ # =============================================================================
122
+ # System Prompt for Qwen3-Coder-480B (with IMPORTANT cost warning from paper)
123
+ # Adds cost warning after the "sub LLMs are powerful" paragraph
124
+ # =============================================================================
125
+
126
+ RLM_SYSTEM_PROMPT_QWEN = textwrap.dedent(
127
+ """You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer.
128
+
129
+ The REPL environment is initialized with:
130
+ 1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query.
131
+ 2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment.
132
+ 3. A `llm_query_batched` function that allows you to query multiple prompts concurrently: `llm_query_batched(prompts: List[str]) -> List[str]`. This is much faster than sequential `llm_query` calls when you have multiple independent queries. Results are returned in the same order as the input prompts.
133
+ 4. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning.
134
+
135
+ You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer.
136
+ Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer.
137
+
138
+ You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. For example, a viable strategy is to feed 10 documents per sub-LLM query. Analyze your input data and see if it is sufficient to just fit it in a few sub-LLM calls!
139
+
140
+ IMPORTANT: Be very careful about using 'llm_query' as it incurs high runtime costs. Always batch as much information as reasonably possible into each call (aim for around ~200k characters per call). For example, if you have 1000 lines of information to process, it's much better to split into chunks of 5 and call 'llm_query' on each chunk (200 calls total) rather than making 1000 individual calls. Minimize the number of 'llm_query' calls by batching related information together.
141
+
142
+ When you want to execute Python code in the REPL environment, wrap it in triple backticks with 'repl' language identifier. For example, say we want our recursive model to search for the magic number in the context (assuming the context is a string), and the context is very long, so we want to chunk it:
143
+ ```repl
144
+ chunk = context[:10000]
145
+ answer = llm_query(f"What is the magic number in the context? Here is the chunk: {{chunk}}")
146
+ print(answer)
147
+ ```
148
+
149
+ As an example, suppose you're trying to answer a question about a book. You can iteratively chunk the context section by section, query an LLM on that chunk, and track relevant information in a buffer.
150
+ ```repl
151
+ query = "In Harry Potter and the Sorcerer's Stone, did Gryffindor win the House Cup because they led?"
152
+ for i, section in enumerate(context):
153
+ if i == len(context) - 1:
154
+ buffer = llm_query(f"You are on the last section of the book. So far you know that: {{buffers}}. Gather from this last section to answer {{query}}. Here is the section: {{section}}")
155
+ print(f"Based on reading iteratively through the book, the answer is: {{buffer}}")
156
+ else:
157
+ buffer = llm_query(f"You are iteratively looking through a book, and are on section {{i}} of {{len(context)}}. Gather information to help answer {{query}}. Here is the section: {{section}}")
158
+ print(f"After section {{i}} of {{len(context)}}, you have tracked: {{buffer}}")
159
+ ```
160
+
161
+ As another example, when the context isn't that long (e.g. >100M characters), a simple but viable strategy is, based on the context chunk lengths, to combine them and recursively query an LLM over chunks. For example, if the context is a List[str], we ask the same query over each chunk using `llm_query_batched` for concurrent processing:
162
+ ```repl
163
+ query = "A man became famous for his book "The Great Gatsby". How many jobs did he have?"
164
+ # Suppose our context is ~1M chars, and we want each sub-LLM query to be ~0.1M chars so we split it into 10 chunks
165
+ chunk_size = len(context) // 10
166
+ chunks = []
167
+ for i in range(10):
168
+ if i < 9:
169
+ chunk_str = "\\n".join(context[i*chunk_size:(i+1)*chunk_size])
170
+ else:
171
+ chunk_str = "\\n".join(context[i*chunk_size:])
172
+ chunks.append(chunk_str)
173
+
174
+ # Use batched query for concurrent processing - much faster than sequential calls!
175
+ prompts = [f"Try to answer the following query: {{query}}. Here are the documents:\\n{{chunk}}. Only answer if you are confident in your answer based on the evidence." for chunk in chunks]
176
+ answers = llm_query_batched(prompts)
177
+ for i, answer in enumerate(answers):
178
+ print(f"I got the answer from chunk {{i}}: {{answer}}")
179
+ final_answer = llm_query(f"Aggregating all the answers per chunk, answer the original query about total number of jobs: {{query}}\\n\\nAnswers:\\n" + "\\n".join(answers))
180
+ ```
181
+
182
+ As a final example, after analyzing the context and realizing its separated by Markdown headers, we can maintain state through buffers by chunking the context by headers, and iteratively querying an LLM over it:
183
+ ```repl
184
+ # After finding out the context is separated by Markdown headers, we can chunk, summarize, and answer
185
+ import re
186
+ sections = re.split(r'### (.+)', context["content"])
187
+ buffers = []
188
+ for i in range(1, len(sections), 2):
189
+ header = sections[i]
190
+ info = sections[i+1]
191
+ summary = llm_query(f"Summarize this {{header}} section: {{info}}")
192
+ buffers.append(f"{{header}}: {{summary}}")
193
+ final_answer = llm_query(f"Based on these summaries, answer the original query: {{query}}\\n\\nSummaries:\\n" + "\\n".join(buffers))
194
+ ```
195
+ In the next step, we can return FINAL_VAR("final_answer").
196
+
197
+ IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of the FINAL functions. Do not use these unless you have completed your task. You have two options:
198
+ 1. Use FINAL(value) to provide the answer directly, e.g., FINAL(42) or FINAL(my_variable)
199
+ 2. Use FINAL_VAR("variable_name") to return a variable by name, e.g., FINAL_VAR("final_answer")
200
+
201
+ Think step by step carefully, plan, and execute this plan immediately in your response -- do not just say "I will do this" or "I will do that". Output to the REPL environment and recursive LLMs as much as possible. Remember to explicitly answer the original query in your final answer.
202
+ """
203
+ )
204
+
205
+
206
+ # =============================================================================
207
+ # User Prompt Templates (from official RLM repo)
208
+ # =============================================================================
209
+
210
+ USER_PROMPT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the prompt.\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:"""
211
+
212
+ USER_PROMPT_WITH_ROOT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the original prompt: \"{root_prompt}\".\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:"""
213
+
214
+
215
+ # =============================================================================
216
+ # Prompt Building Functions (from official RLM repo)
217
+ # =============================================================================
218
+
219
+
220
+ def build_rlm_system_prompt(
221
+ system_prompt: str,
222
+ query_metadata: QueryMetadata,
223
+ ) -> List[dict]:
224
+ """
225
+ Build the initial system prompt for the REPL environment based on extra prompt metadata.
226
+
227
+ Args:
228
+ system_prompt: The system prompt to use
229
+ query_metadata: QueryMetadata object containing context metadata
230
+
231
+ Returns:
232
+ List of message dictionaries [system, assistant(metadata)]
233
+ """
234
+ context_lengths = query_metadata.context_lengths
235
+ context_total_length = query_metadata.context_total_length
236
+ context_type = query_metadata.context_type
237
+
238
+ # If there are more than 100 chunks, truncate to the first 100 chunks.
239
+ if len(context_lengths) > 100:
240
+ others = len(context_lengths) - 100
241
+ context_lengths_str = (
242
+ str(context_lengths[:100]) + "... [" + str(others) + " others]"
243
+ )
244
+ else:
245
+ context_lengths_str = str(context_lengths)
246
+
247
+ metadata_prompt = f"Your context is a {context_type} with {context_total_length} total characters, and is broken up into chunks of char lengths: {context_lengths_str}."
248
+
249
+ return [
250
+ {"role": "system", "content": system_prompt},
251
+ {"role": "assistant", "content": metadata_prompt},
252
+ ]
253
+
254
+
255
+ def build_user_prompt(
256
+ root_prompt: Optional[str] = None,
257
+ iteration: int = 0,
258
+ context_count: int = 1,
259
+ history_count: int = 0,
260
+ ) -> dict:
261
+ """
262
+ Build the user prompt for a given iteration.
263
+
264
+ Args:
265
+ root_prompt: The original query/task
266
+ iteration: Current iteration number (0 = first)
267
+ context_count: Number of context variables available
268
+ history_count: Number of prior conversation histories
269
+
270
+ Returns:
271
+ User message dict
272
+ """
273
+ if iteration == 0:
274
+ safeguard = "You have not interacted with the REPL environment or seen your prompt / context yet. Your next action should be to look through and figure out how to answer the prompt, so don't just provide a final answer yet.\n\n"
275
+ prompt = safeguard + (
276
+ USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt)
277
+ if root_prompt
278
+ else USER_PROMPT
279
+ )
280
+ else:
281
+ prompt = (
282
+ "The history before is your previous interactions with the REPL environment. "
283
+ + (
284
+ USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt)
285
+ if root_prompt
286
+ else USER_PROMPT
287
+ )
288
+ )
289
+
290
+ # Inform model about multiple contexts if present
291
+ if context_count > 1:
292
+ prompt += f"\n\nNote: You have {context_count} contexts available (context_0 through context_{context_count - 1})."
293
+
294
+ # Inform model about prior conversation histories if present
295
+ if history_count > 0:
296
+ if history_count == 1:
297
+ prompt += "\n\nNote: You have 1 prior conversation history available in the `history` variable."
298
+ else:
299
+ prompt += f"\n\nNote: You have {history_count} prior conversation histories available (history_0 through history_{history_count - 1})."
300
+
301
+ return {"role": "user", "content": prompt}
302
+
303
+
304
+ # =============================================================================
305
+ # Convenience Functions (for backward compatibility)
306
+ # =============================================================================
307
+
308
+
309
+ def build_initial_prompt(
310
+ task_prompt: str,
311
+ context_length: int,
312
+ context_preview: Optional[str] = None,
313
+ variables: Optional[List[str]] = None,
314
+ **kwargs,
315
+ ) -> str:
316
+ """Build the initial user prompt (convenience wrapper).
317
+
318
+ Args:
319
+ task_prompt: The task to accomplish
320
+ context_length: Total length of the context
321
+ context_preview: Preview of the context (not used)
322
+ variables: List of available variable names (not used)
323
+
324
+ Returns:
325
+ Formatted initial prompt string
326
+ """
327
+ return build_user_prompt(root_prompt=task_prompt, iteration=0)["content"]
328
+
329
+
330
+ # =============================================================================
331
+ # Parsing Utilities
332
+ # =============================================================================
333
+
334
+
335
+ def extract_code_blocks(text: str, language: str = "python") -> List[str]:
336
+ """Extract code blocks from LLM response.
337
+
338
+ Supports both ```repl``` (official RLM) and ```python``` style blocks.
339
+
340
+ Args:
341
+ text: The LLM response text
342
+ language: Language identifier to match (default "python")
343
+
344
+ Returns:
345
+ List of code strings extracted from the response
346
+ """
347
+ # Match 'repl' (official) and 'python' (common alternative)
348
+ patterns = [
349
+ r"```repl\s*(.*?)```",
350
+ rf"```{language}\s*(.*?)```",
351
+ ]
352
+
353
+ all_matches = []
354
+ for pattern in patterns:
355
+ matches = re.findall(pattern, text, re.DOTALL)
356
+ all_matches.extend(m.strip() for m in matches if m.strip())
357
+
358
+ return all_matches
359
+
360
+
361
+ def format_observations(observations) -> str:
362
+ """Format REPL observations into observation text for the LLM.
363
+
364
+ Args:
365
+ observations: List of REPL observations from env.step()
366
+
367
+ Returns:
368
+ Formatted observation string
369
+ """
370
+ formatted = []
371
+ for i, observation in enumerate(observations, 1):
372
+ output = (
373
+ observation.result.stdout.strip()
374
+ if observation.result.stdout
375
+ else "(no output)"
376
+ )
377
+ if observation.result.success:
378
+ formatted.append(f"Code block {i} output:\n{output}")
379
+ else:
380
+ error = (
381
+ observation.result.stderr
382
+ or observation.result.exception
383
+ or "Unknown error"
384
+ )
385
+ formatted.append(
386
+ f"Code block {i} output:\n{output}\n\nERROR: {error}\n"
387
+ f"Fix the error in code block {i}. Remember: 'context' is already defined."
388
+ )
389
+ return "\n\n".join(formatted)
envs/repl_env/pyproject.toml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-repl"
13
+ version = "0.1.0"
14
+ description = "Recursive Language Model REPL Environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv dependencies (required for server functionality)
18
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main",
19
+ "fastapi>=0.115.0",
20
+ "pydantic>=2.0.0",
21
+ "uvicorn>=0.24.0",
22
+ "requests>=2.31.0",
23
+ # Environment-specific dependencies
24
+ "smolagents>=1.22.0,<2",
25
+ # LLM support via HuggingFace Inference API
26
+ "huggingface_hub>=0.20.0",
27
+ ]
28
+
29
+ [project.optional-dependencies]
30
+ dev = [
31
+ "pytest>=8.0.0",
32
+ "pytest-cov>=4.0.0",
33
+ ]
34
+
35
+ [project.scripts]
36
+ # Server entry point - enables running via: uv run --project . server
37
+ # or: python -m repl_env.server.app
38
+ server = "repl_env.server.app:main"
39
+
40
+ [tool.setuptools]
41
+ # Explicitly list packages - "repl_env" maps to current dir
42
+ packages = ["repl_env", "repl_env.server"]
43
+ package-dir = {"repl_env" = ".", "repl_env.server" = "server"}
envs/repl_env/server/Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local src/core)
10
+ # - Standalone environments (with openenv from pip)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Build argument to control whether we're building standalone or in-repo
19
+ ARG BUILD_MODE=in-repo
20
+ ARG ENV_NAME=repl_env
21
+
22
+ # Copy environment code (always at root of build context)
23
+ COPY . /app/env
24
+
25
+ # For in-repo builds, openenv-core is already in the pyproject.toml dependencies
26
+ # For standalone builds, openenv-core will be installed from pip via pyproject.toml
27
+ WORKDIR /app/env
28
+
29
+ # Ensure uv is available (for local builds where base image lacks it)
30
+ RUN if ! command -v uv >/dev/null 2>&1; then \
31
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
32
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
33
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
34
+ fi
35
+
36
+ # Install git for building from git repos (build-time only)
37
+ RUN apt-get update && apt-get install -y --no-install-recommends \
38
+ git \
39
+ && rm -rf /var/lib/apt/lists/*
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check using Python (more portable than curl/wget)
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
envs/repl_env/server/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ REPL Environment Server Components.
9
+
10
+ This module contains the server-side implementation of the REPL environment.
11
+ """
12
+
13
+ from .python_executor import PythonExecutor
14
+ from .repl_environment import REPLEnvironment
15
+
16
+ __all__ = [
17
+ "REPLEnvironment",
18
+ "PythonExecutor",
19
+ ]
envs/repl_env/server/app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the REPL Environment.
9
+
10
+ This module creates an HTTP server that exposes the REPLEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ The server includes llm_query and llm_query_batched support via HuggingFace Inference API,
14
+ enabling the Recursive Language Model (RLM) paradigm.
15
+
16
+ LLM Token Configuration:
17
+ 1. Client can pass `hf_token` in reset() - RECOMMENDED
18
+ 2. Server fallback: HF_TOKEN environment variable
19
+
20
+ LLM functions are created dynamically in REPLEnvironment.reset() based on the
21
+ available token (client or server).
22
+
23
+ Usage:
24
+ # Development (with auto-reload):
25
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
26
+
27
+ # Production:
28
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
29
+
30
+ # Or run directly:
31
+ uv run --project . server
32
+
33
+ Environment Variables:
34
+ HF_TOKEN: Fallback HuggingFace API token (client token takes priority)
35
+ LLM_MODEL: Model to use for llm_query/llm_query_batched (default: Qwen/Qwen3-Coder-480B-A35B-Instruct)
36
+ """
37
+
38
+ import os
39
+
40
+ # Support both in-repo and standalone imports
41
+ try:
42
+ # In-repo imports (when running from OpenEnv repository)
43
+ from openenv.core.env_server.http_server import create_app
44
+
45
+ from ..models import REPLAction, REPLObservation
46
+ from .repl_environment import REPLEnvironment
47
+ except ImportError:
48
+ from models import REPLAction, REPLObservation
49
+
50
+ # Standalone imports (when environment is standalone with openenv from pip)
51
+ from openenv.core.env_server.http_server import create_app
52
+ from server.repl_environment import REPLEnvironment
53
+
54
+
55
+ # ============== LLM CONFIGURATION ==============
56
+ LLM_MODEL = os.environ.get("LLM_MODEL", "Qwen/Qwen3-Coder-480B-A35B-Instruct")
57
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
58
+ # ===============================================
59
+
60
+ # Log LLM configuration
61
+ if HF_TOKEN:
62
+ print(f"[REPL Server] LLM support ENABLED (server token configured)")
63
+ print(f"[REPL Server] Default model: {LLM_MODEL}")
64
+ else:
65
+ print("[REPL Server] No server HF_TOKEN configured")
66
+ print(
67
+ "[REPL Server] LLM functions will be enabled if client passes hf_token in reset()"
68
+ )
69
+
70
+ # Simple factory - LLM functions are created dynamically in reset() based on token
71
+ env_factory = REPLEnvironment
72
+
73
+ # Create the app with web interface and README integration
74
+ app = create_app(env_factory, REPLAction, REPLObservation, env_name="repl_env")
75
+
76
+
77
+ def main():
78
+ """
79
+ Entry point for direct execution via uv run or python -m.
80
+
81
+ This function enables running the server without Docker:
82
+ uv run --project . server
83
+ python -m envs.repl_env.server.app
84
+ openenv serve repl_env
85
+ """
86
+ import uvicorn
87
+
88
+ uvicorn.run(app, host="0.0.0.0", port=8000)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()
envs/repl_env/server/python_executor.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Sandboxed Python code executor for the REPL environment.
9
+
10
+ Uses smolagents.LocalPythonExecutor as the backend for battle-tested sandboxed
11
+ execution, with RLM-specific features on top:
12
+ - Context loading (set_context)
13
+ - Variable access (get_variable, list_variables)
14
+ - Function injection (inject_function for llm_query, llm_query_batched)
15
+ - Output capped at 8,192 characters per turn (configurable)
16
+ - Persistent namespace across code blocks
17
+ """
18
+
19
+ import json
20
+ import logging
21
+ import time
22
+ import traceback
23
+ from collections.abc import Callable
24
+ from typing import Any, Dict, List, Optional
25
+
26
+ from smolagents import LocalPythonExecutor
27
+
28
+ logger = logging.getLogger(__name__)
29
+ logger.addHandler(logging.NullHandler())
30
+
31
+
32
+ class PythonExecutor:
33
+ """Sandboxed Python code executor with persistent namespace.
34
+
35
+ Wraps smolagents.LocalPythonExecutor with RLM-specific features:
36
+ - Context loading for RLM tasks
37
+ - Variable tracking for observation
38
+ - Function injection for llm_query, llm_query_batched
39
+ - Configurable output length limit (default 8192 chars per Prime Intellect)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ max_output_length: int = 8192,
45
+ allowed_imports: Optional[List[str]] = None,
46
+ ):
47
+ """Initialize the executor.
48
+
49
+ Args:
50
+ max_output_length: Maximum characters for stdout/stderr (default 8192)
51
+ allowed_imports: List of allowed module names for import
52
+
53
+ Note:
54
+ smolagents.LocalPythonExecutor does NOT support wall-clock timeouts.
55
+ Instead, it limits operations (10M ops) and while iterations (1M).
56
+ """
57
+ self.max_output_length = max_output_length
58
+
59
+ # Default allowed imports for RLM tasks
60
+ default_imports = [
61
+ "re",
62
+ "json",
63
+ "math",
64
+ "random",
65
+ "collections",
66
+ "itertools",
67
+ "functools",
68
+ "operator",
69
+ "string",
70
+ "textwrap",
71
+ "difflib",
72
+ "statistics",
73
+ "decimal",
74
+ "fractions",
75
+ "datetime",
76
+ "copy",
77
+ "pprint",
78
+ "typing",
79
+ "dataclasses",
80
+ "enum",
81
+ "bisect",
82
+ "heapq",
83
+ "array",
84
+ "struct",
85
+ "base64",
86
+ "hashlib",
87
+ "hmac",
88
+ "uuid",
89
+ ]
90
+
91
+ self.allowed_imports = allowed_imports or default_imports
92
+
93
+ # Initialize the smolagents executor
94
+ self._executor = LocalPythonExecutor(
95
+ additional_authorized_imports=self.allowed_imports
96
+ )
97
+
98
+ # Track variables we've set (for list_variables)
99
+ self._user_variables: set[str] = set()
100
+
101
+ # Track callable functions to register with send_tools
102
+ self._callable_tools: Dict[str, Callable[..., Any]] = {}
103
+
104
+ # Register helper utilities
105
+ self._register_helpers()
106
+
107
+ def _register_helpers(self) -> None:
108
+ """Register helper functions with the executor."""
109
+ helpers = {
110
+ "format_exc": traceback.format_exc,
111
+ "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)),
112
+ }
113
+ # Register helpers as callable tools
114
+ for name, func in helpers.items():
115
+ self.inject_function(name, func)
116
+
117
+ def _sync_callable_tools(self) -> None:
118
+ """Sync callable functions with the executor via send_tools."""
119
+ if self._callable_tools:
120
+ try:
121
+ # Type ignore: smolagents accepts callables despite Tool type hint
122
+ self._executor.send_tools(self._callable_tools) # type: ignore[arg-type]
123
+ except Exception:
124
+ logger.debug(
125
+ "send_tools failed; continuing without extra tools",
126
+ exc_info=True,
127
+ )
128
+
129
+ def set_context(self, context: str, variable_name: str = "context") -> None:
130
+ """Load context into namespace as a variable.
131
+
132
+ Args:
133
+ context: The context string to load
134
+ variable_name: Name of the variable (default "context")
135
+ """
136
+ self.set_variable(variable_name, context)
137
+
138
+ def set_variable(self, name: str, value: Any) -> None:
139
+ """Set a variable in the namespace.
140
+
141
+ Args:
142
+ name: Variable name
143
+ value: Variable value
144
+ """
145
+ # Access the executor's internal state to set variables
146
+ if hasattr(self._executor, "state"):
147
+ self._executor.state[name] = value
148
+ else:
149
+ # Fallback: store in injected vars for later retrieval
150
+ self._executor._injected_vars = getattr(
151
+ self._executor, "_injected_vars", {}
152
+ )
153
+ self._executor._injected_vars[name] = value
154
+
155
+ self._user_variables.add(name)
156
+
157
+ def get_variable(self, name: str) -> Optional[Any]:
158
+ """Retrieve a variable from namespace.
159
+
160
+ Args:
161
+ name: Variable name
162
+
163
+ Returns:
164
+ The variable value or None if not found
165
+ """
166
+ # Try to get from executor's state
167
+ if hasattr(self._executor, "state"):
168
+ return self._executor.state.get(name)
169
+
170
+ # Fallback to injected vars
171
+ if hasattr(self._executor, "_injected_vars"):
172
+ return self._executor._injected_vars.get(name)
173
+
174
+ return None
175
+
176
+ def list_variables(self) -> List[str]:
177
+ """List non-private variables in namespace.
178
+
179
+ Returns:
180
+ List of variable names (excluding private and builtins)
181
+ """
182
+ variables = set()
183
+
184
+ # Get from executor's state
185
+ if hasattr(self._executor, "state"):
186
+ for key in self._executor.state:
187
+ if not key.startswith("_"):
188
+ variables.add(key)
189
+
190
+ # Include tracked user variables
191
+ variables.update(self._user_variables)
192
+
193
+ return list(variables)
194
+
195
+ def execute(self, code: str) -> Dict[str, Any]:
196
+ """Execute Python code and return results.
197
+
198
+ Args:
199
+ code: Python code to execute
200
+
201
+ Returns:
202
+ Dictionary with stdout, stderr, locals_snapshot, execution_time,
203
+ success, and exception fields
204
+ """
205
+ start_time = time.time()
206
+ success = True
207
+ exception_msg = None
208
+ new_locals: Dict[str, str] = {}
209
+
210
+ # Track state before execution
211
+ pre_state_keys = set()
212
+ if hasattr(self._executor, "state"):
213
+ pre_state_keys = set(self._executor.state.keys())
214
+
215
+ stdout_parts: list[str] = []
216
+ stderr_parts: list[str] = []
217
+
218
+ try:
219
+ exec_result = self._executor(code)
220
+
221
+ # Extract logs/prints
222
+ try:
223
+ logs = getattr(exec_result, "logs", None)
224
+ if logs:
225
+ stdout_parts.append(str(logs))
226
+ except Exception:
227
+ logger.debug("Failed to read exec_result.logs", exc_info=True)
228
+
229
+ # Extract the result / output value
230
+ try:
231
+ if hasattr(exec_result, "output"):
232
+ out_val = exec_result.output
233
+ if out_val is not None:
234
+ try:
235
+ stdout_parts.append(json.dumps(out_val))
236
+ except Exception:
237
+ stdout_parts.append(repr(out_val))
238
+ except Exception:
239
+ logger.debug("Failed to read exec_result.output", exc_info=True)
240
+
241
+ # Check for errors
242
+ try:
243
+ err = getattr(exec_result, "error", None)
244
+ if err:
245
+ stderr_parts.append(str(err))
246
+ success = False
247
+ exception_msg = str(err)
248
+ except Exception:
249
+ logger.debug("Failed to read exec_result.error", exc_info=True)
250
+
251
+ try:
252
+ ex = getattr(exec_result, "exception", None)
253
+ if ex:
254
+ stderr_parts.append(str(ex))
255
+ success = False
256
+ exception_msg = str(ex)
257
+ except Exception:
258
+ logger.debug("Failed to read exec_result.exception", exc_info=True)
259
+
260
+ # Determine success from exit_code if available
261
+ try:
262
+ if hasattr(exec_result, "exit_code"):
263
+ if exec_result.exit_code is not None and exec_result.exit_code != 0:
264
+ success = False
265
+ elif hasattr(exec_result, "success"):
266
+ success = bool(exec_result.success)
267
+ except Exception:
268
+ logger.debug("Failed to determine exec_result exit code", exc_info=True)
269
+
270
+ except Exception as e:
271
+ success = False
272
+ exception_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
273
+ stderr_parts.append(exception_msg)
274
+
275
+ execution_time = time.time() - start_time
276
+
277
+ # Capture new/modified variables
278
+ if hasattr(self._executor, "state"):
279
+ for key in self._executor.state:
280
+ if key not in pre_state_keys and not key.startswith("_"):
281
+ try:
282
+ val = self._executor.state[key]
283
+ val_repr = repr(val)
284
+ if len(val_repr) > 500:
285
+ val_repr = val_repr[:500] + "..."
286
+ new_locals[key] = val_repr
287
+ self._user_variables.add(key)
288
+ except Exception:
289
+ new_locals[key] = "<unrepresentable>"
290
+
291
+ # Compose stdout/stderr
292
+ stdout = "\n".join(part for part in stdout_parts if part)
293
+ stderr = "\n".join(part for part in stderr_parts if part)
294
+
295
+ # Truncate output to max_output_length
296
+ if len(stdout) > self.max_output_length:
297
+ stdout = (
298
+ stdout[: self.max_output_length]
299
+ + f"\n... (truncated, total {len(stdout)} chars)"
300
+ )
301
+
302
+ if len(stderr) > self.max_output_length:
303
+ stderr = (
304
+ stderr[: self.max_output_length]
305
+ + f"\n... (truncated, total {len(stderr)} chars)"
306
+ )
307
+
308
+ return {
309
+ "stdout": stdout,
310
+ "stderr": stderr,
311
+ "locals_snapshot": new_locals,
312
+ "execution_time": execution_time,
313
+ "success": success,
314
+ "exception": exception_msg,
315
+ }
316
+
317
+ def reset(self) -> None:
318
+ """Reset namespace to initial state."""
319
+ # Create a new executor instance
320
+ self._executor = LocalPythonExecutor(
321
+ additional_authorized_imports=self.allowed_imports
322
+ )
323
+ self._user_variables.clear()
324
+ self._callable_tools.clear()
325
+ self._register_helpers()
326
+
327
+ def inject_function(self, name: str, func: Callable[..., Any]) -> None:
328
+ """Inject a callable function into the namespace.
329
+
330
+ Used for adding llm_query, llm_query_batched, FINAL, etc.
331
+
332
+ Args:
333
+ name: Function name in namespace
334
+ func: The callable to inject
335
+ """
336
+ # Add to callable tools and sync with executor
337
+ self._callable_tools[name] = func
338
+ self._user_variables.add(name)
339
+ self._sync_callable_tools()
envs/repl_env/server/repl_environment.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ REPL Environment Implementation.
9
+
10
+ A Python REPL environment for training language models on code execution tasks,
11
+ based on the Recursive Language Models (RLM) paradigm.
12
+
13
+ References:
14
+ - RLM Paper: https://arxiv.org/abs/2512.24601
15
+ - Prime Intellect Blog: https://www.primeintellect.ai/blog/rlm
16
+ - Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/
17
+ """
18
+
19
+ import os
20
+ import re
21
+ from collections.abc import Callable
22
+ from typing import Any, Dict, List, Optional
23
+ from uuid import uuid4
24
+
25
+ # Support both in-repo and standalone imports
26
+ try:
27
+ from openenv.core.env_server.interfaces import Environment
28
+ from openenv.core.env_server.types import EnvironmentMetadata
29
+ except ImportError:
30
+ from openenv.core.env_server.interfaces import Environment
31
+ from openenv.core.env_server.types import EnvironmentMetadata
32
+
33
+ try:
34
+ from ..models import CodeBlockResult, REPLAction, REPLObservation, REPLState
35
+ except ImportError:
36
+ from models import CodeBlockResult, REPLAction, REPLObservation, REPLState
37
+
38
+ try:
39
+ from .python_executor import PythonExecutor
40
+ except ImportError:
41
+ from python_executor import PythonExecutor
42
+
43
+
44
+ class REPLEnvironment(Environment):
45
+ """
46
+ A REPL environment for training language models to use code execution.
47
+
48
+ Based on the Recursive Language Models (RLM) paradigm, this environment allows
49
+ language models to:
50
+ - Execute Python code in a sandboxed REPL
51
+ - Work with large contexts loaded as variables
52
+ - Finalize answers via FINAL(), FINAL_VAR(), or answer dict pattern
53
+ - Optionally make recursive LLM calls via llm_query() / llm_query_batched()
54
+
55
+ Supports two finalization patterns:
56
+ 1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)')
57
+ 2. Prime Intellect style: answer = {"content": "...", "ready": True}
58
+
59
+ Example:
60
+ >>> env = REPLEnvironment(context="Hello World", task_prompt="Count chars")
61
+ >>> obs = env.reset()
62
+ >>> print(obs.context_preview) # "Hello World"
63
+ >>>
64
+ >>> obs = env.step(REPLAction(code="result = len(context)"))
65
+ >>> print(obs.result.success) # True
66
+ >>> print(obs.available_variables) # ["context", "result", "answer"]
67
+ >>>
68
+ >>> obs = env.step(REPLAction(code="print(f'FINAL({result})')"))
69
+ >>> print(obs.done) # True
70
+ >>> print(obs.metadata["final_answer"]) # "11"
71
+ """
72
+
73
+ SUPPORTS_CONCURRENT_SESSIONS = True
74
+
75
+ def __init__(
76
+ self,
77
+ context: Optional[str] = None,
78
+ task_prompt: Optional[str] = None,
79
+ max_iterations: int = 30,
80
+ max_output_length: int = 8192,
81
+ context_preview_length: int = 500,
82
+ reward_on_success: float = 1.0,
83
+ reward_on_iteration: float = 0.0,
84
+ reward_on_failure: float = -0.1,
85
+ reward_on_error: float = -0.05,
86
+ llm_query_fn: Optional[Callable[[str], str]] = None,
87
+ llm_batch_fn: Optional[Callable[[List[str]], List[str]]] = None,
88
+ ):
89
+ """Initialize the REPL environment.
90
+
91
+ Args:
92
+ context: Initial context to load (can also be set via REPL_CONTEXT env var)
93
+ task_prompt: Task description (can also be set via REPL_TASK_PROMPT env var)
94
+ max_iterations: Maximum steps per episode (default 30, env var REPL_MAX_ITERATIONS)
95
+ max_output_length: Max chars for stdout/stderr per turn (default 8192)
96
+ context_preview_length: Chars to show in context preview (default 500)
97
+ reward_on_success: Reward when final answer is submitted (default 1.0)
98
+ reward_on_iteration: Reward per iteration step (default 0.0)
99
+ reward_on_failure: Reward when max iterations reached (default -0.1)
100
+ reward_on_error: Reward when code execution fails (default -0.05)
101
+ llm_query_fn: Optional function for llm_query() support
102
+ llm_batch_fn: Optional function for llm_query_batched() support
103
+ """
104
+ self.initial_context = context or os.environ.get("REPL_CONTEXT", "")
105
+ self.initial_task_prompt = task_prompt or os.environ.get("REPL_TASK_PROMPT", "")
106
+ self.max_iterations = int(os.environ.get("REPL_MAX_ITERATIONS", max_iterations))
107
+ self.max_output_length = max_output_length
108
+ self.context_preview_length = context_preview_length
109
+
110
+ # Reward configuration
111
+ self.reward_on_success = reward_on_success
112
+ self.reward_on_iteration = reward_on_iteration
113
+ self.reward_on_failure = reward_on_failure
114
+ self.reward_on_error = reward_on_error
115
+
116
+ # Optional LLM functions for recursive calls
117
+ self.llm_query_fn = llm_query_fn
118
+ self.llm_batch_fn = llm_batch_fn
119
+
120
+ # State (initialized on reset)
121
+ self._state: Optional[REPLState] = None
122
+ self._executor: Optional[PythonExecutor] = None
123
+
124
+ def _create_llm_functions(
125
+ self,
126
+ hf_token: str,
127
+ llm_model: Optional[str] = None,
128
+ ) -> None:
129
+ """Create LLM functions dynamically using client-provided token.
130
+
131
+ This allows clients to use their own HF token instead of the server's.
132
+
133
+ Security: The token is used only to initialize the InferenceClient
134
+ and is NOT stored in state, logged, or persisted anywhere.
135
+
136
+ Args:
137
+ hf_token: HuggingFace API token (not logged or persisted)
138
+ llm_model: Model to use (default: Qwen/Qwen3-Coder-480B-A35B-Instruct)
139
+ """
140
+ from concurrent.futures import as_completed, ThreadPoolExecutor
141
+
142
+ try:
143
+ from huggingface_hub import InferenceClient
144
+ except ImportError:
145
+ # huggingface_hub not installed, skip LLM functions
146
+ return
147
+
148
+ model = llm_model or os.environ.get(
149
+ "LLM_MODEL", "Qwen/Qwen3-Coder-480B-A35B-Instruct"
150
+ )
151
+ client = InferenceClient(model=model, token=hf_token)
152
+
153
+ def llm_query(prompt: str) -> str:
154
+ """Query the LLM with a prompt and return the response."""
155
+ try:
156
+ messages = [{"role": "user", "content": prompt}]
157
+ response = client.chat_completion(
158
+ messages=messages,
159
+ max_tokens=2048,
160
+ temperature=0.7,
161
+ )
162
+ return response.choices[0].message.content or ""
163
+ except Exception as e:
164
+ return f"Error calling LLM: {e}"
165
+
166
+ def llm_query_batched(prompts: List[str]) -> List[str]:
167
+ """Query the LLM with multiple prompts in parallel."""
168
+ if not prompts:
169
+ return []
170
+
171
+ max_workers = min(len(prompts), 8)
172
+ results: List[str] = [""] * len(prompts)
173
+
174
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
175
+ future_to_idx = {
176
+ executor.submit(llm_query, prompt): idx
177
+ for idx, prompt in enumerate(prompts)
178
+ }
179
+ for future in as_completed(future_to_idx):
180
+ idx = future_to_idx[future]
181
+ try:
182
+ results[idx] = future.result()
183
+ except Exception as e:
184
+ results[idx] = f"Error: {e}"
185
+
186
+ return results
187
+
188
+ self.llm_query_fn = llm_query
189
+ self.llm_batch_fn = llm_query_batched
190
+
191
+ def reset(
192
+ self,
193
+ seed: Optional[int] = None,
194
+ episode_id: Optional[str] = None,
195
+ context: Optional[str] = None,
196
+ task_prompt: Optional[str] = None,
197
+ hf_token: Optional[str] = None,
198
+ llm_model: Optional[str] = None,
199
+ **kwargs: Any,
200
+ ) -> REPLObservation:
201
+ """Reset the environment with optional new context.
202
+
203
+ Args:
204
+ seed: Optional random seed (for reproducibility)
205
+ episode_id: Optional episode identifier (if not provided, one is generated)
206
+ context: Context to load (overrides initial_context)
207
+ task_prompt: Task description (overrides initial_task_prompt)
208
+ hf_token: Optional HuggingFace token for llm_query/llm_query_batched.
209
+ If provided, creates LLM functions using this token.
210
+ Security: Token is NOT stored in state or logged.
211
+ llm_model: Optional model name for LLM functions (default: from env or Qwen3-Coder)
212
+ **kwargs: Additional reset parameters
213
+
214
+ Returns:
215
+ Initial REPLObservation with environment ready message
216
+ """
217
+ effective_context = context or self.initial_context
218
+ effective_task_prompt = task_prompt or self.initial_task_prompt
219
+
220
+ # Create LLM functions if not already provided at init
221
+ # Priority: client hf_token > server HF_TOKEN env var
222
+ if not self.llm_query_fn:
223
+ effective_token = hf_token or os.environ.get("HF_TOKEN")
224
+ if effective_token:
225
+ self._create_llm_functions(effective_token, llm_model)
226
+
227
+ # Initialize state
228
+ self._state = REPLState(
229
+ episode_id=episode_id or str(uuid4()),
230
+ step_count=0,
231
+ context=effective_context,
232
+ task_prompt=effective_task_prompt,
233
+ iteration=0,
234
+ max_iterations=self.max_iterations,
235
+ namespace_keys=[],
236
+ final_answer=None,
237
+ total_execution_time=0.0,
238
+ )
239
+
240
+ # Initialize executor
241
+ self._executor = PythonExecutor(max_output_length=self.max_output_length)
242
+
243
+ # Initialize answer dict (Prime Intellect style)
244
+ self._executor.set_variable("answer", {"content": "", "ready": False})
245
+
246
+ # Load context into namespace if provided
247
+ if effective_context:
248
+ self._executor.set_context(effective_context)
249
+
250
+ # Inject LLM functions if provided
251
+ # Names: llm_query (single), llm_query_batched (official RLM), llm_batch (alias)
252
+ if self.llm_query_fn:
253
+ self._executor.inject_function("llm_query", self.llm_query_fn)
254
+ if self.llm_batch_fn:
255
+ self._executor.inject_function(
256
+ "llm_query_batched", self.llm_batch_fn
257
+ ) # Official name
258
+ self._executor.inject_function("llm_batch", self.llm_batch_fn) # Alias
259
+
260
+ # Inject FINAL helper function so both FINAL(x) and print(f'FINAL({x})') work
261
+ # Returns the FINAL pattern as a string so it appears in stdout for detection
262
+ def final_helper(value):
263
+ """Helper that returns FINAL(value) string for detection."""
264
+ return f"FINAL({value})"
265
+
266
+ self._executor.inject_function("FINAL", final_helper)
267
+
268
+ # Inject FINAL_VAR helper that looks up variable and returns FINAL(value)
269
+ # This matches official RLM behavior - strips quotes from var_name and looks up in namespace
270
+ executor = self._executor # Capture for closure
271
+
272
+ def final_var_helper(var_name: str):
273
+ """Look up variable by name and return FINAL(value) for detection."""
274
+ # Strip quotes if present (handles both FINAL_VAR("x") and FINAL_VAR(x))
275
+ var_name_clean = str(var_name).strip().strip("\"'")
276
+ # Look up variable in executor namespace
277
+ value = executor.get_variable(var_name_clean)
278
+ if value is not None:
279
+ return f"FINAL({value})"
280
+ return f"FINAL_VAR({var_name_clean})" # Fallback for regex detection
281
+
282
+ self._executor.inject_function("FINAL_VAR", final_var_helper)
283
+
284
+ # Update namespace keys
285
+ self._state.namespace_keys = self._executor.list_variables()
286
+
287
+ # Build initial message
288
+ message_parts = ["REPL environment initialized."]
289
+ if effective_context:
290
+ message_parts.append(
291
+ f"Context loaded ({len(effective_context)} chars). Use 'context' variable to access it."
292
+ )
293
+ if effective_task_prompt:
294
+ message_parts.append(f"Task: {effective_task_prompt}")
295
+ message_parts.append(
296
+ "Use answer['content'] to store your answer, and set answer['ready'] = True when done."
297
+ )
298
+
299
+ return REPLObservation(
300
+ result=CodeBlockResult(
301
+ stdout="\n".join(message_parts),
302
+ stderr="",
303
+ locals_snapshot={},
304
+ execution_time=0.0,
305
+ success=True,
306
+ exception=None,
307
+ ),
308
+ context_preview=(
309
+ effective_context[: self.context_preview_length]
310
+ if effective_context
311
+ else None
312
+ ),
313
+ context_length=len(effective_context) if effective_context else 0,
314
+ available_variables=self._state.namespace_keys,
315
+ iteration=0,
316
+ max_iterations=self.max_iterations,
317
+ done=False,
318
+ reward=0.0,
319
+ metadata={
320
+ "task_prompt": effective_task_prompt,
321
+ "message": "Environment ready.",
322
+ },
323
+ )
324
+
325
+ def step(
326
+ self,
327
+ action: REPLAction,
328
+ timeout_s: Optional[float] = None,
329
+ **kwargs: Any,
330
+ ) -> REPLObservation:
331
+ """Execute code and return observation.
332
+
333
+ Args:
334
+ action: REPLAction containing code to execute
335
+ timeout_s: Optional timeout in seconds (not currently used)
336
+ **kwargs: Additional step parameters
337
+
338
+ Returns:
339
+ REPLObservation with execution results
340
+ """
341
+ if self._state is None or self._executor is None:
342
+ raise RuntimeError("Environment not initialized. Call reset() first.")
343
+
344
+ self._state.step_count += 1
345
+ self._state.iteration += 1
346
+
347
+ # Check if agent explicitly signals final answer
348
+ if action.is_final:
349
+ self._state.final_answer = action.final_answer or ""
350
+ return self._create_final_observation(
351
+ success=True,
352
+ message="Final answer submitted.",
353
+ reward=self.reward_on_success,
354
+ )
355
+
356
+ # Check iteration limit
357
+ if self._state.iteration >= self.max_iterations:
358
+ # Check if there's a partial answer in the answer dict
359
+ answer_var = self._executor.get_variable("answer")
360
+ if isinstance(answer_var, dict) and answer_var.get("content"):
361
+ self._state.final_answer = str(answer_var.get("content", ""))
362
+ return self._create_final_observation(
363
+ success=False,
364
+ message=f"Maximum iterations ({self.max_iterations}) reached.",
365
+ reward=self.reward_on_failure,
366
+ )
367
+
368
+ # Execute code
369
+ result = self._executor.execute(action.code)
370
+ self._state.total_execution_time += result["execution_time"]
371
+ self._state.namespace_keys = self._executor.list_variables()
372
+
373
+ # Calculate reward
374
+ reward = self.reward_on_iteration
375
+ if not result["success"]:
376
+ reward += self.reward_on_error
377
+
378
+ # Check for final answer patterns
379
+ final_answer = self._extract_final_answer(result["stdout"])
380
+ done = final_answer is not None
381
+
382
+ if done:
383
+ self._state.final_answer = final_answer
384
+ reward = self.reward_on_success
385
+
386
+ return REPLObservation(
387
+ result=CodeBlockResult(
388
+ stdout=result["stdout"],
389
+ stderr=result["stderr"],
390
+ locals_snapshot=result["locals_snapshot"],
391
+ execution_time=result["execution_time"],
392
+ success=result["success"],
393
+ exception=result["exception"],
394
+ ),
395
+ context_preview=(
396
+ self._state.context[: self.context_preview_length]
397
+ if self._state.context
398
+ else None
399
+ ),
400
+ context_length=len(self._state.context) if self._state.context else 0,
401
+ available_variables=self._state.namespace_keys,
402
+ iteration=self._state.iteration,
403
+ max_iterations=self.max_iterations,
404
+ done=done,
405
+ reward=reward,
406
+ metadata={
407
+ "task_prompt": self._state.task_prompt,
408
+ "final_answer": final_answer,
409
+ "execution_time": result["execution_time"],
410
+ },
411
+ )
412
+
413
+ def _extract_final_answer(self, stdout: str) -> Optional[str]:
414
+ """Extract final answer from output.
415
+
416
+ Supports multiple patterns:
417
+ 1. RLM-style: FINAL(answer) in stdout
418
+ 2. RLM-style: FINAL_VAR(variable_name) in stdout
419
+ 3. Prime Intellect style: answer = {"content": "...", "ready": True} in namespace
420
+
421
+ Args:
422
+ stdout: Standard output from code execution
423
+
424
+ Returns:
425
+ Final answer string or None if not found
426
+ """
427
+ # Pattern 1: RLM-style FINAL(answer)
428
+ final_match = re.search(r"FINAL\((.*?)\)", stdout, re.DOTALL)
429
+ if final_match:
430
+ return final_match.group(1).strip()
431
+
432
+ # Pattern 2: RLM-style FINAL_VAR(variable_name)
433
+ final_var_match = re.search(r"FINAL_VAR\((\w+)\)", stdout)
434
+ if final_var_match and self._executor:
435
+ var_name = final_var_match.group(1)
436
+ value = self._executor.get_variable(var_name)
437
+ if value is not None:
438
+ return str(value)
439
+
440
+ # Pattern 3: Prime Intellect style answer dict
441
+ if self._executor:
442
+ answer_var = self._executor.get_variable("answer")
443
+ if isinstance(answer_var, dict):
444
+ if answer_var.get("ready", False):
445
+ return str(answer_var.get("content", ""))
446
+
447
+ return None
448
+
449
+ def _create_final_observation(
450
+ self, success: bool, message: str, reward: float
451
+ ) -> REPLObservation:
452
+ """Create observation for episode termination.
453
+
454
+ Args:
455
+ success: Whether the episode ended successfully
456
+ message: Termination message
457
+ reward: Final reward value
458
+
459
+ Returns:
460
+ Final REPLObservation with done=True
461
+ """
462
+ return REPLObservation(
463
+ result=CodeBlockResult(
464
+ stdout=message,
465
+ stderr="",
466
+ locals_snapshot={},
467
+ execution_time=0.0,
468
+ success=success,
469
+ exception=None,
470
+ ),
471
+ context_preview=None,
472
+ context_length=0,
473
+ available_variables=[],
474
+ iteration=self._state.iteration if self._state else 0,
475
+ max_iterations=self.max_iterations,
476
+ done=True,
477
+ reward=reward,
478
+ metadata={
479
+ "final_answer": self._state.final_answer if self._state else None,
480
+ "total_execution_time": (
481
+ self._state.total_execution_time if self._state else 0
482
+ ),
483
+ "total_iterations": self._state.iteration if self._state else 0,
484
+ },
485
+ )
486
+
487
+ @property
488
+ def state(self) -> REPLState:
489
+ """Get the current environment state.
490
+
491
+ Returns:
492
+ Current REPLState
493
+
494
+ Raises:
495
+ RuntimeError: If environment not initialized
496
+ """
497
+ if self._state is None:
498
+ raise RuntimeError("Environment not initialized. Call reset() first.")
499
+ return self._state
500
+
501
+ def close(self) -> None:
502
+ """Cleanup resources."""
503
+ self._executor = None
504
+ self._state = None
505
+
506
+ def get_metadata(self) -> EnvironmentMetadata:
507
+ """Get environment metadata.
508
+
509
+ Returns:
510
+ EnvironmentMetadata with environment info
511
+ """
512
+ return EnvironmentMetadata(
513
+ name="repl_env",
514
+ description="Python REPL environment for RLM-style code execution",
515
+ version="0.1.0",
516
+ )
models.py CHANGED
@@ -48,9 +48,7 @@ class REPLAction(Action):
48
  class CodeBlockResult(BaseModel):
49
  """Result of executing a single code block."""
50
 
51
- stdout: str = Field(
52
- default="", description="Standard output from execution"
53
- )
54
  stderr: str = Field(default="", description="Standard error from execution")
55
  locals_snapshot: Dict[str, str] = Field(
56
  default_factory=dict,
@@ -59,9 +57,7 @@ class CodeBlockResult(BaseModel):
59
  execution_time: float = Field(
60
  default=0.0, ge=0, description="Execution time in seconds"
61
  )
62
- success: bool = Field(
63
- default=True, description="Whether execution succeeded"
64
- )
65
  exception: Optional[str] = Field(
66
  default=None, description="Exception message if execution failed"
67
  )
@@ -84,9 +80,7 @@ class REPLObservation(Observation):
84
  default_factory=list,
85
  description="List of variable names available in the namespace",
86
  )
87
- iteration: int = Field(
88
- default=0, ge=0, description="Current iteration number"
89
- )
90
  max_iterations: int = Field(
91
  default=30, ge=1, description="Maximum allowed iterations"
92
  )
@@ -101,9 +95,7 @@ class REPLState(State):
101
  task_prompt: Optional[str] = Field(
102
  default=None, description="The task description to solve"
103
  )
104
- iteration: int = Field(
105
- default=0, ge=0, description="Current iteration number"
106
- )
107
  max_iterations: int = Field(
108
  default=30, ge=1, description="Max iterations before termination"
109
  )
 
48
  class CodeBlockResult(BaseModel):
49
  """Result of executing a single code block."""
50
 
51
+ stdout: str = Field(default="", description="Standard output from execution")
 
 
52
  stderr: str = Field(default="", description="Standard error from execution")
53
  locals_snapshot: Dict[str, str] = Field(
54
  default_factory=dict,
 
57
  execution_time: float = Field(
58
  default=0.0, ge=0, description="Execution time in seconds"
59
  )
60
+ success: bool = Field(default=True, description="Whether execution succeeded")
 
 
61
  exception: Optional[str] = Field(
62
  default=None, description="Exception message if execution failed"
63
  )
 
80
  default_factory=list,
81
  description="List of variable names available in the namespace",
82
  )
83
+ iteration: int = Field(default=0, ge=0, description="Current iteration number")
 
 
84
  max_iterations: int = Field(
85
  default=30, ge=1, description="Maximum allowed iterations"
86
  )
 
95
  task_prompt: Optional[str] = Field(
96
  default=None, description="The task description to solve"
97
  )
98
+ iteration: int = Field(default=0, ge=0, description="Current iteration number")
 
 
99
  max_iterations: int = Field(
100
  default=30, ge=1, description="Max iterations before termination"
101
  )
prompts.py CHANGED
@@ -358,19 +358,32 @@ def extract_code_blocks(text: str, language: str = "python") -> List[str]:
358
  return all_matches
359
 
360
 
361
- def format_observation(obs) -> str:
362
- """Format a REPLObservation into observation text for the LLM.
363
 
364
  Args:
365
- obs: The REPLObservation from env.step()
366
 
367
  Returns:
368
  Formatted observation string
369
  """
370
- output = obs.result.stdout.strip() if obs.result.stdout else "(no output)"
371
-
372
- if obs.result.success:
373
- return f"Code output:\n{output}"
374
- else:
375
- error = obs.result.stderr or obs.result.exception or "Unknown error"
376
- return f"Code output:\n{output}\n\nERROR: {error}\nFix the error. Remember: 'context' is already defined."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return all_matches
359
 
360
 
361
+ def format_observations(observations) -> str:
362
+ """Format REPL observations into observation text for the LLM.
363
 
364
  Args:
365
+ observations: List of REPL observations from env.step()
366
 
367
  Returns:
368
  Formatted observation string
369
  """
370
+ formatted = []
371
+ for i, observation in enumerate(observations, 1):
372
+ output = (
373
+ observation.result.stdout.strip()
374
+ if observation.result.stdout
375
+ else "(no output)"
376
+ )
377
+ if observation.result.success:
378
+ formatted.append(f"Code block {i} output:\n{output}")
379
+ else:
380
+ error = (
381
+ observation.result.stderr
382
+ or observation.result.exception
383
+ or "Unknown error"
384
+ )
385
+ formatted.append(
386
+ f"Code block {i} output:\n{output}\n\nERROR: {error}\n"
387
+ f"Fix the error in code block {i}. Remember: 'context' is already defined."
388
+ )
389
+ return "\n\n".join(formatted)
pyproject.toml CHANGED
@@ -15,7 +15,7 @@ description = "Recursive Language Model REPL Environment for OpenEnv"
15
  requires-python = ">=3.10"
16
  dependencies = [
17
  # Core OpenEnv dependencies (required for server functionality)
18
- "openenv-core @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1",
19
  "fastapi>=0.115.0",
20
  "pydantic>=2.0.0",
21
  "uvicorn>=0.24.0",
 
15
  requires-python = ">=3.10"
16
  dependencies = [
17
  # Core OpenEnv dependencies (required for server functionality)
18
+ "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main",
19
  "fastapi>=0.115.0",
20
  "pydantic>=2.0.0",
21
  "uvicorn>=0.24.0",
server/Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local src/core)
10
+ # - Standalone environments (with openenv from pip)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Build argument to control whether we're building standalone or in-repo
19
+ ARG BUILD_MODE=in-repo
20
+ ARG ENV_NAME=repl_env
21
+
22
+ # Copy environment code (always at root of build context)
23
+ COPY . /app/env
24
+
25
+ # For in-repo builds, openenv-core is already in the pyproject.toml dependencies
26
+ # For standalone builds, openenv-core will be installed from pip via pyproject.toml
27
+ WORKDIR /app/env
28
+
29
+ # Ensure uv is available (for local builds where base image lacks it)
30
+ RUN if ! command -v uv >/dev/null 2>&1; then \
31
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
32
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
33
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
34
+ fi
35
+
36
+ # Install git for building from git repos (build-time only)
37
+ RUN apt-get update && apt-get install -y --no-install-recommends \
38
+ git \
39
+ && rm -rf /var/lib/apt/lists/*
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check using Python (more portable than curl/wget)
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
server/__init__.py CHANGED
@@ -10,8 +10,8 @@ REPL Environment Server Components.
10
  This module contains the server-side implementation of the REPL environment.
11
  """
12
 
13
- from .repl_environment import REPLEnvironment
14
  from .python_executor import PythonExecutor
 
15
 
16
  __all__ = [
17
  "REPLEnvironment",
 
10
  This module contains the server-side implementation of the REPL environment.
11
  """
12
 
 
13
  from .python_executor import PythonExecutor
14
+ from .repl_environment import REPLEnvironment
15
 
16
  __all__ = [
17
  "REPLEnvironment",
server/app.py CHANGED
@@ -41,12 +41,14 @@ import os
41
  try:
42
  # In-repo imports (when running from OpenEnv repository)
43
  from openenv.core.env_server.http_server import create_app
 
44
  from ..models import REPLAction, REPLObservation
45
  from .repl_environment import REPLEnvironment
46
  except ImportError:
 
 
47
  # Standalone imports (when environment is standalone with openenv from pip)
48
  from openenv.core.env_server.http_server import create_app
49
- from models import REPLAction, REPLObservation
50
  from server.repl_environment import REPLEnvironment
51
 
52
 
 
41
  try:
42
  # In-repo imports (when running from OpenEnv repository)
43
  from openenv.core.env_server.http_server import create_app
44
+
45
  from ..models import REPLAction, REPLObservation
46
  from .repl_environment import REPLEnvironment
47
  except ImportError:
48
+ from models import REPLAction, REPLObservation
49
+
50
  # Standalone imports (when environment is standalone with openenv from pip)
51
  from openenv.core.env_server.http_server import create_app
 
52
  from server.repl_environment import REPLEnvironment
53
 
54
 
server/python_executor.py CHANGED
@@ -108,9 +108,7 @@ class PythonExecutor:
108
  """Register helper functions with the executor."""
109
  helpers = {
110
  "format_exc": traceback.format_exc,
111
- "safe_json_dumps": lambda obj: json.dumps(
112
- obj, default=lambda o: repr(o)
113
- ),
114
  }
115
  # Register helpers as callable tools
116
  for name, func in helpers.items():
@@ -257,30 +255,21 @@ class PythonExecutor:
257
  success = False
258
  exception_msg = str(ex)
259
  except Exception:
260
- logger.debug(
261
- "Failed to read exec_result.exception", exc_info=True
262
- )
263
 
264
  # Determine success from exit_code if available
265
  try:
266
  if hasattr(exec_result, "exit_code"):
267
- if (
268
- exec_result.exit_code is not None
269
- and exec_result.exit_code != 0
270
- ):
271
  success = False
272
  elif hasattr(exec_result, "success"):
273
  success = bool(exec_result.success)
274
  except Exception:
275
- logger.debug(
276
- "Failed to determine exec_result exit code", exc_info=True
277
- )
278
 
279
  except Exception as e:
280
  success = False
281
- exception_msg = (
282
- f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
283
- )
284
  stderr_parts.append(exception_msg)
285
 
286
  execution_time = time.time() - start_time
 
108
  """Register helper functions with the executor."""
109
  helpers = {
110
  "format_exc": traceback.format_exc,
111
+ "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)),
 
 
112
  }
113
  # Register helpers as callable tools
114
  for name, func in helpers.items():
 
255
  success = False
256
  exception_msg = str(ex)
257
  except Exception:
258
+ logger.debug("Failed to read exec_result.exception", exc_info=True)
 
 
259
 
260
  # Determine success from exit_code if available
261
  try:
262
  if hasattr(exec_result, "exit_code"):
263
+ if exec_result.exit_code is not None and exec_result.exit_code != 0:
 
 
 
264
  success = False
265
  elif hasattr(exec_result, "success"):
266
  success = bool(exec_result.success)
267
  except Exception:
268
+ logger.debug("Failed to determine exec_result exit code", exc_info=True)
 
 
269
 
270
  except Exception as e:
271
  success = False
272
+ exception_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
 
 
273
  stderr_parts.append(exception_msg)
274
 
275
  execution_time = time.time() - start_time
server/repl_environment.py CHANGED
@@ -31,9 +31,9 @@ except ImportError:
31
  from openenv.core.env_server.types import EnvironmentMetadata
32
 
33
  try:
34
- from ..models import REPLAction, REPLObservation, REPLState, CodeBlockResult
35
  except ImportError:
36
- from models import REPLAction, REPLObservation, REPLState, CodeBlockResult
37
 
38
  try:
39
  from .python_executor import PythonExecutor
@@ -102,12 +102,8 @@ class REPLEnvironment(Environment):
102
  llm_batch_fn: Optional function for llm_query_batched() support
103
  """
104
  self.initial_context = context or os.environ.get("REPL_CONTEXT", "")
105
- self.initial_task_prompt = task_prompt or os.environ.get(
106
- "REPL_TASK_PROMPT", ""
107
- )
108
- self.max_iterations = int(
109
- os.environ.get("REPL_MAX_ITERATIONS", max_iterations)
110
- )
111
  self.max_output_length = max_output_length
112
  self.context_preview_length = context_preview_length
113
 
@@ -141,7 +137,7 @@ class REPLEnvironment(Environment):
141
  hf_token: HuggingFace API token (not logged or persisted)
142
  llm_model: Model to use (default: Qwen/Qwen3-Coder-480B-A35B-Instruct)
143
  """
144
- from concurrent.futures import ThreadPoolExecutor, as_completed
145
 
146
  try:
147
  from huggingface_hub import InferenceClient
@@ -242,9 +238,7 @@ class REPLEnvironment(Environment):
242
  )
243
 
244
  # Initialize executor
245
- self._executor = PythonExecutor(
246
- max_output_length=self.max_output_length
247
- )
248
 
249
  # Initialize answer dict (Prime Intellect style)
250
  self._executor.set_variable("answer", {"content": "", "ready": False})
@@ -261,9 +255,7 @@ class REPLEnvironment(Environment):
261
  self._executor.inject_function(
262
  "llm_query_batched", self.llm_batch_fn
263
  ) # Official name
264
- self._executor.inject_function(
265
- "llm_batch", self.llm_batch_fn
266
- ) # Alias
267
 
268
  # Inject FINAL helper function so both FINAL(x) and print(f'FINAL({x})') work
269
  # Returns the FINAL pattern as a string so it appears in stdout for detection
@@ -285,9 +277,7 @@ class REPLEnvironment(Environment):
285
  value = executor.get_variable(var_name_clean)
286
  if value is not None:
287
  return f"FINAL({value})"
288
- return (
289
- f"FINAL_VAR({var_name_clean})" # Fallback for regex detection
290
- )
291
 
292
  self._executor.inject_function("FINAL_VAR", final_var_helper)
293
 
@@ -349,9 +339,7 @@ class REPLEnvironment(Environment):
349
  REPLObservation with execution results
350
  """
351
  if self._state is None or self._executor is None:
352
- raise RuntimeError(
353
- "Environment not initialized. Call reset() first."
354
- )
355
 
356
  self._state.step_count += 1
357
  self._state.iteration += 1
@@ -409,9 +397,7 @@ class REPLEnvironment(Environment):
409
  if self._state.context
410
  else None
411
  ),
412
- context_length=len(self._state.context)
413
- if self._state.context
414
- else 0,
415
  available_variables=self._state.namespace_keys,
416
  iteration=self._state.iteration,
417
  max_iterations=self.max_iterations,
@@ -490,9 +476,7 @@ class REPLEnvironment(Environment):
490
  done=True,
491
  reward=reward,
492
  metadata={
493
- "final_answer": self._state.final_answer
494
- if self._state
495
- else None,
496
  "total_execution_time": (
497
  self._state.total_execution_time if self._state else 0
498
  ),
@@ -511,9 +495,7 @@ class REPLEnvironment(Environment):
511
  RuntimeError: If environment not initialized
512
  """
513
  if self._state is None:
514
- raise RuntimeError(
515
- "Environment not initialized. Call reset() first."
516
- )
517
  return self._state
518
 
519
  def close(self) -> None:
 
31
  from openenv.core.env_server.types import EnvironmentMetadata
32
 
33
  try:
34
+ from ..models import CodeBlockResult, REPLAction, REPLObservation, REPLState
35
  except ImportError:
36
+ from models import CodeBlockResult, REPLAction, REPLObservation, REPLState
37
 
38
  try:
39
  from .python_executor import PythonExecutor
 
102
  llm_batch_fn: Optional function for llm_query_batched() support
103
  """
104
  self.initial_context = context or os.environ.get("REPL_CONTEXT", "")
105
+ self.initial_task_prompt = task_prompt or os.environ.get("REPL_TASK_PROMPT", "")
106
+ self.max_iterations = int(os.environ.get("REPL_MAX_ITERATIONS", max_iterations))
 
 
 
 
107
  self.max_output_length = max_output_length
108
  self.context_preview_length = context_preview_length
109
 
 
137
  hf_token: HuggingFace API token (not logged or persisted)
138
  llm_model: Model to use (default: Qwen/Qwen3-Coder-480B-A35B-Instruct)
139
  """
140
+ from concurrent.futures import as_completed, ThreadPoolExecutor
141
 
142
  try:
143
  from huggingface_hub import InferenceClient
 
238
  )
239
 
240
  # Initialize executor
241
+ self._executor = PythonExecutor(max_output_length=self.max_output_length)
 
 
242
 
243
  # Initialize answer dict (Prime Intellect style)
244
  self._executor.set_variable("answer", {"content": "", "ready": False})
 
255
  self._executor.inject_function(
256
  "llm_query_batched", self.llm_batch_fn
257
  ) # Official name
258
+ self._executor.inject_function("llm_batch", self.llm_batch_fn) # Alias
 
 
259
 
260
  # Inject FINAL helper function so both FINAL(x) and print(f'FINAL({x})') work
261
  # Returns the FINAL pattern as a string so it appears in stdout for detection
 
277
  value = executor.get_variable(var_name_clean)
278
  if value is not None:
279
  return f"FINAL({value})"
280
+ return f"FINAL_VAR({var_name_clean})" # Fallback for regex detection
 
 
281
 
282
  self._executor.inject_function("FINAL_VAR", final_var_helper)
283
 
 
339
  REPLObservation with execution results
340
  """
341
  if self._state is None or self._executor is None:
342
+ raise RuntimeError("Environment not initialized. Call reset() first.")
 
 
343
 
344
  self._state.step_count += 1
345
  self._state.iteration += 1
 
397
  if self._state.context
398
  else None
399
  ),
400
+ context_length=len(self._state.context) if self._state.context else 0,
 
 
401
  available_variables=self._state.namespace_keys,
402
  iteration=self._state.iteration,
403
  max_iterations=self.max_iterations,
 
476
  done=True,
477
  reward=reward,
478
  metadata={
479
+ "final_answer": self._state.final_answer if self._state else None,
 
 
480
  "total_execution_time": (
481
  self._state.total_execution_time if self._state else 0
482
  ),
 
495
  RuntimeError: If environment not initialized
496
  """
497
  if self._state is None:
498
+ raise RuntimeError("Environment not initialized. Call reset() first.")
 
 
499
  return self._state
500
 
501
  def close(self) -> None:
src/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """EnvTorch: Standardized agentic execution environments."""
src/core/README.md ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # <img width="35" height="35" alt="image" src="https://github.com/user-attachments/assets/2700a971-e5d6-4036-b03f-2f89c9791609" /> OpenEnv: Agentic Execution Environments
2
+
3
+ An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - step(), reset(), state(). Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs.
4
+
5
+ In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use.
6
+
7
+
8
+ ## Overview
9
+ `openenv.core` provides the foundational building blocks for creating and interacting with containerized environments over HTTP. It enables you to build agent environments that can be deployed as Docker containers and accessed via a simple HTTP API.
10
+
11
+ > ⚠️ **Early Development Warning** OpenEnv is currently in an experimental
12
+ > stage. You should expect bugs, incomplete features, and APIs that may change
13
+ > in future versions. The project welcomes bugfixes, but to make sure things are
14
+ > well coordinated you should discuss any significant change before starting the
15
+ > work. It's recommended that you signal your intention to contribute in the
16
+ > issue tracker, either by filing a new issue or by claiming an existing one.
17
+
18
+
19
+ # OpenEnv Core
20
+
21
+ Core components for OpenEnv - a framework for building HTTP-based agentic environments.
22
+
23
+ ## Features
24
+
25
+ - **EnvClient**: Async-first client for interacting with remote environments
26
+ - **SyncEnvClient**: Synchronous wrapper via `.sync()` for sync codebases
27
+ - **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket
28
+ - **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.)
29
+ - **Type System**: Strongly-typed Action/Observation/State interfaces
30
+ - **Web Interface**: Optional web UI for interacting with environments
31
+
32
+ ## Installation
33
+
34
+ ```bash
35
+ pip install "openenv[core]"
36
+ ```
37
+
38
+ For development:
39
+ ```bash
40
+ pip install "openenv[core]"
41
+ ```
42
+
43
+ ## Quick Start
44
+
45
+ ### Creating an Environment Client
46
+
47
+ EnvClient is **async by default**. Use `async with` and `await` for all operations:
48
+
49
+ ```python
50
+ import asyncio
51
+ from openenv.core import EnvClient, StepResult
52
+ from dataclasses import dataclass
53
+ from typing import Any
54
+
55
+ @dataclass
56
+ class MyAction:
57
+ text: str
58
+
59
+ @dataclass
60
+ class MyObservation:
61
+ response: str
62
+
63
+ class MyEnvClient(EnvClient[MyAction, MyObservation, Any]):
64
+ def _step_payload(self, action: MyAction) -> dict:
65
+ return {"text": action.text}
66
+
67
+ def _parse_result(self, payload: dict) -> StepResult[MyObservation]:
68
+ obs_data = payload["observation"]
69
+ return StepResult(
70
+ observation=MyObservation(**obs_data),
71
+ reward=payload.get("reward"),
72
+ done=payload.get("done", False)
73
+ )
74
+
75
+ def _parse_state(self, payload: dict) -> Any:
76
+ return payload
77
+
78
+ # Async usage (recommended)
79
+ async def main():
80
+ client = await MyEnvClient.from_docker_image("my-env:latest")
81
+ async with client:
82
+ result = await client.reset()
83
+ step_result = await client.step(MyAction(text="hello"))
84
+
85
+ asyncio.run(main())
86
+
87
+ # Sync usage (via .sync() wrapper)
88
+ with MyEnvClient(base_url="http://localhost:8000").sync() as client:
89
+ result = client.reset()
90
+ step_result = client.step(MyAction(text="hello"))
91
+ ```
92
+
93
+ ### Creating an Environment Server
94
+
95
+ ```python
96
+ from openenv.core.env_server import Environment, HTTPEnvServer, create_app
97
+ from dataclasses import dataclass
98
+
99
+ @dataclass
100
+ class MyAction:
101
+ text: str
102
+
103
+ @dataclass
104
+ class MyObservation:
105
+ response: str
106
+ reward: float = 0.0
107
+ done: bool = False
108
+
109
+ class MyEnvironment(Environment):
110
+ def reset(self) -> MyObservation:
111
+ return MyObservation(response="Ready")
112
+
113
+ def step(self, action: MyAction) -> MyObservation:
114
+ return MyObservation(
115
+ response=f"Echo: {action.text}",
116
+ reward=1.0,
117
+ done=False
118
+ )
119
+
120
+ # Create FastAPI app
121
+ env = MyEnvironment()
122
+ app = create_app(env, MyAction, MyObservation)
123
+
124
+ # Run with: uvicorn module:app --host 0.0.0.0 --port 8000
125
+ ```
126
+
127
+ ## Container Providers
128
+
129
+ OpenEnv Core supports multiple container providers:
130
+
131
+ ### Local Docker Provider
132
+
133
+ ```python
134
+ from openenv.core.containers.runtime import LocalDockerProvider
135
+
136
+ provider = LocalDockerProvider()
137
+ base_url = provider.start_container("my-env:latest")
138
+ provider.wait_for_ready(base_url)
139
+ # Use environment...
140
+ provider.stop_container()
141
+ ```
142
+
143
+ ### Kubernetes Provider (Coming Soon)
144
+
145
+ ```python
146
+ from openenv.core.containers.runtime import KubernetesProvider
147
+
148
+ provider = KubernetesProvider(namespace="envs")
149
+ base_url = provider.start_container("my-env:latest")
150
+ # Use environment...
151
+ provider.stop_container()
152
+ ```
153
+
154
+
155
+ ## API Reference
156
+
157
+ ### EnvClient
158
+
159
+ Async base class for environment clients. Key methods:
160
+
161
+ - `async connect()`: Establish WebSocket connection
162
+ - `async reset(**kwargs)`: Reset environment
163
+ - `async step(action)`: Execute action
164
+ - `async state()`: Get current state
165
+ - `async close()`: Close connection and cleanup
166
+ - `sync()`: Return a SyncEnvClient wrapper for synchronous usage
167
+
168
+ Abstract methods to implement:
169
+ - `_step_payload(action)`: Convert action to JSON
170
+ - `_parse_result(payload)`: Parse response to StepResult
171
+ - `_parse_state(payload)`: Parse state response
172
+
173
+ ### SyncEnvClient
174
+
175
+ Synchronous wrapper around EnvClient. Use `client.sync()` to get one:
176
+
177
+ ```python
178
+ sync_client = async_client.sync()
179
+ with sync_client:
180
+ result = sync_client.reset()
181
+ result = sync_client.step(action)
182
+ ```
183
+
184
+ ### HTTPEnvServer
185
+
186
+ Server wrapper with these methods:
187
+
188
+ - `register_routes(app)`: Register endpoints on FastAPI app
189
+ - `_deserialize_action(data)`: Convert JSON to Action
190
+ - `_serialize_observation(obs)`: Convert Observation to JSON
191
+
192
+ ### Environment Interface
193
+
194
+ Base interface for environment implementations:
195
+
196
+ - `reset()`: Reset environment and return initial observation
197
+ - `step(action)`: Execute action and return observation
198
+ - `state`: Property returning current environment state
199
+
200
+ ## License
201
+
202
+ This project is licensed under the BSD-3-Clause License - see the LICENSE file for details.
203
+
204
+ ## Contributing
205
+
206
+ Contributions are welcome! Please see the main OpenEnv repository for contribution guidelines.
207
+
208
+ ## Links
209
+
210
+ - **Homepage**: https://github.com/meta-pytorch/OpenEnv
211
+ - **Documentation**: https://github.com/meta-pytorch/OpenEnv/blob/main/README.md
212
+ - **Bug Tracker**: https://github.com/meta-pytorch/OpenEnv/issues
src/core/__init__.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Core components for agentic environments."""
8
+
9
+ from __future__ import annotations
10
+
11
+ from importlib import import_module
12
+ from typing import TYPE_CHECKING
13
+
14
+ from . import env_server
15
+ from .env_server import * # noqa: F403
16
+
17
+ if TYPE_CHECKING:
18
+ from .env_client import EnvClient
19
+ from .generic_client import GenericAction, GenericEnvClient
20
+ from .llm_client import (
21
+ AnthropicClient,
22
+ create_llm_client,
23
+ LLMClient,
24
+ LLMResponse,
25
+ OpenAIClient,
26
+ ToolCall,
27
+ )
28
+ from .mcp_client import MCPClientBase, MCPToolClient
29
+ from .sync_client import SyncEnvClient
30
+
31
+ __all__ = [
32
+ "EnvClient",
33
+ "SyncEnvClient",
34
+ "GenericEnvClient",
35
+ "GenericAction",
36
+ "MCPClientBase",
37
+ "MCPToolClient",
38
+ "AnthropicClient",
39
+ "LLMClient",
40
+ "LLMResponse",
41
+ "OpenAIClient",
42
+ "ToolCall",
43
+ "create_llm_client",
44
+ ] + env_server.__all__ # type: ignore
45
+
46
+
47
+ _LAZY_ATTRS = {
48
+ "EnvClient": (".env_client", "EnvClient"),
49
+ "SyncEnvClient": (".sync_client", "SyncEnvClient"),
50
+ "GenericEnvClient": (".generic_client", "GenericEnvClient"),
51
+ "GenericAction": (".generic_client", "GenericAction"),
52
+ "MCPClientBase": (".mcp_client", "MCPClientBase"),
53
+ "MCPToolClient": (".mcp_client", "MCPToolClient"),
54
+ "AnthropicClient": (".llm_client", "AnthropicClient"),
55
+ "LLMClient": (".llm_client", "LLMClient"),
56
+ "LLMResponse": (".llm_client", "LLMResponse"),
57
+ "OpenAIClient": (".llm_client", "OpenAIClient"),
58
+ "ToolCall": (".llm_client", "ToolCall"),
59
+ "create_llm_client": (".llm_client", "create_llm_client"),
60
+ }
61
+
62
+
63
+ def __getattr__(name: str):
64
+ if name in _LAZY_ATTRS:
65
+ module_path, attr_name = _LAZY_ATTRS[name]
66
+ module = import_module(module_path, __name__)
67
+ value = getattr(module, attr_name)
68
+ globals()[name] = value
69
+ return value
70
+
71
+ try:
72
+ value = getattr(env_server, name)
73
+ except AttributeError as exc:
74
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc
75
+
76
+ globals()[name] = value
77
+ return value
78
+
79
+
80
+ def __dir__() -> list[str]:
81
+ return sorted(set(globals().keys()) | set(__all__))
src/core/client_types.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Type definitions for EnvTorch
2
+ from dataclasses import dataclass
3
+ from typing import Generic, Optional, TypeVar
4
+
5
+ # Generic type for observations
6
+ ObsT = TypeVar("ObsT")
7
+ StateT = TypeVar("StateT")
8
+
9
+
10
+ @dataclass
11
+ class StepResult(Generic[ObsT]):
12
+ """
13
+ Represents the result of one environment step.
14
+
15
+ Attributes:
16
+ observation: The environment's observation after the action.
17
+ reward: Scalar reward for this step (optional).
18
+ done: Whether the episode is finished.
19
+ """
20
+
21
+ observation: ObsT
22
+ reward: Optional[float] = None
23
+ done: bool = False
src/core/containers/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Container management for environment servers."""
src/core/containers/images/Dockerfile ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ #
8
+ # OpenEnv Base Image
9
+ #
10
+ # This is the standard base image for all OpenEnv environment servers.
11
+ # It includes the minimal dependencies needed to run HTTP environment servers
12
+ # and uv for fast dependency management.
13
+ #
14
+ # Build from repo root: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
15
+ # Tag: docker tag openenv-base:latest openenv-base:0.2.0
16
+ #
17
+
18
+ FROM ghcr.io/astral-sh/uv:0.5.27-python3.11-bookworm-slim AS builder
19
+
20
+ # Set working directory
21
+ WORKDIR /app
22
+
23
+ # Copy core pyproject.toml and lockfile for dependency installation
24
+ COPY pyproject.toml uv.lock* ./
25
+
26
+ # Install core dependencies using uv with cache mount
27
+ RUN --mount=type=cache,target=/root/.cache/uv \
28
+ uv pip install --system -r pyproject.toml
29
+
30
+ # Final runtime stage
31
+ FROM python:3.11-slim
32
+
33
+ # Set metadata
34
+ LABEL maintainer="OpenEnv Team"
35
+ LABEL description="Base image for OpenEnv based environment servers with uv"
36
+ LABEL version="0.2.0"
37
+
38
+ # Install system dependencies
39
+ RUN apt-get update && apt-get install -y --no-install-recommends \
40
+ curl \
41
+ ca-certificates \
42
+ && rm -rf /var/lib/apt/lists/*
43
+
44
+ # Copy uv from builder
45
+ COPY --from=builder /usr/local/bin/uv /usr/local/bin/uvx /usr/local/bin/
46
+
47
+ # Copy installed Python packages from builder
48
+ COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
49
+
50
+ # Copy console scripts installed by pip (uvicorn, fastapi, etc.)
51
+ COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/fastapi /usr/local/bin/
52
+
53
+ # Set working directory
54
+ WORKDIR /app
55
+
56
+ # Default environment variables
57
+ ENV PYTHONPATH=/app/src
58
+ ENV PYTHONUNBUFFERED=1
59
+ ENV UV_SYSTEM_PYTHON=1
60
+
61
+ # Default expose port (can be overridden)
62
+ EXPOSE 8000
63
+
64
+ # Note: CMD should be specified in child Dockerfiles
src/core/containers/images/README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenEnv Base Image
2
+
3
+ Standard base image for all OpenEnv environment servers.
4
+
5
+ ## What's Included
6
+
7
+ | Layer | Size | Contents |
8
+ |-------|------|----------|
9
+ | python:3.11-slim | 200 MB | Base Python runtime |
10
+ | + Dependencies | 100 MB | FastAPI, uvicorn, requests |
11
+ | **Total** | **~300 MB** | Ready for environment servers |
12
+
13
+ ## Image Sizes
14
+
15
+ ```
16
+ openenv-base:latest 300 MB (python + fastapi + uvicorn)
17
+ ```
18
+ echo-env:latest 500 MB (python + fastapi + uvicorn + app)
19
+ coding-env:latest 520 MB (python + fastapi + uvicorn + app + tools)
20
+ another-env:latest 510 MB (python + fastapi + uvicorn + app)
21
+ ---
22
+ Total: 1.5 GB (with lots of duplication)
23
+ ```
24
+
25
+ ### With Base Images (✅ Solution)
26
+ ```
27
+ openenv-base:latest 300 MB (python + fastapi + uvicorn)
28
+ echo-env:latest 50 MB (app only, uses base)
29
+ coding-env:latest 70 MB (app + tools, uses base)
30
+ another-env:latest 45 MB (app only, uses base)
31
+ ---
32
+ Total: 465 MB (base shared, minimal duplication)
33
+ ```
34
+
35
+ ## Building the Base Image
36
+
37
+ ```bash
38
+ # From project root
39
+ docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
40
+ ```
41
+
42
+ ## Usage in Environment Dockerfiles
43
+
44
+ Each environment Dockerfile should start with:
45
+
46
+ ```dockerfile
47
+ FROM openenv-base:latest
48
+
49
+ # Copy only environment-specific files
50
+ COPY src/openenv/core/ /app/src/openenv/core/
51
+ COPY envs/my_env/ /app/envs/my_env/
52
+
53
+ # Run the server
54
+ CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
55
+ ```
56
+
57
+ ## Base Image Contents
58
+
59
+ - Python 3.11-slim
60
+ - FastAPI >= 0.104.0
61
+ - Uvicorn >= 0.24.0
62
+ - Requests >= 2.25.0
63
+ - curl (for health checks)
64
+
65
+ ## Example: Building Echo Environment
66
+
67
+ ```bash
68
+ # Step 1: Build base image (do this once)
69
+ docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
70
+
71
+ # Step 2: Build echo environment (uses base)
72
+ docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
73
+
74
+ # Step 3: Run echo environment
75
+ docker run -p 8000:8000 echo-env:latest
76
+ ```
77
+
78
+ ## Updating the Base
79
+
80
+ When dependencies need updating:
81
+
82
+ 1. Update `src/openenv/core/containers/images/Dockerfile`
83
+ 2. Rebuild base image
84
+ 3. Rebuild all environment images (they'll use new base)
85
+
86
+ ```bash
87
+ # Update base
88
+ docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile .
89
+
90
+ # Rebuild environments (they automatically use new base)
91
+ docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile .
92
+ ```
src/core/containers/runtime/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Container runtime providers."""
8
+
9
+ from .providers import (
10
+ ContainerProvider,
11
+ DockerSwarmProvider,
12
+ KubernetesProvider,
13
+ LocalDockerProvider,
14
+ RuntimeProvider,
15
+ )
16
+ from .uv_provider import UVProvider
17
+
18
+ __all__ = [
19
+ "ContainerProvider",
20
+ "DockerSwarmProvider",
21
+ "LocalDockerProvider",
22
+ "KubernetesProvider",
23
+ "RuntimeProvider",
24
+ "UVProvider",
25
+ ]
src/core/containers/runtime/daytona_provider.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Daytona container provider for running OpenEnv environments in Daytona cloud sandboxes.
9
+
10
+ Requires the ``daytona`` SDK: ``pip install daytona>=0.10``
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import shlex
18
+ import time
19
+ from typing import Any, Callable, Dict, Optional
20
+
21
+ import yaml
22
+
23
+ from .providers import ContainerProvider
24
+
25
+
26
+ class DaytonaProvider(ContainerProvider):
27
+ """
28
+ Container provider that runs environments in Daytona cloud sandboxes.
29
+
30
+ Example:
31
+ >>> provider = DaytonaProvider(api_key="your-key")
32
+ >>> image = DaytonaProvider.image_from_dockerfile("envs/echo_env/server/Dockerfile")
33
+ >>> base_url = provider.start_container(image)
34
+ >>> provider.wait_for_ready(base_url)
35
+ >>> provider.stop_container()
36
+ """
37
+
38
+ _dockerfile_registry: Dict[str, Dict[str, Any]] = {}
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ api_key: Optional[str] = None,
44
+ public: bool = False,
45
+ resources: Optional[Any] = None,
46
+ auto_stop_interval: int = 15,
47
+ target: Optional[str] = None,
48
+ on_snapshot_create_logs: Optional[Callable[[str], None]] = None,
49
+ cmd: Optional[str] = None,
50
+ create_timeout: float = 300,
51
+ ):
52
+ """
53
+ Args:
54
+ api_key: Daytona API key. Falls back to ``DAYTONA_API_KEY`` env var.
55
+ public: If True, the sandbox preview is publicly accessible.
56
+ resources: Optional ``daytona.Resources`` instance for CPU/memory.
57
+ auto_stop_interval: Minutes of inactivity before auto-stop (0 disables).
58
+ target: Daytona target region (e.g. "us").
59
+ on_snapshot_create_logs: Callback for snapshot build log lines.
60
+ cmd: Shell command to start the server inside the sandbox.
61
+ create_timeout: Seconds to wait for sandbox creation (default 300).
62
+ Heavy images (e.g. with Playwright/Chromium) may need more.
63
+ """
64
+ from daytona import Daytona, DaytonaConfig
65
+
66
+ config_kwargs: Dict[str, Any] = {}
67
+ resolved_key = api_key or os.environ.get("DAYTONA_API_KEY")
68
+ if resolved_key:
69
+ config_kwargs["api_key"] = resolved_key
70
+ if target:
71
+ config_kwargs["target"] = target
72
+
73
+ self._daytona = Daytona(DaytonaConfig(**config_kwargs))
74
+ self._public = public
75
+ self._resources = resources
76
+ self._auto_stop_interval = auto_stop_interval
77
+ self._on_snapshot_create_logs = on_snapshot_create_logs
78
+ self._cmd = cmd
79
+ self._create_timeout = create_timeout
80
+ self._sandbox: Any = None
81
+ self._preview_url: Optional[str] = None
82
+
83
+ def _discover_server_cmd(self, sandbox: Any, port: int = 8000) -> str:
84
+ """Discover the server command from ``openenv.yaml`` inside *sandbox*.
85
+
86
+ Finds the file, reads the ``app`` field, and constructs a command
87
+ of the form ``cd <env_root> && python -m uvicorn <app> --host 0.0.0.0 --port <port>``.
88
+
89
+ Raises:
90
+ ValueError: If ``openenv.yaml`` is not found or lacks an ``app`` field.
91
+ """
92
+ yaml_path = self._find_openenv_yaml(sandbox)
93
+ if yaml_path is None:
94
+ raise ValueError(
95
+ "Could not find openenv.yaml inside the sandbox. "
96
+ "Pass an explicit cmd= to DaytonaProvider or start_container()."
97
+ )
98
+
99
+ cat_resp = sandbox.process.exec(f"cat {shlex.quote(yaml_path)}", timeout=10)
100
+ content = cat_resp.result if hasattr(cat_resp, "result") else str(cat_resp)
101
+ app = self._parse_app_field(content)
102
+ if app is None:
103
+ raise ValueError(
104
+ f"openenv.yaml at {yaml_path} does not contain an 'app' field. "
105
+ "Pass an explicit cmd= to DaytonaProvider or start_container()."
106
+ )
107
+
108
+ # The directory containing openenv.yaml is the env root
109
+ env_root = yaml_path.rsplit("/", 1)[0]
110
+ return (
111
+ f"cd {shlex.quote(env_root)} && "
112
+ f"python -m uvicorn {shlex.quote(app)} --host 0.0.0.0 --port {port}"
113
+ )
114
+
115
+ def _find_openenv_yaml(self, sandbox: Any) -> Optional[str]:
116
+ """Locate ``openenv.yaml`` inside the sandbox.
117
+
118
+ Tries the modern layout path ``/app/env/openenv.yaml`` first,
119
+ then falls back to a ``find`` command for the old layout.
120
+ """
121
+ # Fast path: modern Dockerfile layout
122
+ resp = sandbox.process.exec(
123
+ "test -f /app/env/openenv.yaml && echo found", timeout=10
124
+ )
125
+ out = resp.result if hasattr(resp, "result") else str(resp)
126
+ if "found" in (out or ""):
127
+ return "/app/env/openenv.yaml"
128
+
129
+ # Fallback: search for it (redirect stderr so error messages
130
+ # like "No such file or directory" don't get mistaken for paths).
131
+ resp = sandbox.process.exec(
132
+ "find /app -maxdepth 4 -name openenv.yaml -print -quit 2>/dev/null",
133
+ timeout=10,
134
+ )
135
+ path = (resp.result if hasattr(resp, "result") else str(resp) or "").strip()
136
+ if path and path.startswith("/"):
137
+ return path
138
+
139
+ return None
140
+
141
+ @staticmethod
142
+ def _parse_app_field(yaml_content: str) -> Optional[str]:
143
+ """Extract the ``app`` value from raw openenv.yaml content.
144
+
145
+ Uses PyYAML to handle comments, quotes, and nested keys correctly.
146
+ """
147
+ try:
148
+ data = yaml.safe_load(yaml_content) or {}
149
+ except Exception:
150
+ return None
151
+
152
+ if not isinstance(data, dict):
153
+ return None
154
+
155
+ value = data.get("app")
156
+ if isinstance(value, str):
157
+ value = value.strip()
158
+ return value if value else None
159
+ return None
160
+
161
+ @staticmethod
162
+ def _parse_dockerfile_cmd(dockerfile_content: str) -> Optional[str]:
163
+ """Extract the server command from the last ``CMD`` in a Dockerfile.
164
+
165
+ Handles exec form (``CMD ["prog", "arg"]``) and shell form
166
+ (``CMD prog arg``). When a Dockerfile has multiple ``CMD``
167
+ instructions (e.g. multi-stage builds), the last one wins - same
168
+ semantics as Docker itself. Lines where ``CMD`` appears inside a
169
+ comment are ignored.
170
+
171
+ Returns:
172
+ The command as a single string, or ``None`` if no ``CMD`` found.
173
+ """
174
+ import re
175
+
176
+ last_cmd: Optional[str] = None
177
+ for line in dockerfile_content.splitlines():
178
+ stripped = line.strip()
179
+ # Skip comments
180
+ if stripped.startswith("#"):
181
+ continue
182
+ match = re.match(r"CMD\s+(.+)", stripped, flags=re.IGNORECASE)
183
+ if match:
184
+ last_cmd = match.group(1).strip()
185
+
186
+ if last_cmd is None:
187
+ return None
188
+
189
+ # Exec form: CMD ["executable", "param1", ...]
190
+ if last_cmd.startswith("["):
191
+ try:
192
+ parts = json.loads(last_cmd)
193
+ if isinstance(parts, list) and all(isinstance(p, str) for p in parts):
194
+ return " ".join(parts)
195
+ except (json.JSONDecodeError, TypeError):
196
+ pass
197
+
198
+ # Shell form: CMD executable param1 ...
199
+ return last_cmd if last_cmd else None
200
+
201
+ @staticmethod
202
+ def strip_buildkit_syntax(dockerfile_content: str) -> str:
203
+ """Remove BuildKit ``--mount=...`` flags from ``RUN`` instructions.
204
+
205
+ Handles single-line flags, multi-line continuations, and multiple
206
+ ``--mount`` flags spread across continuation lines. Only leading
207
+ ``--mount`` flags are removed (before the actual command starts).
208
+
209
+ Daytona's ``Image.from_dockerfile`` does not support BuildKit
210
+ ``--mount`` syntax. This helper strips the flags so that standard
211
+ Dockerfiles (like the ones generated by ``openenv build``) can
212
+ be used directly.
213
+ """
214
+ import re
215
+
216
+ def strip_leading_mounts(text: str) -> str:
217
+ remaining = text
218
+ while True:
219
+ match = re.match(r"\s*--mount=\S+\s*", remaining)
220
+ if not match:
221
+ return remaining
222
+ remaining = remaining[match.end() :]
223
+
224
+ lines = dockerfile_content.split("\n")
225
+ result: list[str] = []
226
+ in_run = False
227
+ in_mount_prefix = False
228
+
229
+ for line in lines:
230
+ line_out = line
231
+ run_start = False
232
+ if re.match(r"\s*RUN(\s+|$)", line, flags=re.IGNORECASE):
233
+ in_run = True
234
+ in_mount_prefix = True
235
+ run_start = True
236
+
237
+ if in_run and in_mount_prefix:
238
+ original_ends_with_slash = line_out.rstrip().endswith("\\")
239
+ if run_start:
240
+ match = re.match(r"(\s*RUN\s+)(.*)$", line_out, flags=re.IGNORECASE)
241
+ if match:
242
+ run_prefix, remainder = match.group(1), match.group(2)
243
+ else:
244
+ run_prefix, remainder = line_out, ""
245
+ new_remainder = strip_leading_mounts(remainder)
246
+ line_out = run_prefix + new_remainder
247
+ content_for_check = new_remainder
248
+ else:
249
+ new_remainder = strip_leading_mounts(line_out)
250
+ line_out = new_remainder
251
+ content_for_check = new_remainder
252
+
253
+ if original_ends_with_slash and not line_out.rstrip().endswith("\\"):
254
+ line_out = line_out.rstrip() + " \\"
255
+
256
+ if content_for_check.strip() not in ("", "\\"):
257
+ in_mount_prefix = False
258
+
259
+ if in_run and not line_out.rstrip().endswith("\\"):
260
+ in_run = False
261
+ in_mount_prefix = False
262
+
263
+ result.append(line_out)
264
+
265
+ return "\n".join(result)
266
+
267
+ @classmethod
268
+ def image_from_dockerfile(
269
+ cls,
270
+ dockerfile_path: str,
271
+ context_dir: str | None = None,
272
+ ) -> str:
273
+ """Validate a Dockerfile and return a ``dockerfile:`` URI for
274
+ :meth:`start_container`.
275
+
276
+ Eagerly validates the Dockerfile (existence, COPY sources,
277
+ BuildKit stripping) and stores the processed content in an
278
+ internal registry. The actual ``daytona.Image`` is created
279
+ later inside ``start_container``.
280
+
281
+ Args:
282
+ dockerfile_path: Path to the Dockerfile on disk.
283
+ context_dir: Build context directory. Defaults to the
284
+ Dockerfile's grandparent directory, matching the
285
+ ``openenv init`` convention where Dockerfiles live in
286
+ ``<env>/server/Dockerfile`` and the build context is
287
+ ``<env>/``. Pass explicitly for non-standard layouts
288
+ (e.g. ``context_dir="."`` for repo-root contexts).
289
+
290
+ Returns:
291
+ A ``"dockerfile:<abs_path>"`` string to pass to
292
+ ``start_container``.
293
+
294
+ Raises:
295
+ FileNotFoundError: If *dockerfile_path* does not exist.
296
+ ValueError: If *context_dir* is given but does not exist,
297
+ or if COPY sources in the Dockerfile cannot be found
298
+ under the resolved context directory.
299
+ """
300
+ import pathlib
301
+ import re
302
+
303
+ src = pathlib.Path(dockerfile_path).resolve()
304
+ if not src.is_file():
305
+ raise FileNotFoundError(f"Dockerfile not found: {dockerfile_path}")
306
+
307
+ if context_dir is not None:
308
+ ctx = pathlib.Path(context_dir)
309
+ if not ctx.is_dir():
310
+ raise ValueError(f"context_dir does not exist: {context_dir}")
311
+ else:
312
+ # Default: grandparent of the Dockerfile, matching the
313
+ # openenv init layout (<env>/server/Dockerfile -> <env>/).
314
+ ctx = src.parent.parent
315
+
316
+ content = src.read_text()
317
+ stripped = cls.strip_buildkit_syntax(content)
318
+
319
+ # Validate that COPY sources exist under the context directory.
320
+ # This catches mismatches early (e.g. a Dockerfile expecting repo
321
+ # root as context when we defaulted to the env directory).
322
+ for line in stripped.splitlines():
323
+ m = re.match(r"^\s*COPY\s+(?!--from=)(\S+)\s+", line, re.IGNORECASE)
324
+ if not m:
325
+ continue
326
+ copy_src = m.group(1)
327
+ if copy_src.startswith("/"):
328
+ continue
329
+ resolved = ctx / copy_src
330
+ if not resolved.exists() and not any(ctx.glob(copy_src)):
331
+ raise ValueError(
332
+ f"Dockerfile COPY source '{copy_src}' not found "
333
+ f"under context_dir '{ctx}'. This Dockerfile may "
334
+ f"expect a different build context (e.g. the repo "
335
+ f"root). Pass context_dir explicitly."
336
+ )
337
+
338
+ # Parse CMD from the original Dockerfile so start_container can
339
+ # use it as a fallback when openenv.yaml is unavailable.
340
+ parsed_cmd = cls._parse_dockerfile_cmd(content)
341
+
342
+ cls._dockerfile_registry[str(src)] = {
343
+ "stripped_content": stripped,
344
+ "context_dir": str(ctx),
345
+ "server_cmd": parsed_cmd,
346
+ }
347
+
348
+ return f"dockerfile:{src}"
349
+
350
+ def start_container(
351
+ self,
352
+ image: str,
353
+ port: Optional[int] = None,
354
+ env_vars: Optional[Dict[str, str]] = None,
355
+ **kwargs: Any,
356
+ ) -> str:
357
+ """
358
+ Create a Daytona sandbox from a Docker image or snapshot.
359
+
360
+ Daytona does not execute the image's CMD (known bug — ENTRYPOINT
361
+ runs, CMD does not). The server command is resolved in order:
362
+
363
+ 1. Explicit ``cmd`` passed to the constructor.
364
+ 2. ``cmd`` key in ``**kwargs`` (popped before forwarding).
365
+ 3. Auto-discovered from ``openenv.yaml`` inside the sandbox.
366
+ 4. ``CMD`` parsed from the Dockerfile (when *image* came from
367
+ ``image_from_dockerfile``).
368
+
369
+ Args:
370
+ image: Docker image name (e.g. ``"echo-env:latest"``),
371
+ ``"snapshot:<name>"`` to create from a pre-built snapshot,
372
+ or ``"dockerfile:<path>"`` returned by
373
+ :meth:`image_from_dockerfile`.
374
+ port: Must be ``None`` or ``8000``. Daytona exposes port 8000
375
+ via its preview proxy; other ports raise ``ValueError``.
376
+ env_vars: Environment variables forwarded to the sandbox.
377
+ **kwargs: ``cmd`` (str) to override the server command;
378
+ remaining kwargs passed through to ``Daytona.create()``.
379
+
380
+ Returns:
381
+ HTTPS preview URL for the sandbox (base_url).
382
+ """
383
+ if port is not None and port != 8000:
384
+ raise ValueError(
385
+ f"DaytonaProvider only supports port 8000 (got {port}). "
386
+ "The Daytona preview proxy routes to port 8000 inside the sandbox."
387
+ )
388
+
389
+ # Resolve the server command (may be None; discovery happens after
390
+ # sandbox creation when we can inspect the filesystem).
391
+ cmd = kwargs.pop("cmd", None) or self._cmd
392
+
393
+ # CMD parsed from Dockerfile (populated for "dockerfile:" images).
394
+ parsed_cmd: Optional[str] = None
395
+
396
+ # Build creation params
397
+ create_kwargs: Dict[str, Any] = {}
398
+ if env_vars:
399
+ create_kwargs["env_vars"] = env_vars
400
+ if self._public:
401
+ create_kwargs["public"] = True
402
+ if self._auto_stop_interval != 15:
403
+ create_kwargs["auto_stop_interval"] = self._auto_stop_interval
404
+
405
+ if image.startswith("snapshot:"):
406
+ from daytona import CreateSandboxFromSnapshotParams
407
+
408
+ snapshot_name = image[len("snapshot:") :]
409
+ params = CreateSandboxFromSnapshotParams(
410
+ snapshot=snapshot_name, **create_kwargs
411
+ )
412
+ elif image.startswith("dockerfile:"):
413
+ from daytona import CreateSandboxFromImageParams, Image
414
+
415
+ dockerfile_path = image[len("dockerfile:") :]
416
+ meta = self._dockerfile_registry.get(dockerfile_path)
417
+ if meta is None:
418
+ raise ValueError(
419
+ f"No registered Dockerfile metadata for {dockerfile_path}. "
420
+ "Call DaytonaProvider.image_from_dockerfile() first."
421
+ )
422
+
423
+ parsed_cmd = meta.get("server_cmd")
424
+
425
+ # Build the daytona Image from the pre-stripped content.
426
+ import pathlib
427
+ import uuid
428
+
429
+ ctx = pathlib.Path(meta["context_dir"])
430
+ tmp_name = f".daytona-{uuid.uuid4().hex[:8]}.dockerfile"
431
+ tmp_path = ctx / tmp_name
432
+ try:
433
+ tmp_path.write_text(meta["stripped_content"])
434
+ daytona_image = Image.from_dockerfile(str(tmp_path))
435
+ finally:
436
+ tmp_path.unlink(missing_ok=True)
437
+
438
+ img_kwargs: Dict[str, Any] = {
439
+ "image": daytona_image,
440
+ **create_kwargs,
441
+ }
442
+ if self._resources is not None:
443
+ img_kwargs["resources"] = self._resources
444
+ params = CreateSandboxFromImageParams(**img_kwargs)
445
+ else:
446
+ from daytona import CreateSandboxFromImageParams
447
+
448
+ img_kwargs = {"image": image, **create_kwargs}
449
+ if self._resources is not None:
450
+ img_kwargs["resources"] = self._resources
451
+ params = CreateSandboxFromImageParams(**img_kwargs)
452
+
453
+ # Create sandbox
454
+ extra: Dict[str, Any] = dict(kwargs)
455
+ if self._on_snapshot_create_logs is not None:
456
+ extra["on_snapshot_create_logs"] = self._on_snapshot_create_logs
457
+
458
+ self._sandbox = self._daytona.create(
459
+ params, timeout=self._create_timeout, **extra
460
+ )
461
+
462
+ try:
463
+ # Discover server command from openenv.yaml if not explicitly set.
464
+ if cmd is None:
465
+ try:
466
+ cmd = self._discover_server_cmd(self._sandbox)
467
+ except ValueError:
468
+ # Fall back to CMD parsed from Dockerfile (if available).
469
+ if parsed_cmd:
470
+ cmd = parsed_cmd
471
+ else:
472
+ raise
473
+
474
+ # Wrap in bash -c so compound commands (cd ... && uvicorn ...)
475
+ # are handled correctly by nohup. Write PID so we can check
476
+ # if the process crashed later in wait_for_ready().
477
+ escaped_cmd = shlex.quote(cmd)
478
+ self._sandbox.process.exec(
479
+ f"nohup bash -c {escaped_cmd} > /tmp/openenv-server.log 2>&1 &"
480
+ " echo $! > /tmp/openenv-server.pid",
481
+ timeout=10,
482
+ )
483
+
484
+ # Get a signed preview URL for port 8000. The token is
485
+ # embedded in the URL itself so no extra headers are needed.
486
+ signed = self._sandbox.create_signed_preview_url(
487
+ 8000, expires_in_seconds=86400
488
+ )
489
+ self._preview_url = signed.url
490
+ except Exception:
491
+ self.stop_container()
492
+ raise
493
+
494
+ return self._preview_url
495
+
496
+ def refresh_preview_url(self) -> str:
497
+ """Get a fresh signed preview URL (valid for 24h).
498
+
499
+ Daytona signed URLs expire after at most 24 hours. Call this to
500
+ get a new one for long-running sessions. The returned URL points
501
+ to the same sandbox — clients will need to reconnect using it.
502
+ """
503
+ if self._sandbox is None:
504
+ raise RuntimeError("No active sandbox to refresh URL for.")
505
+ signed = self._sandbox.create_signed_preview_url(8000, expires_in_seconds=86400)
506
+ self._preview_url = signed.url
507
+ return self._preview_url
508
+
509
+ def stop_container(self) -> None:
510
+ """Delete the Daytona sandbox."""
511
+ if self._sandbox is None:
512
+ return
513
+
514
+ try:
515
+ self._daytona.delete(self._sandbox)
516
+ finally:
517
+ self._sandbox = None
518
+ self._preview_url = None
519
+
520
+ def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
521
+ """
522
+ Poll the /health endpoint until the sandbox is ready.
523
+
524
+ Uses a longer default timeout (120s) than Docker providers because
525
+ Daytona sandboxes may have cold-start latency.
526
+
527
+ Args:
528
+ base_url: Preview URL returned by ``start_container()``.
529
+ timeout_s: Maximum seconds to wait.
530
+
531
+ Raises:
532
+ TimeoutError: If the sandbox doesn't become ready in time.
533
+ RuntimeError: If the server process died (detected via PID check).
534
+ """
535
+ import requests
536
+
537
+ health_url = f"{base_url}/health"
538
+
539
+ deadline = time.time() + timeout_s
540
+ while time.time() < deadline:
541
+ try:
542
+ response = requests.get(health_url, timeout=5.0)
543
+ if response.status_code == 200:
544
+ return
545
+ except requests.RequestException:
546
+ pass
547
+
548
+ # Early exit: if the server process died, raise immediately
549
+ # instead of waiting for the full health-check timeout.
550
+ if self._sandbox is not None:
551
+ resp = self._sandbox.process.exec(
552
+ "kill -0 $(cat /tmp/openenv-server.pid) 2>/dev/null"
553
+ " && echo RUNNING || echo DEAD",
554
+ timeout=10,
555
+ )
556
+ out = resp.result if hasattr(resp, "result") else str(resp)
557
+ if "DEAD" in (out or ""):
558
+ log_resp = self._sandbox.process.exec(
559
+ "cat /tmp/openenv-server.log 2>/dev/null", timeout=10
560
+ )
561
+ log = (
562
+ log_resp.result
563
+ if hasattr(log_resp, "result")
564
+ else str(log_resp)
565
+ )
566
+ raise RuntimeError(f"Server process died.\nLog:\n{log}")
567
+
568
+ time.sleep(1.0)
569
+
570
+ raise TimeoutError(
571
+ f"Daytona sandbox at {base_url} did not become ready within {timeout_s}s"
572
+ )
src/core/containers/runtime/providers.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Container provider abstractions for running environment servers.
9
+
10
+ This module provides a pluggable architecture for different container providers
11
+ (local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from abc import ABC, abstractmethod
17
+ from typing import Any, Dict, Optional, Sequence
18
+
19
+
20
+ class ContainerProvider(ABC):
21
+ """
22
+ Abstract base class for container providers.
23
+
24
+ Providers implement this interface to support different container platforms:
25
+ - LocalDockerProvider: Runs containers on local Docker daemon
26
+ - KubernetesProvider: Runs containers in Kubernetes cluster
27
+ - FargateProvider: Runs containers on AWS Fargate
28
+ - CloudRunProvider: Runs containers on Google Cloud Run
29
+
30
+ The provider manages a single container lifecycle and provides the base URL
31
+ for connecting to it.
32
+
33
+ Example:
34
+ >>> provider = LocalDockerProvider()
35
+ >>> base_url = provider.start_container("echo-env:latest")
36
+ >>> print(base_url) # http://localhost:8000
37
+ >>> # Use the environment via base_url
38
+ >>> provider.stop_container()
39
+ """
40
+
41
+ @abstractmethod
42
+ def start_container(
43
+ self,
44
+ image: str,
45
+ port: Optional[int] = None,
46
+ env_vars: Optional[Dict[str, str]] = None,
47
+ **kwargs: Any,
48
+ ) -> str:
49
+ """
50
+ Start a container from the specified image.
51
+
52
+ Args:
53
+ image: Container image name (e.g., "echo-env:latest")
54
+ port: Port to expose (if None, provider chooses)
55
+ env_vars: Environment variables to pass to container
56
+ **kwargs: Provider-specific options
57
+
58
+ Returns:
59
+ Base URL to connect to the container (e.g., "http://localhost:8000")
60
+
61
+ Raises:
62
+ RuntimeError: If container fails to start
63
+ """
64
+ pass
65
+
66
+ @abstractmethod
67
+ def stop_container(self) -> None:
68
+ """
69
+ Stop and remove the running container.
70
+
71
+ This cleans up the container that was started by start_container().
72
+ """
73
+ pass
74
+
75
+ @abstractmethod
76
+ def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
77
+ """
78
+ Wait for the container to be ready to accept requests.
79
+
80
+ This typically polls the /health endpoint until it returns 200.
81
+
82
+ Args:
83
+ base_url: Base URL of the container
84
+ timeout_s: Maximum time to wait
85
+
86
+ Raises:
87
+ TimeoutError: If container doesn't become ready in time
88
+ """
89
+ pass
90
+
91
+
92
+ class LocalDockerProvider(ContainerProvider):
93
+ """
94
+ Container provider for local Docker daemon.
95
+
96
+ This provider runs containers on the local machine using Docker.
97
+ Useful for development and testing.
98
+
99
+ Example:
100
+ >>> provider = LocalDockerProvider()
101
+ >>> base_url = provider.start_container("echo-env:latest")
102
+ >>> # Container running on http://localhost:<random-port>
103
+ >>> provider.stop_container()
104
+ """
105
+
106
+ def __init__(self):
107
+ """Initialize the local Docker provider."""
108
+ self._container_id: Optional[str] = None
109
+ self._container_name: Optional[str] = None
110
+
111
+ # Check if Docker is available
112
+ import subprocess
113
+
114
+ try:
115
+ subprocess.run(
116
+ ["docker", "version"],
117
+ check=True,
118
+ capture_output=True,
119
+ timeout=5,
120
+ )
121
+ except (
122
+ subprocess.CalledProcessError,
123
+ FileNotFoundError,
124
+ subprocess.TimeoutExpired,
125
+ ):
126
+ raise RuntimeError(
127
+ "Docker is not available. Please install Docker Desktop or Docker Engine."
128
+ )
129
+
130
+ def start_container(
131
+ self,
132
+ image: str,
133
+ port: Optional[int] = None,
134
+ env_vars: Optional[Dict[str, str]] = None,
135
+ **kwargs: Any,
136
+ ) -> str:
137
+ """
138
+ Start a Docker container locally.
139
+
140
+ Args:
141
+ image: Docker image name
142
+ port: Port to expose (if None, finds available port)
143
+ env_vars: Environment variables for the container
144
+ **kwargs: Additional Docker run options
145
+
146
+ Returns:
147
+ Base URL to connect to the container
148
+ """
149
+ import subprocess
150
+ import time
151
+
152
+ # Find available port if not specified
153
+ if port is None:
154
+ port = self._find_available_port()
155
+
156
+ # Generate container name
157
+ self._container_name = self._generate_container_name(image)
158
+
159
+ # Build docker run command
160
+ cmd = [
161
+ "docker",
162
+ "run",
163
+ "-d", # Detached
164
+ "--name",
165
+ self._container_name,
166
+ "-p",
167
+ f"{port}:8000", # Map port
168
+ ]
169
+
170
+ # Add environment variables
171
+ if env_vars:
172
+ for key, value in env_vars.items():
173
+ cmd.extend(["-e", f"{key}={value}"])
174
+
175
+ # Add image
176
+ cmd.append(image)
177
+
178
+ # Run container
179
+ try:
180
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
181
+ self._container_id = result.stdout.strip()
182
+ except subprocess.CalledProcessError as e:
183
+ error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}"
184
+ raise RuntimeError(error_msg) from e
185
+
186
+ # Wait a moment for container to start
187
+ time.sleep(1)
188
+
189
+ base_url = f"http://localhost:{port}"
190
+ return base_url
191
+
192
+ def stop_container(self) -> None:
193
+ """
194
+ Stop and remove the Docker container.
195
+ """
196
+ if self._container_id is None:
197
+ return
198
+
199
+ import subprocess
200
+
201
+ try:
202
+ # Stop container
203
+ subprocess.run(
204
+ ["docker", "stop", self._container_id],
205
+ capture_output=True,
206
+ check=True,
207
+ timeout=10,
208
+ )
209
+
210
+ # Remove container
211
+ subprocess.run(
212
+ ["docker", "rm", self._container_id],
213
+ capture_output=True,
214
+ check=True,
215
+ timeout=10,
216
+ )
217
+ except subprocess.CalledProcessError:
218
+ # Container might already be stopped/removed
219
+ pass
220
+ finally:
221
+ self._container_id = None
222
+ self._container_name = None
223
+
224
+ def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
225
+ """
226
+ Wait for container to be ready by polling /health endpoint.
227
+
228
+ Args:
229
+ base_url: Base URL of the container
230
+ timeout_s: Maximum time to wait
231
+
232
+ Raises:
233
+ TimeoutError: If container doesn't become ready
234
+ """
235
+ import time
236
+
237
+ import requests
238
+
239
+ start_time = time.time()
240
+ health_url = f"{base_url}/health"
241
+
242
+ # Bypass proxy for localhost to avoid proxy issues
243
+ proxies = {"http": None, "https": None}
244
+
245
+ while time.time() - start_time < timeout_s:
246
+ try:
247
+ response = requests.get(health_url, timeout=2.0, proxies=proxies)
248
+ if response.status_code == 200:
249
+ return
250
+ except requests.RequestException:
251
+ pass
252
+
253
+ time.sleep(0.5)
254
+
255
+ raise TimeoutError(
256
+ f"Container at {base_url} did not become ready within {timeout_s}s"
257
+ )
258
+
259
+ def _find_available_port(self) -> int:
260
+ """
261
+ Find an available port on localhost.
262
+
263
+ Returns:
264
+ An available port number
265
+ """
266
+ import socket
267
+
268
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
269
+ s.bind(("", 0))
270
+ s.listen(1)
271
+ port = s.getsockname()[1]
272
+ return port
273
+
274
+ def _generate_container_name(self, image: str) -> str:
275
+ """
276
+ Generate a unique container name based on image name and timestamp.
277
+
278
+ Args:
279
+ image: Docker image name
280
+
281
+ Returns:
282
+ A unique container name
283
+ """
284
+ import time
285
+
286
+ clean_image = image.split("/")[-1].split(":")[0]
287
+ timestamp = int(time.time() * 1000)
288
+ return f"{clean_image}-{timestamp}"
289
+
290
+
291
+ class DockerSwarmProvider(ContainerProvider):
292
+ """
293
+ Container provider that uses Docker Swarm services for local concurrency.
294
+
295
+ This provider creates a replicated Swarm service backed by the local Docker
296
+ engine. The built-in load-balancer fans requests across the replicas,
297
+ allowing multiple container instances to run concurrently on the developer
298
+ workstation (mirroring the workflow described in the Docker stack docs).
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ *,
304
+ auto_init_swarm: bool = True,
305
+ overlay_network: Optional[str] = None,
306
+ ):
307
+ """
308
+ Args:
309
+ auto_init_swarm: Whether to call ``docker swarm init`` when Swarm
310
+ is not active. Otherwise, user must manually initialize Swarm.
311
+ overlay_network: Optional overlay network name for the service.
312
+ When provided, the network is created with
313
+ ``docker network create --driver overlay --attachable`` if it
314
+ does not already exist.
315
+ """
316
+ self._service_name: Optional[str] = None
317
+ self._service_id: Optional[str] = None
318
+ self._published_port: Optional[int] = None
319
+ self._overlay_network = overlay_network
320
+ self._auto_init_swarm = auto_init_swarm
321
+
322
+ self._ensure_docker_available()
323
+ self._ensure_swarm_initialized()
324
+ if self._overlay_network:
325
+ self._ensure_overlay_network(self._overlay_network)
326
+
327
+ def start_container(
328
+ self,
329
+ image: str,
330
+ port: Optional[int] = None,
331
+ env_vars: Optional[Dict[str, str]] = None,
332
+ **kwargs: Any,
333
+ ) -> str:
334
+ """
335
+ Start (or scale) a Swarm service for the given image.
336
+
337
+ Supported kwargs:
338
+ replicas (int): Number of container replicas (default: 2).
339
+ cpu_limit (float | str): CPU limit passed to ``--limit-cpu``.
340
+ memory_limit (str): Memory limit passed to ``--limit-memory``.
341
+ constraints (Sequence[str]): Placement constraints.
342
+ labels (Dict[str, str]): Service labels.
343
+ command (Sequence[str] | str): Override container command.
344
+ """
345
+ import shlex
346
+ import subprocess
347
+ import time
348
+
349
+ allowed_kwargs = {
350
+ "replicas",
351
+ "cpu_limit",
352
+ "memory_limit",
353
+ "constraints",
354
+ "labels",
355
+ "command",
356
+ }
357
+ unknown = set(kwargs) - allowed_kwargs
358
+ if unknown:
359
+ raise ValueError(f"Unsupported kwargs for DockerSwarmProvider: {unknown}")
360
+
361
+ replicas = int(kwargs.get("replicas", 2))
362
+ cpu_limit = kwargs.get("cpu_limit")
363
+ memory_limit = kwargs.get("memory_limit")
364
+ constraints: Optional[Sequence[str]] = kwargs.get("constraints")
365
+ labels: Optional[Dict[str, str]] = kwargs.get("labels")
366
+ command_override = kwargs.get("command")
367
+
368
+ if port is None:
369
+ port = self._find_available_port()
370
+
371
+ self._service_name = self._generate_service_name(image)
372
+ self._published_port = port
373
+
374
+ cmd = [
375
+ "docker",
376
+ "service",
377
+ "create",
378
+ "--detach",
379
+ "--name",
380
+ self._service_name,
381
+ "--replicas",
382
+ str(max(1, replicas)),
383
+ "--publish",
384
+ f"{port}:8000",
385
+ ]
386
+
387
+ if self._overlay_network:
388
+ cmd.extend(["--network", self._overlay_network])
389
+
390
+ if env_vars:
391
+ for key, value in env_vars.items():
392
+ cmd.extend(["--env", f"{key}={value}"])
393
+
394
+ if cpu_limit is not None:
395
+ cmd.extend(["--limit-cpu", str(cpu_limit)])
396
+
397
+ if memory_limit is not None:
398
+ cmd.extend(["--limit-memory", str(memory_limit)])
399
+
400
+ if constraints:
401
+ for constraint in constraints:
402
+ cmd.extend(["--constraint", constraint])
403
+
404
+ if labels:
405
+ for key, value in labels.items():
406
+ cmd.extend(["--label", f"{key}={value}"])
407
+
408
+ cmd.append(image)
409
+
410
+ if command_override:
411
+ if isinstance(command_override, str):
412
+ cmd.extend(shlex.split(command_override))
413
+ else:
414
+ cmd.extend(command_override)
415
+
416
+ try:
417
+ result = subprocess.run(
418
+ cmd,
419
+ capture_output=True,
420
+ text=True,
421
+ check=True,
422
+ )
423
+ self._service_id = result.stdout.strip()
424
+ except subprocess.CalledProcessError as e:
425
+ error_msg = (
426
+ "Failed to start Docker Swarm service.\n"
427
+ f"Command: {' '.join(cmd)}\n"
428
+ f"Exit code: {e.returncode}\n"
429
+ f"Stdout: {e.stdout}\n"
430
+ f"Stderr: {e.stderr}"
431
+ )
432
+ raise RuntimeError(error_msg) from e
433
+
434
+ # Give Swarm a brief moment to schedule the tasks.
435
+ time.sleep(1.0)
436
+
437
+ return f"http://localhost:{port}"
438
+
439
+ def stop_container(self) -> None:
440
+ """
441
+ Remove the Swarm service (and keep the Swarm manager running).
442
+ """
443
+ if not self._service_name:
444
+ return
445
+
446
+ import subprocess
447
+
448
+ try:
449
+ subprocess.run(
450
+ ["docker", "service", "rm", self._service_name],
451
+ capture_output=True,
452
+ check=True,
453
+ timeout=10,
454
+ )
455
+ except subprocess.CalledProcessError:
456
+ # Service may already be gone; ignore.
457
+ pass
458
+ finally:
459
+ self._service_name = None
460
+ self._service_id = None
461
+ self._published_port = None
462
+
463
+ def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
464
+ """
465
+ Wait for at least one replica to become healthy by polling /health.
466
+
467
+ Note: With Swarm's load balancer, requests round-robin across replicas,
468
+ so this only verifies that at least one replica is responding. Some
469
+ replicas may still be starting when this returns.
470
+ """
471
+ import time
472
+
473
+ import requests
474
+
475
+ deadline = time.time() + timeout_s
476
+ health_url = f"{base_url}/health"
477
+
478
+ # Bypass proxy for localhost to avoid proxy issues
479
+ proxies = {"http": None, "https": None}
480
+
481
+ while time.time() < deadline:
482
+ try:
483
+ response = requests.get(health_url, timeout=2.0, proxies=proxies)
484
+ if response.status_code == 200:
485
+ return
486
+ except requests.RequestException:
487
+ pass
488
+
489
+ time.sleep(0.5)
490
+
491
+ raise TimeoutError(
492
+ f"Swarm service at {base_url} did not become ready within {timeout_s}s"
493
+ )
494
+
495
+ def _ensure_docker_available(self) -> None:
496
+ import subprocess
497
+
498
+ try:
499
+ subprocess.run(
500
+ ["docker", "version"],
501
+ check=True,
502
+ capture_output=True,
503
+ timeout=5,
504
+ )
505
+ except (
506
+ subprocess.CalledProcessError,
507
+ FileNotFoundError,
508
+ subprocess.TimeoutExpired,
509
+ ) as exc:
510
+ raise RuntimeError(
511
+ "Docker is not available. Please install Docker Desktop or Docker Engine."
512
+ ) from exc
513
+
514
+ def _ensure_swarm_initialized(self) -> None:
515
+ import subprocess
516
+
517
+ try:
518
+ result = subprocess.run(
519
+ ["docker", "info", "--format", "{{.Swarm.LocalNodeState}}"],
520
+ capture_output=True,
521
+ text=True,
522
+ check=True,
523
+ timeout=5,
524
+ )
525
+ state = result.stdout.strip().lower()
526
+ if state == "active":
527
+ return
528
+ except subprocess.CalledProcessError:
529
+ state = "unknown"
530
+
531
+ if not self._auto_init_swarm:
532
+ raise RuntimeError(
533
+ f"Docker Swarm is not active (state={state}). Enable Swarm manually or pass auto_init_swarm=True."
534
+ )
535
+
536
+ try:
537
+ subprocess.run(
538
+ ["docker", "swarm", "init"],
539
+ check=True,
540
+ capture_output=True,
541
+ timeout=10,
542
+ )
543
+ except subprocess.CalledProcessError as e:
544
+ raise RuntimeError("Failed to initialize Docker Swarm") from e
545
+
546
+ def _ensure_overlay_network(self, network: str) -> None:
547
+ import subprocess
548
+
549
+ inspect = subprocess.run(
550
+ ["docker", "network", "inspect", network],
551
+ capture_output=True,
552
+ text=True,
553
+ check=False,
554
+ )
555
+ if inspect.returncode == 0:
556
+ return
557
+
558
+ try:
559
+ subprocess.run(
560
+ [
561
+ "docker",
562
+ "network",
563
+ "create",
564
+ "--driver",
565
+ "overlay",
566
+ "--attachable",
567
+ network,
568
+ ],
569
+ check=True,
570
+ capture_output=True,
571
+ timeout=10,
572
+ )
573
+ except subprocess.CalledProcessError as e:
574
+ raise RuntimeError(f"Failed to create overlay network '{network}'") from e
575
+
576
+ def _find_available_port(self) -> int:
577
+ import socket
578
+
579
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
580
+ s.bind(("", 0))
581
+ s.listen(1)
582
+ port = s.getsockname()[1]
583
+ return port
584
+
585
+ def _generate_service_name(self, image: str) -> str:
586
+ import time
587
+
588
+ clean_image = image.split("/")[-1].split(":")[0]
589
+ timestamp = int(time.time() * 1000)
590
+ return f"{clean_image}-swarm-{timestamp}"
591
+
592
+
593
+ class KubernetesProvider(ContainerProvider):
594
+ """
595
+ Container provider for Kubernetes clusters.
596
+
597
+ This provider creates pods in a Kubernetes cluster and exposes them
598
+ via services or port-forwarding.
599
+
600
+ Example:
601
+ >>> provider = KubernetesProvider(namespace="envtorch-dev")
602
+ >>> base_url = provider.start_container("echo-env:latest")
603
+ >>> # Pod running in k8s, accessible via service or port-forward
604
+ >>> provider.stop_container()
605
+ """
606
+
607
+ pass
608
+
609
+
610
+ class RuntimeProvider(ABC):
611
+ """
612
+ Abstract base class for runtime providers that are not container providers.
613
+ Providers implement this interface to support different runtime platforms:
614
+ - UVProvider: Runs environments via `uv run`
615
+
616
+ The provider manages a single runtime lifecycle and provides the base URL
617
+ for connecting to it.
618
+
619
+ Example:
620
+ >>> provider = UVProvider(project_path="/path/to/env")
621
+ >>> base_url = provider.start()
622
+ >>> print(base_url) # http://localhost:8000
623
+ >>> provider.stop()
624
+ """
625
+
626
+ @abstractmethod
627
+ def start(
628
+ self,
629
+ port: Optional[int] = None,
630
+ env_vars: Optional[Dict[str, str]] = None,
631
+ **kwargs: Any,
632
+ ) -> str:
633
+ """
634
+ Start a runtime from the specified image.
635
+
636
+ Args:
637
+ image: Runtime image name
638
+ port: Port to expose (if None, provider chooses)
639
+ env_vars: Environment variables for the runtime
640
+ **kwargs: Additional runtime options
641
+ """
642
+
643
+ @abstractmethod
644
+ def stop(self) -> None:
645
+ """
646
+ Stop the runtime.
647
+ """
648
+ pass
649
+
650
+ @abstractmethod
651
+ def wait_for_ready(self, timeout_s: float = 30.0) -> None:
652
+ """
653
+ Wait for the runtime to be ready to accept requests.
654
+ """
655
+ pass
656
+
657
+ def __enter__(self) -> "RuntimeProvider":
658
+ """
659
+ Enter the runtime provider.
660
+ """
661
+ self.start()
662
+ return self
663
+
664
+ def __exit__(self, exc_type, exc, tb) -> None:
665
+ """
666
+ Exit the runtime provider.
667
+ """
668
+ self.stop()
669
+ return False
src/core/containers/runtime/uv_provider.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Providers for launching ASGI applications via ``uv run``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import socket
7
+ import subprocess
8
+ import time
9
+ from typing import Dict, Optional
10
+
11
+ import requests
12
+
13
+ from .providers import RuntimeProvider
14
+
15
+
16
+ def _check_uv_installed() -> None:
17
+ try:
18
+ subprocess.check_output(["uv", "--version"])
19
+ except FileNotFoundError as exc:
20
+ raise RuntimeError(
21
+ "`uv` executable not found. Install uv from https://docs.astral.sh and ensure it is on PATH."
22
+ ) from exc
23
+
24
+
25
+ def _find_free_port() -> int:
26
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
27
+ sock.bind(("", 0))
28
+ sock.listen(1)
29
+ return sock.getsockname()[1]
30
+
31
+
32
+ def _create_uv_command(
33
+ *,
34
+ host: str,
35
+ port: int,
36
+ reload: bool,
37
+ workers: int,
38
+ app: str,
39
+ project_path: str,
40
+ ) -> list[str]:
41
+ command: list[str] = ["uv", "run", "--isolated", "--project", project_path]
42
+
43
+ command.append("--")
44
+ command.extend(
45
+ [
46
+ "uvicorn",
47
+ app,
48
+ "--host",
49
+ host,
50
+ "--port",
51
+ str(port),
52
+ "--workers",
53
+ str(workers),
54
+ ]
55
+ )
56
+
57
+ if reload:
58
+ command.append("--reload")
59
+
60
+ return command
61
+
62
+
63
+ def _poll_health(health_url: str, timeout_s: float) -> None:
64
+ """Poll a health endpoint until it returns HTTP 200 or times out."""
65
+
66
+ deadline = time.time() + timeout_s
67
+ while time.time() < deadline:
68
+ try:
69
+ timeout = max(0.0001, min(deadline - time.time(), 2.0))
70
+ response = requests.get(health_url, timeout=timeout)
71
+ if response.status_code == 200:
72
+ return
73
+ except requests.RequestException:
74
+ continue
75
+
76
+ time.sleep(0.5)
77
+
78
+ raise TimeoutError(f"Server did not become ready within {timeout_s:.1f} seconds")
79
+
80
+
81
+ class UVProvider(RuntimeProvider):
82
+ """
83
+ RuntimeProvider implementation backed by ``uv run``.
84
+
85
+ Args:
86
+ project_path: Local path to a uv project (passed to ``uv run --project``)
87
+ app: ASGI application path for uvicorn (defaults to ``server.app:app``)
88
+ host: Host interface to bind to (defaults to ``0.0.0.0``)
89
+ reload: Whether to enable uvicorn's reload mode
90
+ env_vars: Environment variables to pass through to the spawned process
91
+ context_timeout_s: How long to wait for the environment to become ready
92
+
93
+ Example:
94
+ >>> provider = UVProvider(project_path="/path/to/env")
95
+ >>> base_url = provider.start()
96
+ >>> print(base_url) # http://localhost:8000
97
+ >>> # Use the environment via base_url
98
+ >>> provider.stop()
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ *,
104
+ project_path: str,
105
+ app: str = "server.app:app",
106
+ host: str = "0.0.0.0",
107
+ reload: bool = False,
108
+ env_vars: Optional[Dict[str, str]] = None,
109
+ context_timeout_s: float = 60.0,
110
+ ):
111
+ """Initialize the UVProvider."""
112
+ self.project_path = os.path.abspath(project_path)
113
+ self.app = app
114
+ self.host = host
115
+ self.reload = reload
116
+ self.env_vars = env_vars
117
+ self.context_timeout_s = context_timeout_s
118
+ _check_uv_installed()
119
+ self._process = None
120
+ self._base_url = None
121
+
122
+ def start(
123
+ self,
124
+ port: Optional[int] = None,
125
+ env_vars: Optional[Dict[str, str]] = None,
126
+ workers: int = 1,
127
+ **_: Dict[str, str],
128
+ ) -> str:
129
+ """
130
+ Start the environment via `uv run`.
131
+
132
+ Args:
133
+ port: The port to bind the environment to
134
+ env_vars: Environment variables to pass to the environment
135
+ workers: The number of workers to use
136
+
137
+ Returns:
138
+ The base URL of the environment
139
+
140
+ Raises:
141
+ RuntimeError: If the environment is already running
142
+ """
143
+ if self._process is not None and self._process.poll() is None:
144
+ raise RuntimeError("UVProvider is already running")
145
+
146
+ bind_port = port or _find_free_port()
147
+
148
+ command = _create_uv_command(
149
+ host=self.host,
150
+ port=bind_port,
151
+ reload=self.reload,
152
+ workers=workers,
153
+ app=self.app,
154
+ project_path=self.project_path,
155
+ )
156
+
157
+ env = os.environ.copy()
158
+
159
+ if self.env_vars:
160
+ env.update(self.env_vars)
161
+ if env_vars:
162
+ env.update(env_vars)
163
+
164
+ try:
165
+ self._process = subprocess.Popen(command, env=env)
166
+ except OSError as exc:
167
+ raise RuntimeError(f"Failed to launch `uv run`: {exc}") from exc
168
+
169
+ client_host = "127.0.0.1" if self.host in {"0.0.0.0", "::"} else self.host
170
+ self._base_url = f"http://{client_host}:{bind_port}"
171
+ return self._base_url
172
+
173
+ def wait_for_ready(self, timeout_s: float = 60.0) -> None:
174
+ """
175
+ Wait for the environment to become ready.
176
+
177
+ Args:
178
+ timeout_s: The timeout to wait for the environment to become ready
179
+
180
+ Raises:
181
+ RuntimeError: If the environment is not running
182
+ TimeoutError: If the environment does not become ready within the timeout
183
+ """
184
+ if self._process and self._process.poll() is not None:
185
+ code = self._process.returncode
186
+ raise RuntimeError(f"uv process exited prematurely with code {code}")
187
+
188
+ _poll_health(f"{self._base_url}/health", timeout_s=timeout_s)
189
+
190
+ def stop(self) -> None:
191
+ """
192
+ Stop the environment.
193
+
194
+ Raises:
195
+ RuntimeError: If the environment is not running
196
+ """
197
+ if self._process is None:
198
+ return
199
+
200
+ if self._process.poll() is None:
201
+ self._process.terminate()
202
+ try:
203
+ self._process.wait(timeout=10.0)
204
+ except subprocess.TimeoutExpired:
205
+ self._process.kill()
206
+ self._process.wait(timeout=5.0)
207
+
208
+ self._process = None
209
+ self._base_url = None
210
+
211
+ @property
212
+ def base_url(self) -> str:
213
+ """
214
+ The base URL of the environment.
215
+
216
+ Returns:
217
+ The base URL of the environment
218
+
219
+ Raises:
220
+ RuntimeError: If the environment is not running
221
+ """
222
+ if self._base_url is None:
223
+ raise RuntimeError("UVProvider has not been started")
224
+ return self._base_url
src/core/containers/test_local_docker_provider.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ End-to-end test for LocalDockerProvider.
4
+
5
+ This script tests the complete flow:
6
+ 1. Start a container using LocalDockerProvider
7
+ 2. Wait for it to be ready
8
+ 3. Make HTTP requests to test the environment
9
+ 4. Clean up the container
10
+ """
11
+
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ # Add src to path
16
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
17
+
18
+ import requests
19
+ from openenv.core.containers.runtime import LocalDockerProvider
20
+
21
+
22
+ # TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env
23
+ def test_local_docker_provider():
24
+ """Test LocalDockerProvider end-to-end."""
25
+ print("=" * 60)
26
+ print("LocalDockerProvider End-to-End Test")
27
+ print("=" * 60)
28
+ print()
29
+
30
+ provider = None
31
+
32
+ try:
33
+ # Step 1: Create provider
34
+ print("Step 1: Creating LocalDockerProvider...")
35
+ provider = LocalDockerProvider()
36
+ print("✓ Provider created\n")
37
+
38
+ # Step 2: Start container
39
+ print("Step 2: Starting echo-env container...")
40
+ base_url = provider.start_container("echo-env:latest")
41
+ print(f"✓ Container started at: {base_url}")
42
+ if provider._container_id:
43
+ print(f" Container ID: {provider._container_id[:12]}...")
44
+ if provider._container_name:
45
+ print(f" Container name: {provider._container_name}\n")
46
+
47
+ # Step 3: Wait for ready
48
+ print("Step 3: Waiting for container to be ready...")
49
+ provider.wait_for_ready(base_url, timeout_s=30.0)
50
+ print("✓ Container is ready!\n")
51
+
52
+ # Step 4: Test health endpoint
53
+ print("Step 4: Testing /health endpoint...")
54
+ response = requests.get(f"{base_url}/health")
55
+ print(f" Status: {response.status_code}")
56
+ print(f" Response: {response.json()}")
57
+ assert response.status_code == 200
58
+ assert response.json()["status"] == "healthy"
59
+ print("✓ Health check passed\n")
60
+
61
+ # Step 5: Test reset endpoint
62
+ print("Step 5: Testing /reset endpoint...")
63
+ response = requests.post(
64
+ f"{base_url}/reset",
65
+ json={},
66
+ headers={"Content-Type": "application/json"},
67
+ )
68
+ print(f" Status: {response.status_code}")
69
+ data = response.json()
70
+ print(f" Message: {data['observation']['echoed_message']}")
71
+ print(f" Reward: {data['reward']}")
72
+ print(f" Done: {data['done']}")
73
+ assert response.status_code == 200
74
+ assert data["observation"]["echoed_message"] == "Echo environment ready!"
75
+ print("✓ Reset test passed\n")
76
+
77
+ # Step 6: Test step endpoint
78
+ print("Step 6: Testing /step endpoint...")
79
+ response = requests.post(
80
+ f"{base_url}/step",
81
+ json={"action": {"message": "Hello from LocalDockerProvider!"}},
82
+ headers={"Content-Type": "application/json"},
83
+ )
84
+ print(f" Status: {response.status_code}")
85
+ data = response.json()
86
+ print(f" Echoed: {data['observation']['echoed_message']}")
87
+ print(f" Length: {data['observation']['message_length']}")
88
+ print(f" Reward: {data['reward']}")
89
+ assert response.status_code == 200
90
+ assert (
91
+ data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!"
92
+ )
93
+ assert data["observation"]["message_length"] == 31
94
+ print("✓ Step test passed\n")
95
+
96
+ # Step 7: Test state endpoint
97
+ print("Step 7: Testing /state endpoint...")
98
+ response = requests.get(f"{base_url}/state")
99
+ print(f" Status: {response.status_code}")
100
+ data = response.json()
101
+ print(f" Episode ID: {data['episode_id']}")
102
+ print(f" Step count: {data['step_count']}")
103
+ assert response.status_code == 200
104
+ assert data["step_count"] == 1 # One step from above
105
+ print("✓ State test passed\n")
106
+
107
+ # Step 8: Multiple steps
108
+ print("Step 8: Testing multiple steps...")
109
+ for i in range(3):
110
+ response = requests.post(
111
+ f"{base_url}/step",
112
+ json={"action": {"message": f"Message {i + 1}"}},
113
+ headers={"Content-Type": "application/json"},
114
+ )
115
+ assert response.status_code == 200
116
+ print(f" Step {i + 1}: ✓")
117
+
118
+ # Check state updated
119
+ response = requests.get(f"{base_url}/state")
120
+ data = response.json()
121
+ assert data["step_count"] == 4 # 1 + 3 more steps
122
+ print(f" Final step count: {data['step_count']}")
123
+ print("✓ Multiple steps test passed\n")
124
+
125
+ print("=" * 60)
126
+ print("✓ All tests passed!")
127
+ print("=" * 60)
128
+ print()
129
+
130
+ return True
131
+
132
+ except Exception as e:
133
+ print(f"\n❌ Test failed: {e}")
134
+ import traceback
135
+
136
+ traceback.print_exc()
137
+ return False
138
+
139
+ finally:
140
+ # Step 9: Cleanup
141
+ if provider is not None:
142
+ print("\nStep 9: Cleaning up container...")
143
+ try:
144
+ provider.stop_container()
145
+ print("✓ Container stopped and removed\n")
146
+ except Exception as e:
147
+ print(f"⚠️ Cleanup warning: {e}\n")
148
+
149
+
150
+ def test_provider_with_custom_port():
151
+ """Test provider with custom port."""
152
+ print("=" * 60)
153
+ print("LocalDockerProvider with Custom Port Test")
154
+ print("=" * 60)
155
+ print()
156
+
157
+ provider = None
158
+
159
+ try:
160
+ provider = LocalDockerProvider()
161
+
162
+ print("Starting container on custom port 8123...")
163
+ base_url = provider.start_container("echo-env:latest", port=8123)
164
+ print(f"✓ Started at: {base_url}")
165
+ assert ":8123" in base_url
166
+
167
+ print("Waiting for ready...")
168
+ provider.wait_for_ready(base_url)
169
+ print("✓ Ready!")
170
+
171
+ print("Testing health...")
172
+ response = requests.get(f"{base_url}/health")
173
+ assert response.status_code == 200
174
+ print("✓ Health check passed")
175
+
176
+ print("\n✓ Custom port test passed!\n")
177
+ return True
178
+
179
+ except Exception as e:
180
+ print(f"\n❌ Test failed: {e}")
181
+ return False
182
+
183
+ finally:
184
+ if provider is not None:
185
+ provider.stop_container()
186
+ print("✓ Cleaned up\n")
187
+
188
+
189
+ def test_provider_with_env_vars():
190
+ """Test provider with environment variables."""
191
+ print("=" * 60)
192
+ print("LocalDockerProvider with Environment Variables Test")
193
+ print("=" * 60)
194
+ print()
195
+
196
+ provider = None
197
+
198
+ try:
199
+ provider = LocalDockerProvider()
200
+
201
+ print("Starting container with environment variables...")
202
+ base_url = provider.start_container(
203
+ "echo-env:latest", env_vars={"DEBUG": "true", "LOG_LEVEL": "info"}
204
+ )
205
+ print(f"✓ Started at: {base_url}")
206
+
207
+ print("Waiting for ready...")
208
+ provider.wait_for_ready(base_url)
209
+ print("✓ Ready!")
210
+
211
+ print("Testing health...")
212
+ response = requests.get(f"{base_url}/health")
213
+ assert response.status_code == 200
214
+ print("✓ Health check passed")
215
+
216
+ print("\n✓ Environment variables test passed!\n")
217
+ return True
218
+
219
+ except Exception as e:
220
+ print(f"\n❌ Test failed: {e}")
221
+ return False
222
+
223
+ finally:
224
+ if provider is not None:
225
+ provider.stop_container()
226
+ print("✓ Cleaned up\n")
227
+
228
+
229
+ if __name__ == "__main__":
230
+ print()
231
+ print("🐳 LocalDockerProvider Test Suite")
232
+ print()
233
+
234
+ results = []
235
+
236
+ # Run basic test
237
+ results.append(("Basic End-to-End", test_local_docker_provider()))
238
+
239
+ # Run custom port test
240
+ results.append(("Custom Port", test_provider_with_custom_port()))
241
+
242
+ # Run environment variables test
243
+ results.append(("Environment Variables", test_provider_with_env_vars()))
244
+
245
+ # Summary
246
+ print("=" * 60)
247
+ print("Test Summary")
248
+ print("=" * 60)
249
+ for name, passed in results:
250
+ status = "✓ PASSED" if passed else "✗ FAILED"
251
+ print(f"{name:25} {status}")
252
+ print("=" * 60)
253
+
254
+ all_passed = all(result for _, result in results)
255
+ if all_passed:
256
+ print("\n🎉 All tests passed!")
257
+ exit(0)
258
+ else:
259
+ print("\n❌ Some tests failed")
260
+ exit(1)
src/core/env_client.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Environment client for persistent sessions.
9
+
10
+ This module provides a WebSocket-based client that maintains a persistent connection
11
+ to an environment server, enabling efficient multi-step interactions without
12
+ the overhead of HTTP request/response cycles.
13
+
14
+ The client is async by default. For synchronous usage, use the `.sync()` method
15
+ to get a `SyncEnvClient` wrapper.
16
+
17
+ Example (async):
18
+ >>> async with GenericEnvClient(base_url="ws://localhost:8000") as env:
19
+ ... result = await env.reset()
20
+ ... result = await env.step({"code": "print('hello')"})
21
+
22
+ Example (sync wrapper):
23
+ >>> env = GenericEnvClient(base_url="ws://localhost:8000").sync()
24
+ >>> with env:
25
+ ... result = env.reset()
26
+ ... result = env.step({"code": "print('hello')"})
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import asyncio
32
+ import json
33
+ import os
34
+ from abc import ABC, abstractmethod
35
+ from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar
36
+
37
+ from .client_types import StateT, StepResult
38
+ from .containers.runtime import LocalDockerProvider, UVProvider
39
+ from .utils import convert_to_ws_url
40
+
41
+ if TYPE_CHECKING:
42
+ from websockets.asyncio.client import ClientConnection
43
+
44
+ from .containers.runtime import ContainerProvider, RuntimeProvider
45
+ from .sync_client import SyncEnvClient
46
+
47
+ from websockets.asyncio.client import connect as ws_connect
48
+
49
+ ActT = TypeVar("ActT")
50
+ ObsT = TypeVar("ObsT")
51
+ EnvClientT = TypeVar("EnvClientT", bound="EnvClient")
52
+
53
+
54
+ class EnvClient(ABC, Generic[ActT, ObsT, StateT]):
55
+ """
56
+ Async environment client for persistent sessions.
57
+
58
+ This client maintains a persistent WebSocket connection to an environment
59
+ server, enabling efficient multi-step interactions. Each client instance
60
+ corresponds to a dedicated environment session on the server.
61
+
62
+ The client is async by default. For synchronous usage, use the `.sync()`
63
+ method to get a `SyncEnvClient` wrapper.
64
+
65
+ Features:
66
+ - Lower latency for sequential interactions
67
+ - Session state is maintained server-side
68
+ - Better suited for long-running episodes
69
+ - Async by default for modern Python async/await patterns
70
+
71
+ Example (async):
72
+ >>> from envs.coding_env.client import CodingEnv
73
+ >>>
74
+ >>> # Connect to a server using async context manager
75
+ >>> async with CodingEnv(base_url="ws://localhost:8000") as env:
76
+ ... result = await env.reset(seed=42)
77
+ ... while not result.done:
78
+ ... action = agent.predict(result.observation)
79
+ ... result = await env.step(action)
80
+
81
+ Example (sync wrapper):
82
+ >>> env = CodingEnv(base_url="ws://localhost:8000").sync()
83
+ >>> with env:
84
+ ... result = env.reset(seed=42)
85
+ ... result = env.step(action)
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ base_url: str,
91
+ connect_timeout_s: float = 10.0,
92
+ message_timeout_s: float = 60.0,
93
+ max_message_size_mb: float = 100.0,
94
+ provider: Optional["ContainerProvider | RuntimeProvider"] = None,
95
+ mode: Optional[str] = None,
96
+ ):
97
+ """
98
+ Initialize environment client.
99
+
100
+ Args:
101
+ base_url: Base URL of the environment server (http:// or ws://).
102
+ Will be converted to ws:// if http:// is provided.
103
+ connect_timeout_s: Timeout for establishing WebSocket connection
104
+ message_timeout_s: Timeout for receiving responses to messages
105
+ max_message_size_mb: Maximum WebSocket message size in megabytes.
106
+ Default 100MB to handle large observations (screenshots, DOM, etc.)
107
+ provider: Optional container/runtime provider for lifecycle management.
108
+ Can be a ContainerProvider (Docker) or RuntimeProvider (UV).
109
+ mode: Communication mode: 'simulation' for Gym-style API (default) or
110
+ 'production' for MCP JSON-RPC protocol. Can also be set via the
111
+ OPENENV_CLIENT_MODE environment variable. Constructor parameter
112
+ takes precedence over environment variable. Case-insensitive.
113
+ """
114
+ # Determine mode (constructor > env var > default)
115
+ if mode is None:
116
+ mode = os.environ.get("OPENENV_CLIENT_MODE", "simulation")
117
+
118
+ # Normalize and validate mode
119
+ mode = mode.lower()
120
+ if mode not in ("simulation", "production"):
121
+ raise ValueError(
122
+ f"Invalid mode: '{mode}'. Must be 'simulation' or 'production'. "
123
+ f"Set via constructor parameter or OPENENV_CLIENT_MODE environment variable."
124
+ )
125
+
126
+ # Store mode (use object.__setattr__ to bypass immutability)
127
+ object.__setattr__(self, "_mode", mode)
128
+
129
+ # Convert HTTP URL to WebSocket URL
130
+ ws_url = convert_to_ws_url(base_url)
131
+
132
+ self._ws_url = f"{ws_url}/ws"
133
+ self._connect_timeout = connect_timeout_s
134
+ self._message_timeout = message_timeout_s
135
+ self._max_message_size = int(
136
+ max_message_size_mb * 1024 * 1024
137
+ ) # Convert MB to bytes
138
+ self._provider = provider
139
+ self._ws: Optional[ClientConnection] = None
140
+
141
+ def __setattr__(self, name: str, value: Any) -> None:
142
+ """Prevent modification of _mode after initialization."""
143
+ if name == "_mode" and hasattr(self, "_mode"):
144
+ raise AttributeError("Cannot modify mode after initialization")
145
+ super().__setattr__(name, value)
146
+
147
+ async def connect(self) -> "EnvClient":
148
+ """
149
+ Establish WebSocket connection to the server.
150
+
151
+ Returns:
152
+ self for method chaining
153
+
154
+ Raises:
155
+ ConnectionError: If connection cannot be established
156
+ """
157
+ if self._ws is not None:
158
+ return self
159
+
160
+ # Bypass proxy for localhost connections
161
+ ws_url_lower = self._ws_url.lower()
162
+ is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
163
+
164
+ old_no_proxy = os.environ.get("NO_PROXY")
165
+ if is_localhost:
166
+ # Set NO_PROXY to bypass proxy for localhost
167
+ current_no_proxy = old_no_proxy or ""
168
+ if "localhost" not in current_no_proxy.lower():
169
+ os.environ["NO_PROXY"] = (
170
+ f"{current_no_proxy},localhost,127.0.0.1"
171
+ if current_no_proxy
172
+ else "localhost,127.0.0.1"
173
+ )
174
+
175
+ try:
176
+ self._ws = await ws_connect(
177
+ self._ws_url,
178
+ open_timeout=self._connect_timeout,
179
+ max_size=self._max_message_size,
180
+ )
181
+ except Exception as e:
182
+ raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
183
+ finally:
184
+ # Restore original NO_PROXY value
185
+ if is_localhost:
186
+ if old_no_proxy is None:
187
+ os.environ.pop("NO_PROXY", None)
188
+ else:
189
+ os.environ["NO_PROXY"] = old_no_proxy
190
+
191
+ return self
192
+
193
+ async def disconnect(self) -> None:
194
+ """Close the WebSocket connection."""
195
+ if self._ws is not None:
196
+ try:
197
+ # Send close message
198
+ await self._send({"type": "close"})
199
+ except Exception:
200
+ pass # Best effort
201
+ try:
202
+ await self._ws.close()
203
+ except Exception:
204
+ pass
205
+ self._ws = None
206
+
207
+ async def _ensure_connected(self) -> None:
208
+ """Ensure WebSocket connection is established."""
209
+ if self._ws is None:
210
+ await self.connect()
211
+
212
+ async def _send(self, message: Dict[str, Any]) -> None:
213
+ """Send a message over the WebSocket."""
214
+ await self._ensure_connected()
215
+ assert self._ws is not None
216
+ await self._ws.send(json.dumps(message))
217
+
218
+ async def _receive(self) -> Dict[str, Any]:
219
+ """Receive and parse a message from the WebSocket."""
220
+ assert self._ws is not None
221
+ raw = await asyncio.wait_for(self._ws.recv(), timeout=self._message_timeout)
222
+ return json.loads(raw)
223
+
224
+ async def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]:
225
+ """Send a message and wait for response."""
226
+ await self._send(message)
227
+ response = await self._receive()
228
+
229
+ # Check for error response
230
+ if response.get("type") == "error":
231
+ error_data = response.get("data", {})
232
+ raise RuntimeError(
233
+ f"Server error: {error_data.get('message', 'Unknown error')} "
234
+ f"(code: {error_data.get('code', 'UNKNOWN')})"
235
+ )
236
+
237
+ return response
238
+
239
+ @classmethod
240
+ async def from_docker_image(
241
+ cls: Type[EnvClientT],
242
+ image: str,
243
+ provider: Optional["ContainerProvider"] = None,
244
+ **kwargs: Any,
245
+ ) -> EnvClientT:
246
+ """
247
+ Create an environment client by spinning up a Docker container.
248
+
249
+ Args:
250
+ image: Docker image name to run (e.g., "coding-env:latest")
251
+ provider: Container provider to use (defaults to LocalDockerProvider)
252
+ **kwargs: Additional arguments to pass to provider.start_container()
253
+
254
+ Returns:
255
+ Connected client instance
256
+ """
257
+ if provider is None:
258
+ provider = LocalDockerProvider()
259
+
260
+ # Start container
261
+ base_url = provider.start_container(image, **kwargs)
262
+
263
+ # Wait for server to be ready
264
+ provider.wait_for_ready(base_url)
265
+
266
+ # Create and connect client
267
+ client = cls(base_url=base_url, provider=provider)
268
+ await client.connect()
269
+
270
+ return client
271
+
272
+ @classmethod
273
+ async def from_env(
274
+ cls: Type[EnvClientT],
275
+ repo_id: str,
276
+ *,
277
+ use_docker: bool = True,
278
+ provider: Optional["ContainerProvider | RuntimeProvider"] = None,
279
+ **provider_kwargs: Any,
280
+ ) -> EnvClientT:
281
+ """
282
+ Create a client from a Hugging Face Space.
283
+
284
+ Args:
285
+ repo_id: Hugging Face space identifier ``{org}/{space}``.
286
+ use_docker: When ``True`` (default) pull from the HF registry and
287
+ launch via :class:`LocalDockerProvider`. When ``False`` run the
288
+ space locally with :class:`UVProvider`.
289
+ provider: Optional provider instance to reuse. Must be a
290
+ :class:`ContainerProvider` when ``use_docker=True`` and a
291
+ :class:`RuntimeProvider` otherwise.
292
+ provider_kwargs: Additional keyword arguments forwarded to
293
+ either the container provider's ``start_container`` (docker)
294
+ or to the ``UVProvider`` constructor/start (uv). When
295
+ ``use_docker=False``, the ``project_path`` argument can be
296
+ used to override the default git URL
297
+ (``git+https://huggingface.co/spaces/{repo_id}``).
298
+
299
+ Returns:
300
+ Connected client instance
301
+
302
+ Examples:
303
+ >>> # Pull and run from HF Docker registry
304
+ >>> env = await MyEnv.from_env("openenv/echo-env")
305
+ >>>
306
+ >>> # Run locally with UV (clones the space)
307
+ >>> env = await MyEnv.from_env("openenv/echo-env", use_docker=False)
308
+ >>>
309
+ >>> # Run from a local checkout
310
+ >>> env = await MyEnv.from_env(
311
+ ... "openenv/echo-env",
312
+ ... use_docker=False,
313
+ ... project_path="/path/to/local/checkout"
314
+ ... )
315
+ """
316
+ # Extract start args that apply to both providers
317
+ start_args = {}
318
+ for key in ("port", "env_vars", "workers"):
319
+ if key in provider_kwargs:
320
+ start_args[key] = provider_kwargs.pop(key)
321
+
322
+ if use_docker:
323
+ # Docker mode: pull from HF registry
324
+ docker_provider = provider or LocalDockerProvider()
325
+ tag = provider_kwargs.pop("tag", "latest")
326
+ image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}"
327
+ base_url = docker_provider.start_container(
328
+ image, **start_args, **provider_kwargs
329
+ )
330
+ docker_provider.wait_for_ready(base_url)
331
+
332
+ client = cls(base_url=base_url, provider=docker_provider)
333
+ await client.connect()
334
+ return client
335
+ else:
336
+ # UV mode: clone and run with uv
337
+ if provider is None:
338
+ uv_kwargs = dict(provider_kwargs)
339
+ project_path = uv_kwargs.pop("project_path", None)
340
+ if project_path is None:
341
+ project_path = f"git+https://huggingface.co/spaces/{repo_id}"
342
+
343
+ provider = UVProvider(project_path=project_path, **uv_kwargs)
344
+ else:
345
+ if provider_kwargs:
346
+ raise ValueError(
347
+ "provider_kwargs cannot be used when supplying a provider instance"
348
+ )
349
+
350
+ base_url = provider.start(**start_args)
351
+ provider.wait_for_ready()
352
+
353
+ client = cls(base_url=base_url, provider=provider)
354
+ await client.connect()
355
+ return client
356
+
357
+ @abstractmethod
358
+ def _step_payload(self, action: ActT) -> Dict[str, Any]:
359
+ """Convert an Action object to the JSON data expected by the env server."""
360
+ raise NotImplementedError
361
+
362
+ @abstractmethod
363
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]:
364
+ """Convert a JSON response from the env server to StepResult[ObsT]."""
365
+ raise NotImplementedError
366
+
367
+ @abstractmethod
368
+ def _parse_state(self, payload: Dict[str, Any]) -> StateT:
369
+ """Convert a JSON response from the state endpoint to a State object."""
370
+ raise NotImplementedError
371
+
372
+ async def reset(self, **kwargs: Any) -> StepResult[ObsT]:
373
+ """
374
+ Reset the environment with optional parameters.
375
+
376
+ Args:
377
+ **kwargs: Optional parameters passed to the environment's reset method.
378
+ Common parameters include:
379
+ - seed: Random seed for reproducibility
380
+ - episode_id: Custom episode identifier
381
+
382
+ Returns:
383
+ StepResult containing initial observation
384
+ """
385
+ message = {
386
+ "type": "reset",
387
+ "data": kwargs,
388
+ }
389
+ response = await self._send_and_receive(message)
390
+ return self._parse_result(response.get("data", {}))
391
+
392
+ async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]:
393
+ """
394
+ Execute an action in the environment.
395
+
396
+ Args:
397
+ action: The action to execute
398
+ **kwargs: Optional parameters (currently ignored)
399
+
400
+ Returns:
401
+ StepResult containing observation, reward, and done status
402
+ """
403
+ message = {
404
+ "type": "step",
405
+ "data": self._step_payload(action),
406
+ }
407
+ response = await self._send_and_receive(message)
408
+ return self._parse_result(response.get("data", {}))
409
+
410
+ async def state(self) -> StateT:
411
+ """
412
+ Get the current environment state from the server.
413
+
414
+ Returns:
415
+ State object with environment state information
416
+ """
417
+ message = {"type": "state"}
418
+ response = await self._send_and_receive(message)
419
+ return self._parse_state(response.get("data", {}))
420
+
421
+ async def close(self) -> None:
422
+ """
423
+ Close the WebSocket connection and clean up resources.
424
+
425
+ If this client was created via from_docker_image() or from_env(),
426
+ this will also stop and remove the associated container/process.
427
+ """
428
+ await self.disconnect()
429
+
430
+ if self._provider is not None:
431
+ # Handle both ContainerProvider and RuntimeProvider
432
+ if hasattr(self._provider, "stop_container"):
433
+ self._provider.stop_container()
434
+ elif hasattr(self._provider, "stop"):
435
+ self._provider.stop()
436
+
437
+ async def __aenter__(self) -> "EnvClient":
438
+ """Enter async context manager, ensuring connection is established."""
439
+ await self.connect()
440
+ return self
441
+
442
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
443
+ """Exit async context manager, closing connection."""
444
+ await self.close()
445
+
446
+ def __enter__(self) -> "EnvClient":
447
+ """Sync context manager entry - raises error suggesting async usage."""
448
+ raise TypeError(
449
+ "EnvClient is async by default. Use 'async with' instead of 'with', "
450
+ "or call .sync() to get a synchronous wrapper:\n"
451
+ " async with client: # async usage\n"
452
+ " with client.sync(): # sync wrapper"
453
+ )
454
+
455
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
456
+ """Sync context manager exit - should not be reached."""
457
+ pass # pragma: no cover
458
+
459
+ def sync(self) -> "SyncEnvClient":
460
+ """
461
+ Return a synchronous wrapper around this async client.
462
+
463
+ Use this method when you need synchronous access to the environment
464
+ without async/await syntax. This is useful for:
465
+ - Integration with synchronous codebases
466
+ - Interactive/REPL usage
467
+ - Stopping async from "infecting" the call stack
468
+
469
+ Returns:
470
+ SyncEnvClient wrapper that provides synchronous methods
471
+
472
+ Example:
473
+ >>> # Create async client and get sync wrapper
474
+ >>> async_client = GenericEnvClient(base_url="http://localhost:8000")
475
+ >>> sync_client = async_client.sync()
476
+ >>>
477
+ >>> # Use synchronous API
478
+ >>> with sync_client:
479
+ ... result = sync_client.reset()
480
+ ... result = sync_client.step({"code": "print('hello')"})
481
+ """
482
+ from .sync_client import SyncEnvClient
483
+
484
+ return SyncEnvClient(self)
src/core/env_server/__init__.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Core environment interfaces and types."""
8
+
9
+ from .base_transforms import CompositeTransform, NullTransform
10
+ from .exceptions import (
11
+ ConcurrencyConfigurationError,
12
+ EnvironmentFactoryError,
13
+ OpenEnvError,
14
+ SessionCapacityError,
15
+ SessionCreationError,
16
+ SessionNotFoundError,
17
+ )
18
+ from .http_server import create_app, create_fastapi_app, HTTPEnvServer
19
+ from .interfaces import Environment, Message, ModelTokenizer, Transform
20
+
21
+ try:
22
+ from .mcp_environment import MCPEnvironment
23
+ except ModuleNotFoundError:
24
+ MCPEnvironment = None # type: ignore[assignment]
25
+
26
+ from .mcp_types import (
27
+ CallToolAction,
28
+ CallToolObservation,
29
+ JsonRpcError,
30
+ # JSON-RPC types
31
+ JsonRpcErrorCode,
32
+ JsonRpcRequest,
33
+ JsonRpcResponse,
34
+ ListToolsAction,
35
+ ListToolsObservation,
36
+ McpMethod,
37
+ RESERVED_TOOL_NAMES,
38
+ Tool,
39
+ ToolError,
40
+ ToolErrorType,
41
+ WSMCPMessage,
42
+ WSMCPResponse,
43
+ )
44
+ from .route_config import GetEndpointConfig
45
+ from .serialization import (
46
+ deserialize_action,
47
+ deserialize_action_with_preprocessing,
48
+ serialize_observation,
49
+ )
50
+ from .types import (
51
+ Action,
52
+ BaseMessage,
53
+ ConcurrencyConfig,
54
+ HealthResponse,
55
+ HealthStatus,
56
+ Observation,
57
+ SchemaResponse,
58
+ ServerCapacityStatus,
59
+ ServerMode,
60
+ SessionInfo,
61
+ State,
62
+ WSCloseMessage,
63
+ WSErrorCode,
64
+ WSErrorResponse,
65
+ WSIncomingMessage,
66
+ WSObservationResponse,
67
+ WSResetMessage,
68
+ WSStateMessage,
69
+ WSStateResponse,
70
+ WSStepMessage,
71
+ )
72
+
73
+ try:
74
+ from .web_interface import create_web_interface_app, WebInterfaceManager
75
+ except ModuleNotFoundError:
76
+ create_web_interface_app = None # type: ignore[assignment]
77
+ WebInterfaceManager = None # type: ignore[assignment]
78
+
79
+ __all__ = [
80
+ # Core interfaces
81
+ "Environment",
82
+ "Transform",
83
+ "Message",
84
+ "ModelTokenizer",
85
+ # Types
86
+ "Action",
87
+ "Observation",
88
+ "State",
89
+ "SchemaResponse",
90
+ "HealthResponse",
91
+ # Enums
92
+ "HealthStatus",
93
+ "ServerMode",
94
+ "WSErrorCode",
95
+ # WebSocket message types
96
+ "BaseMessage",
97
+ "WSIncomingMessage",
98
+ "WSResetMessage",
99
+ "WSStepMessage",
100
+ "WSStateMessage",
101
+ "WSCloseMessage",
102
+ "WSObservationResponse",
103
+ "WSStateResponse",
104
+ "WSErrorResponse",
105
+ # Concurrency types
106
+ "ConcurrencyConfig",
107
+ "ServerCapacityStatus",
108
+ "SessionInfo",
109
+ # Exceptions
110
+ "OpenEnvError",
111
+ "ConcurrencyConfigurationError",
112
+ "SessionCapacityError",
113
+ "SessionNotFoundError",
114
+ "SessionCreationError",
115
+ "EnvironmentFactoryError",
116
+ # Base transforms
117
+ "CompositeTransform",
118
+ "NullTransform",
119
+ # HTTP Server
120
+ "HTTPEnvServer",
121
+ "create_app",
122
+ "create_fastapi_app",
123
+ # Web Interface
124
+ "create_web_interface_app",
125
+ "WebInterfaceManager",
126
+ # Serialization utilities
127
+ "deserialize_action",
128
+ "deserialize_action_with_preprocessing",
129
+ "serialize_observation",
130
+ # Route configuration
131
+ "GetEndpointConfig",
132
+ # MCP types
133
+ "Tool",
134
+ "ToolError",
135
+ "ToolErrorType",
136
+ "ListToolsAction",
137
+ "CallToolAction",
138
+ "ListToolsObservation",
139
+ "CallToolObservation",
140
+ "WSMCPMessage",
141
+ "WSMCPResponse",
142
+ "RESERVED_TOOL_NAMES",
143
+ "MCPEnvironment",
144
+ # JSON-RPC types
145
+ "JsonRpcErrorCode",
146
+ "JsonRpcError",
147
+ "JsonRpcRequest",
148
+ "JsonRpcResponse",
149
+ "McpMethod",
150
+ ]
src/core/env_server/base_transforms.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Base transform implementations for composing environment-specific transforms."""
8
+
9
+ from .interfaces import Transform
10
+ from .types import Observation
11
+
12
+
13
+ class CompositeTransform(Transform):
14
+ """Combines multiple transforms into a single transform."""
15
+
16
+ def __init__(self, transforms: list[Transform]):
17
+ self.transforms = transforms
18
+
19
+ def __call__(self, observation: Observation) -> Observation:
20
+ for transform in self.transforms:
21
+ observation = transform(observation)
22
+ return observation
23
+
24
+
25
+ class NullTransform(Transform):
26
+ """Default transform that passes through unchanged."""
27
+
28
+ def __call__(self, observation: Observation) -> Observation:
29
+ return observation
src/core/env_server/exceptions.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Custom exceptions for environment server operations."""
8
+
9
+ from typing import Optional
10
+
11
+
12
+ class OpenEnvError(Exception):
13
+ """Base exception for all OpenEnv errors."""
14
+
15
+ pass
16
+
17
+
18
+ class ConcurrencyConfigurationError(OpenEnvError):
19
+ """
20
+ Raised when an environment is misconfigured for concurrent sessions.
21
+
22
+ This error is raised during server startup when max_concurrent_envs > 1
23
+ is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ environment_name: str,
29
+ max_concurrent_envs: int,
30
+ message: Optional[str] = None,
31
+ ):
32
+ self.environment_name = environment_name
33
+ self.max_concurrent_envs = max_concurrent_envs
34
+
35
+ if message is None:
36
+ message = (
37
+ f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. "
38
+ f"Cannot run with max_concurrent_envs={max_concurrent_envs}. "
39
+ f"Either set max_concurrent_envs=1 or ensure the environment "
40
+ f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True."
41
+ )
42
+
43
+ super().__init__(message)
44
+
45
+
46
+ class SessionCapacityError(OpenEnvError):
47
+ """
48
+ Raised when the server cannot accept new sessions due to capacity limits.
49
+
50
+ This error is raised when a new WebSocket connection is attempted but
51
+ the server has already reached max_concurrent_envs active sessions.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ active_sessions: int,
57
+ max_sessions: int,
58
+ message: Optional[str] = None,
59
+ ):
60
+ self.active_sessions = active_sessions
61
+ self.max_sessions = max_sessions
62
+
63
+ if message is None:
64
+ message = (
65
+ f"Server at capacity: {active_sessions}/{max_sessions} sessions active. "
66
+ f"Cannot accept new connections."
67
+ )
68
+
69
+ super().__init__(message)
70
+
71
+
72
+ class SessionNotFoundError(OpenEnvError):
73
+ """Raised when attempting to access a session that does not exist."""
74
+
75
+ def __init__(self, session_id: str, message: Optional[str] = None):
76
+ self.session_id = session_id
77
+
78
+ if message is None:
79
+ message = f"Session '{session_id}' not found."
80
+
81
+ super().__init__(message)
82
+
83
+
84
+ class SessionCreationError(OpenEnvError):
85
+ """Raised when a session cannot be created."""
86
+
87
+ def __init__(self, reason: str, message: Optional[str] = None):
88
+ self.reason = reason
89
+
90
+ if message is None:
91
+ message = f"Failed to create session: {reason}"
92
+
93
+ super().__init__(message)
94
+
95
+
96
+ class EnvironmentFactoryError(OpenEnvError):
97
+ """Raised when the environment factory fails to create an instance."""
98
+
99
+ def __init__(self, factory_name: str, message: Optional[str] = None):
100
+ self.factory_name = factory_name
101
+
102
+ if message is None:
103
+ message = f"Environment factory '{factory_name}' failed to create instance."
104
+
105
+ super().__init__(message)
src/core/env_server/gradio_theme.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Unified terminal-style theme for OpenEnv Gradio UI (light/dark)."""
8
+
9
+ from __future__ import annotations
10
+
11
+ import gradio as gr
12
+
13
+ _MONO_FONTS = (
14
+ "JetBrains Mono",
15
+ "Fira Code",
16
+ "Cascadia Code",
17
+ "Consolas",
18
+ "ui-monospace",
19
+ "monospace",
20
+ )
21
+
22
+ _CORE_FONT = (
23
+ "Lato",
24
+ "Inter",
25
+ "Arial",
26
+ "Helvetica",
27
+ "sans-serif",
28
+ )
29
+
30
+ _ZERO_RADIUS = gr.themes.Size(
31
+ xxs="0px",
32
+ xs="0px",
33
+ sm="0px",
34
+ md="0px",
35
+ lg="0px",
36
+ xl="0px",
37
+ xxl="0px",
38
+ )
39
+
40
+ _GREEN_HUE = gr.themes.Color(
41
+ c50="#e6f4ea",
42
+ c100="#ceead6",
43
+ c200="#a8dab5",
44
+ c300="#6fcc8b",
45
+ c400="#3fb950",
46
+ c500="#238636",
47
+ c600="#1a7f37",
48
+ c700="#116329",
49
+ c800="#0a4620",
50
+ c900="#033a16",
51
+ c950="#04200d",
52
+ )
53
+
54
+ _NEUTRAL_HUE = gr.themes.Color(
55
+ c50="#f6f8fa",
56
+ c100="#eaeef2",
57
+ c200="#d0d7de",
58
+ c300="#afb8c1",
59
+ c400="#8c959f",
60
+ c500="#6e7781",
61
+ c600="#57606a",
62
+ c700="#424a53",
63
+ c800="#32383f",
64
+ c900="#24292f",
65
+ c950="#1b1f24",
66
+ )
67
+
68
+ OPENENV_GRADIO_THEME = gr.themes.Base(
69
+ primary_hue=_GREEN_HUE,
70
+ secondary_hue=_NEUTRAL_HUE,
71
+ neutral_hue=_NEUTRAL_HUE,
72
+ font=_CORE_FONT,
73
+ font_mono=_MONO_FONTS,
74
+ radius_size=_ZERO_RADIUS,
75
+ ).set(
76
+ body_background_fill="#ffffff",
77
+ background_fill_primary="#ffffff",
78
+ background_fill_secondary="#f6f8fa",
79
+ block_background_fill="#ffffff",
80
+ block_border_color="#ffffff",
81
+ block_label_text_color="#57606a",
82
+ block_title_text_color="#24292f",
83
+ border_color_primary="#d0d7de",
84
+ input_background_fill="#ffffff",
85
+ input_border_color="#d0d7de",
86
+ button_primary_background_fill="#1a7f37",
87
+ button_primary_background_fill_hover="#116329",
88
+ button_primary_text_color="#ffffff",
89
+ button_secondary_background_fill="#f6f8fa",
90
+ button_secondary_background_fill_hover="#eaeef2",
91
+ button_secondary_text_color="#24292f",
92
+ button_secondary_border_color="#d0d7de",
93
+ body_background_fill_dark="#0d1117",
94
+ background_fill_primary_dark="#0d1117",
95
+ background_fill_secondary_dark="#0d1117",
96
+ block_background_fill_dark="#0d1117",
97
+ block_border_color_dark="#0d1117",
98
+ block_label_text_color_dark="#8b949e",
99
+ block_title_text_color_dark="#c9d1d9",
100
+ border_color_primary_dark="#30363d",
101
+ input_background_fill_dark="#0d1117",
102
+ input_border_color_dark="#30363d",
103
+ button_primary_background_fill_dark="#30363d",
104
+ button_primary_background_fill_hover_dark="#484f58",
105
+ button_primary_text_color_dark="#c9d1d9",
106
+ button_secondary_background_fill_dark="#21262d",
107
+ button_secondary_background_fill_hover_dark="#30363d",
108
+ button_secondary_text_color_dark="#c9d1d9",
109
+ button_secondary_border_color_dark="#30363d",
110
+ )
111
+
112
+ OPENENV_GRADIO_CSS = """
113
+ * { border-radius: 0 !important; }
114
+ .col-left { padding: 16px !important; }
115
+ .col-right { padding: 16px !important; }
116
+ .prose, .markdown-text, .md,
117
+ .prose > *, .markdown-text > * {
118
+ background: transparent !important;
119
+ border: none !important;
120
+ box-shadow: none !important;
121
+ }
122
+ .dark .col-left {
123
+ border-left-color: rgba(139, 148, 158, 0.4) !important;
124
+ }
125
+ .dark .col-right {
126
+ border-left-color: rgba(201, 209, 217, 0.3) !important;
127
+ }
128
+ """
src/core/env_server/gradio_ui.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Gradio-based web UI for OpenEnv environments.
9
+
10
+ Replaces the legacy HTML/JavaScript interface when ENABLE_WEB_INTERFACE is set.
11
+ Mount at /web via gr.mount_gradio_app() from create_web_interface_app().
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import re
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ import gradio as gr
21
+
22
+ from .types import EnvironmentMetadata
23
+
24
+
25
+ def _escape_md(text: str) -> str:
26
+ """Escape Markdown special characters in user-controlled content."""
27
+ return re.sub(r"([\\`*_\{\}\[\]()#+\-.!|~>])", r"\\\1", str(text))
28
+
29
+
30
+ def _format_observation(data: Dict[str, Any]) -> str:
31
+ """Format reset/step response for Markdown display."""
32
+ lines: List[str] = []
33
+ obs = data.get("observation", {})
34
+ if isinstance(obs, dict):
35
+ if obs.get("prompt"):
36
+ lines.append(f"**Prompt:**\n\n{_escape_md(obs['prompt'])}\n")
37
+ messages = obs.get("messages", [])
38
+ if messages:
39
+ lines.append("**Messages:**\n")
40
+ for msg in messages:
41
+ sender = _escape_md(str(msg.get("sender_id", "?")))
42
+ content = _escape_md(str(msg.get("content", "")))
43
+ cat = _escape_md(str(msg.get("category", "")))
44
+ lines.append(f"- `[{cat}]` Player {sender}: {content}")
45
+ lines.append("")
46
+ reward = data.get("reward")
47
+ done = data.get("done")
48
+ if reward is not None:
49
+ lines.append(f"**Reward:** `{reward}`")
50
+ if done is not None:
51
+ lines.append(f"**Done:** `{done}`")
52
+ return "\n".join(lines) if lines else "*No observation data*"
53
+
54
+
55
+ def _readme_section(metadata: Optional[EnvironmentMetadata]) -> str:
56
+ """README content for the left panel."""
57
+ if not metadata or not metadata.readme_content:
58
+ return "*No README available.*"
59
+ return metadata.readme_content
60
+
61
+
62
+ def get_gradio_display_title(
63
+ metadata: Optional[EnvironmentMetadata],
64
+ fallback: str = "OpenEnv Environment",
65
+ ) -> str:
66
+ """Return the title used for the Gradio app (browser tab and Blocks)."""
67
+ name = metadata.name if metadata else fallback
68
+ return f"OpenEnv Agentic Environment: {name}"
69
+
70
+
71
+ def build_gradio_app(
72
+ web_manager: Any,
73
+ action_fields: List[Dict[str, Any]],
74
+ metadata: Optional[EnvironmentMetadata],
75
+ is_chat_env: bool,
76
+ title: str = "OpenEnv Environment",
77
+ quick_start_md: Optional[str] = None,
78
+ ) -> gr.Blocks:
79
+ """
80
+ Build a Gradio Blocks app for the OpenEnv web interface.
81
+
82
+ Args:
83
+ web_manager: WebInterfaceManager (reset/step_environment, get_state).
84
+ action_fields: Field dicts from _extract_action_fields(action_cls).
85
+ metadata: Environment metadata for README/name.
86
+ is_chat_env: If True, single message textbox; else form from action_fields.
87
+ title: App title (overridden by metadata.name when present; see get_gradio_display_title).
88
+ quick_start_md: Optional Quick Start markdown (class names already replaced).
89
+
90
+ Returns:
91
+ gr.Blocks to mount with gr.mount_gradio_app(app, blocks, path="/web").
92
+ """
93
+ readme_content = _readme_section(metadata)
94
+ display_title = get_gradio_display_title(metadata, fallback=title)
95
+
96
+ async def reset_env():
97
+ try:
98
+ data = await web_manager.reset_environment()
99
+ obs_md = _format_observation(data)
100
+ return (
101
+ obs_md,
102
+ json.dumps(data, indent=2),
103
+ "Environment reset successfully.",
104
+ )
105
+ except Exception as e:
106
+ return ("", "", f"Error: {e}")
107
+
108
+ def _step_with_action(action_data: Dict[str, Any]):
109
+ async def _run():
110
+ try:
111
+ data = await web_manager.step_environment(action_data)
112
+ obs_md = _format_observation(data)
113
+ return (
114
+ obs_md,
115
+ json.dumps(data, indent=2),
116
+ "Step complete.",
117
+ )
118
+ except Exception as e:
119
+ return ("", "", f"Error: {e}")
120
+
121
+ return _run
122
+
123
+ async def step_chat(message: str):
124
+ if not (message or str(message).strip()):
125
+ return ("", "", "Please enter an action message.")
126
+ action = {"message": str(message).strip()}
127
+ return await _step_with_action(action)()
128
+
129
+ def get_state_sync():
130
+ try:
131
+ data = web_manager.get_state()
132
+ return json.dumps(data, indent=2)
133
+ except Exception as e:
134
+ return f"Error: {e}"
135
+
136
+ with gr.Blocks(title=display_title) as demo:
137
+ with gr.Row():
138
+ with gr.Column(scale=1, elem_classes="col-left"):
139
+ if quick_start_md:
140
+ with gr.Accordion("Quick Start", open=True):
141
+ gr.Markdown(quick_start_md)
142
+ with gr.Accordion("README", open=False):
143
+ gr.Markdown(readme_content)
144
+
145
+ with gr.Column(scale=2, elem_classes="col-right"):
146
+ obs_display = gr.Markdown(
147
+ value=("# Playground\n\nClick **Reset** to start a new episode."),
148
+ )
149
+ with gr.Group():
150
+ if is_chat_env:
151
+ action_input = gr.Textbox(
152
+ label="Action message",
153
+ placeholder="e.g. Enter your message...",
154
+ )
155
+ step_inputs = [action_input]
156
+ step_fn = step_chat
157
+ else:
158
+ step_inputs = []
159
+ for field in action_fields:
160
+ name = field["name"]
161
+ field_type = field.get("type", "text")
162
+ label = name.replace("_", " ").title()
163
+ placeholder = field.get("placeholder", "")
164
+ if field_type == "checkbox":
165
+ inp = gr.Checkbox(label=label)
166
+ elif field_type == "number":
167
+ inp = gr.Number(label=label)
168
+ elif field_type == "select":
169
+ choices = field.get("choices") or []
170
+ inp = gr.Dropdown(
171
+ choices=choices,
172
+ label=label,
173
+ allow_custom_value=False,
174
+ )
175
+ elif field_type in ("textarea", "tensor"):
176
+ inp = gr.Textbox(
177
+ label=label,
178
+ placeholder=placeholder,
179
+ lines=3,
180
+ )
181
+ else:
182
+ inp = gr.Textbox(
183
+ label=label,
184
+ placeholder=placeholder,
185
+ )
186
+ step_inputs.append(inp)
187
+
188
+ async def step_form(*values):
189
+ if not action_fields:
190
+ return await _step_with_action({})()
191
+ action_data = {}
192
+ for i, field in enumerate(action_fields):
193
+ if i >= len(values):
194
+ break
195
+ name = field["name"]
196
+ val = values[i]
197
+ if field.get("type") == "checkbox":
198
+ action_data[name] = bool(val)
199
+ elif val is not None and val != "":
200
+ action_data[name] = val
201
+ return await _step_with_action(action_data)()
202
+
203
+ step_fn = step_form
204
+
205
+ with gr.Row():
206
+ step_btn = gr.Button("Step", variant="primary")
207
+ reset_btn = gr.Button("Reset", variant="secondary")
208
+ state_btn = gr.Button("Get state", variant="secondary")
209
+ with gr.Row():
210
+ status = gr.Textbox(
211
+ label="Status",
212
+ interactive=False,
213
+ )
214
+ raw_json = gr.Code(
215
+ label="Raw JSON response",
216
+ language="json",
217
+ interactive=False,
218
+ )
219
+
220
+ reset_btn.click(
221
+ fn=reset_env,
222
+ outputs=[obs_display, raw_json, status],
223
+ )
224
+ step_btn.click(
225
+ fn=step_fn,
226
+ inputs=step_inputs,
227
+ outputs=[obs_display, raw_json, status],
228
+ )
229
+ if is_chat_env:
230
+ action_input.submit(
231
+ fn=step_fn,
232
+ inputs=step_inputs,
233
+ outputs=[obs_display, raw_json, status],
234
+ )
235
+ state_btn.click(
236
+ fn=get_state_sync,
237
+ outputs=[raw_json],
238
+ )
239
+
240
+ return demo
src/core/env_server/http_server.py ADDED
@@ -0,0 +1,1391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ HTTP server wrapper for Environment instances.
9
+
10
+ This module provides utilities to wrap any Environment subclass and expose it
11
+ over HTTP and WebSocket endpoints that EnvClient can consume.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import inspect
18
+ import json
19
+ import os
20
+ import time
21
+ import uuid
22
+ from concurrent.futures import ThreadPoolExecutor
23
+ from typing import Any, Callable, Dict, Optional, Type
24
+
25
+ from fastapi import (
26
+ Body,
27
+ FastAPI,
28
+ HTTPException,
29
+ Request,
30
+ status,
31
+ WebSocket,
32
+ WebSocketDisconnect,
33
+ )
34
+ from pydantic import ValidationError
35
+
36
+ from .interfaces import Environment
37
+ from .mcp_environment import get_server_tools
38
+ from .mcp_types import (
39
+ JsonRpcErrorCode,
40
+ JsonRpcRequest,
41
+ JsonRpcResponse,
42
+ McpMethod,
43
+ WSMCPMessage,
44
+ WSMCPResponse,
45
+ )
46
+ from .route_config import GetEndpointConfig, register_get_endpoints
47
+ from .serialization import deserialize_action, serialize_observation
48
+ from .types import (
49
+ Action,
50
+ ConcurrencyConfig,
51
+ EnvironmentMetadata,
52
+ HealthResponse,
53
+ HealthStatus,
54
+ Observation,
55
+ ResetRequest,
56
+ ResetResponse,
57
+ SchemaResponse,
58
+ ServerCapacityStatus,
59
+ ServerMode,
60
+ SessionInfo,
61
+ State,
62
+ StepRequest,
63
+ StepResponse,
64
+ WSCloseMessage,
65
+ WSErrorCode,
66
+ WSErrorResponse,
67
+ WSObservationResponse,
68
+ WSResetMessage,
69
+ WSStateMessage,
70
+ WSStateResponse,
71
+ WSStepMessage,
72
+ )
73
+
74
+
75
+ def _make_json_serializable(obj: Any) -> Any:
76
+ """
77
+ Convert an object to a JSON-serializable form.
78
+
79
+ Handles Pydantic models, dataclasses, and other common types.
80
+
81
+ Args:
82
+ obj: The object to convert
83
+
84
+ Returns:
85
+ A JSON-serializable representation of the object
86
+ """
87
+ if obj is None:
88
+ return None
89
+ if isinstance(obj, (str, int, float, bool)):
90
+ return obj
91
+ if isinstance(obj, (list, tuple)):
92
+ return [_make_json_serializable(item) for item in obj]
93
+ if isinstance(obj, dict):
94
+ return {k: _make_json_serializable(v) for k, v in obj.items()}
95
+ if hasattr(obj, "model_dump"):
96
+ # Pydantic model
97
+ return obj.model_dump()
98
+ if hasattr(obj, "__dict__"):
99
+ # Object with __dict__
100
+ return {k: _make_json_serializable(v) for k, v in obj.__dict__.items()}
101
+ # Fallback to string representation
102
+ return str(obj)
103
+
104
+
105
+ from .exceptions import (
106
+ ConcurrencyConfigurationError,
107
+ EnvironmentFactoryError,
108
+ SessionCapacityError,
109
+ )
110
+
111
+
112
+ class HTTPEnvServer:
113
+ """
114
+ HTTP server wrapper for Environment instances.
115
+
116
+ This class wraps an Environment and exposes its reset(), step(), and state
117
+ methods as HTTP and WebSocket endpoints compatible with EnvClient.
118
+
119
+ The server expects:
120
+ - Action deserialization: Converts JSON dict to Action subclass
121
+ - Observation serialization: Converts Observation subclass to JSON dict
122
+
123
+ Example:
124
+ >>> from core.env_server import HTTPEnvServer
125
+ >>> from envs.coding_env.server import CodeExecutionEnvironment
126
+ >>> from envs.coding_env.models import CodeAction, CodeObservation
127
+ >>>
128
+ >>> # Pass environment class (factory pattern)
129
+ >>> server = HTTPEnvServer(
130
+ ... env=CodeExecutionEnvironment,
131
+ ... action_cls=CodeAction,
132
+ ... observation_cls=CodeObservation,
133
+ ... max_concurrent_envs=4,
134
+ ... )
135
+ >>>
136
+ >>> # Register routes with FastAPI
137
+ >>> from fastapi import FastAPI
138
+ >>> app = FastAPI()
139
+ >>> server.register_routes(app)
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ env: Callable[[], Environment],
145
+ action_cls: Type[Action],
146
+ observation_cls: Type[Observation],
147
+ max_concurrent_envs: Optional[int] = None,
148
+ concurrency_config: Optional[ConcurrencyConfig] = None,
149
+ ):
150
+ """
151
+ Initialize HTTP server wrapper.
152
+
153
+ Args:
154
+ env: Environment factory (callable) that creates new instances.
155
+ Will be called to create a new environment for each WebSocket session.
156
+ action_cls: The Action subclass this environment expects
157
+ observation_cls: The Observation subclass this environment returns
158
+ max_concurrent_envs: Maximum number of concurrent WebSocket sessions.
159
+ Mutually exclusive with concurrency_config.
160
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
161
+ Mutually exclusive with max_concurrent_envs.
162
+
163
+ Raises:
164
+ ValueError: If both max_concurrent_envs and concurrency_config are provided.
165
+ ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
166
+ environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
167
+ """
168
+ # Validate that env is callable
169
+ if not callable(env):
170
+ raise TypeError(
171
+ f"env must be a callable (class or factory function), got {type(env)}. "
172
+ f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())."
173
+ )
174
+
175
+ self._env_factory: Callable[[], Environment] = env
176
+
177
+ # Handle concurrency configuration
178
+ if max_concurrent_envs is not None and concurrency_config is not None:
179
+ raise ValueError(
180
+ "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. "
181
+ "Please use only one method to configure concurrency."
182
+ )
183
+
184
+ if concurrency_config is not None:
185
+ self._concurrency_config = concurrency_config
186
+ elif max_concurrent_envs is not None:
187
+ self._concurrency_config = ConcurrencyConfig(
188
+ max_concurrent_envs=max_concurrent_envs,
189
+ session_timeout=None,
190
+ )
191
+ else:
192
+ # Default configuration
193
+ self._concurrency_config = ConcurrencyConfig(
194
+ max_concurrent_envs=1,
195
+ session_timeout=None,
196
+ )
197
+
198
+ self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs
199
+
200
+ # Validate concurrency configuration
201
+ self._validate_concurrency_safety()
202
+
203
+ self.action_cls = action_cls
204
+ self.observation_cls = observation_cls
205
+
206
+ # Session management for WebSocket connections
207
+ self._sessions: Dict[str, Environment] = {}
208
+ self._session_executors: Dict[str, ThreadPoolExecutor] = {}
209
+ self._session_info: Dict[str, SessionInfo] = {}
210
+ self._session_lock = asyncio.Lock()
211
+
212
+ # Create thread pool for running sync code in async context
213
+ # This is needed for environments using sync libraries (e.g., Playwright)
214
+ self._executor = ThreadPoolExecutor(max_workers=32)
215
+
216
+ def _validate_concurrency_safety(self) -> None:
217
+ """
218
+ Validate that the environment supports the configured concurrency level.
219
+
220
+ Raises:
221
+ ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an
222
+ environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS.
223
+ """
224
+ if self._max_concurrent_envs <= 1:
225
+ return
226
+
227
+ if inspect.isclass(self._env_factory):
228
+ env_cls = self._env_factory
229
+ else:
230
+ _temp_env = self._env_factory()
231
+ env_cls = type(_temp_env)
232
+ _temp_env.close()
233
+ del _temp_env
234
+
235
+ if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False):
236
+ raise ConcurrencyConfigurationError(
237
+ environment_name=env_cls.__name__,
238
+ max_concurrent_envs=self._max_concurrent_envs,
239
+ )
240
+
241
+ def get_capacity_status(self) -> ServerCapacityStatus:
242
+ """
243
+ Get the current capacity status of the server.
244
+
245
+ Returns:
246
+ ServerCapacityStatus with current session counts and availability.
247
+ """
248
+ return ServerCapacityStatus.from_counts(
249
+ active=len(self._sessions),
250
+ max_sessions=self._max_concurrent_envs,
251
+ )
252
+
253
+ async def _run_sync_in_thread_pool(
254
+ self, func: Callable[..., Observation], *args, **kwargs
255
+ ) -> Observation:
256
+ """Run a synchronous function in the thread pool executor."""
257
+ loop = asyncio.get_event_loop()
258
+ return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs))
259
+
260
+ def _get_valid_kwargs(
261
+ self,
262
+ sig: inspect.Signature,
263
+ kwargs: Dict[str, Any],
264
+ skip_params: Optional[set[str]] = None,
265
+ ) -> Dict[str, Any]:
266
+ """Filter kwargs to only include parameters accepted by the function signature."""
267
+ if skip_params is None:
268
+ skip_params = set()
269
+
270
+ valid_kwargs = {}
271
+
272
+ has_kwargs = any(
273
+ p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
274
+ )
275
+
276
+ for k, v in kwargs.items():
277
+ if k in sig.parameters or has_kwargs:
278
+ if k not in skip_params:
279
+ valid_kwargs[k] = v
280
+
281
+ return valid_kwargs
282
+
283
+ async def _create_session(self) -> tuple[str, Environment]:
284
+ """
285
+ Create a new WebSocket session with its own environment instance.
286
+
287
+ Returns:
288
+ Tuple of (session_id, environment)
289
+
290
+ Raises:
291
+ SessionCapacityError: If max concurrent sessions reached
292
+ EnvironmentFactoryError: If the factory fails to create an environment
293
+ """
294
+ async with self._session_lock:
295
+ if len(self._sessions) >= self._max_concurrent_envs:
296
+ raise SessionCapacityError(
297
+ active_sessions=len(self._sessions),
298
+ max_sessions=self._max_concurrent_envs,
299
+ )
300
+
301
+ session_id = str(uuid.uuid4())
302
+ current_time = time.time()
303
+
304
+ # Create executor and reserve slot so capacity is not exceeded while
305
+ # we create the env outside the lock (avoids blocking other sessions)
306
+ executor = ThreadPoolExecutor(max_workers=1)
307
+ self._session_executors[session_id] = executor
308
+ self._sessions[session_id] = None # placeholder until env is ready
309
+
310
+ try:
311
+ # Create environment in the executor thread (outside lock)
312
+ loop = asyncio.get_event_loop()
313
+ env = await loop.run_in_executor(executor, self._env_factory)
314
+ except Exception as e:
315
+ async with self._session_lock:
316
+ executor.shutdown(wait=False)
317
+ self._session_executors.pop(session_id, None)
318
+ self._sessions.pop(session_id, None)
319
+ factory_name = getattr(
320
+ self._env_factory, "__name__", str(self._env_factory)
321
+ )
322
+ raise EnvironmentFactoryError(factory_name) from e
323
+
324
+ async with self._session_lock:
325
+ self._sessions[session_id] = env
326
+ self._session_info[session_id] = SessionInfo(
327
+ session_id=session_id,
328
+ created_at=current_time,
329
+ last_activity_at=current_time,
330
+ step_count=0,
331
+ environment_type=type(env).__name__,
332
+ )
333
+
334
+ return session_id, env
335
+
336
+ async def _destroy_session(self, session_id: str) -> None:
337
+ """
338
+ Destroy a WebSocket session and cleanup resources.
339
+
340
+ Args:
341
+ session_id: The session ID to destroy
342
+ """
343
+ async with self._session_lock:
344
+ env = self._sessions.pop(session_id, None)
345
+ executor = self._session_executors.pop(session_id, None)
346
+ self._session_info.pop(session_id, None)
347
+
348
+ # Run close() in the same executor where the env was created
349
+ # This is required for thread-sensitive libraries like Playwright/greenlet
350
+ if env is not None:
351
+ if executor is not None:
352
+ try:
353
+ loop = asyncio.get_event_loop()
354
+ await loop.run_in_executor(executor, env.close)
355
+ except Exception:
356
+ # If executor close fails, try direct close as fallback
357
+ try:
358
+ env.close()
359
+ except Exception:
360
+ pass # Best effort cleanup
361
+ else:
362
+ try:
363
+ env.close()
364
+ except Exception:
365
+ pass # Best effort cleanup
366
+
367
+ # Shutdown executor after close is done
368
+ if executor is not None:
369
+ executor.shutdown(wait=False)
370
+
371
+ def _update_session_activity(
372
+ self, session_id: str, increment_step: bool = False
373
+ ) -> None:
374
+ """
375
+ Update session activity timestamp and optionally increment step count.
376
+
377
+ Args:
378
+ session_id: The session ID to update
379
+ increment_step: If True, increment the step count
380
+ """
381
+ if session_id in self._session_info:
382
+ self._session_info[session_id].last_activity_at = time.time()
383
+ if increment_step:
384
+ self._session_info[session_id].step_count += 1
385
+
386
+ def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
387
+ """
388
+ Get information about a specific session.
389
+
390
+ Args:
391
+ session_id: The session ID to query
392
+
393
+ Returns:
394
+ SessionInfo if the session exists, None otherwise
395
+ """
396
+ return self._session_info.get(session_id)
397
+
398
+ async def _run_in_session_executor(
399
+ self, session_id: str, func: Callable[..., Observation], *args, **kwargs
400
+ ) -> Observation:
401
+ """Run a synchronous function in the session's thread pool executor."""
402
+ executor = self._session_executors.get(session_id, self._executor)
403
+ loop = asyncio.get_event_loop()
404
+ return await loop.run_in_executor(executor, lambda: func(*args, **kwargs))
405
+
406
+ @property
407
+ def active_sessions(self) -> int:
408
+ """Return the number of active WebSocket sessions."""
409
+ return len(self._sessions)
410
+
411
+ @property
412
+ def max_concurrent_envs(self) -> int:
413
+ """Return the maximum number of concurrent environments."""
414
+ return self._max_concurrent_envs
415
+
416
+ @property
417
+ def is_concurrency_safe(self) -> bool:
418
+ """Return whether the environment is marked as concurrency safe."""
419
+ import inspect
420
+
421
+ if inspect.isclass(self._env_factory):
422
+ return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False)
423
+ else:
424
+ _temp_env = self._env_factory()
425
+ result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False)
426
+ _temp_env.close()
427
+ del _temp_env
428
+ return result
429
+
430
+ @property
431
+ def concurrency_config(self) -> ConcurrencyConfig:
432
+ """Return the concurrency configuration."""
433
+ return self._concurrency_config
434
+
435
+ def register_routes(
436
+ self, app: FastAPI, mode: ServerMode | str = ServerMode.SIMULATION
437
+ ) -> None:
438
+ """
439
+ Register HTTP routes on a FastAPI application.
440
+
441
+ Args:
442
+ app: FastAPI application instance
443
+ mode: Server mode - either SIMULATION or PRODUCTION (or string equivalents).
444
+ In production mode, simulation control endpoints (/reset, /step, /state)
445
+ are NOT registered. Only safe endpoints (/health, /schema, /metadata, /ws)
446
+ are available. Defaults to SIMULATION for backwards compatibility.
447
+
448
+ Raises:
449
+ ValueError: If mode is not a valid ServerMode or string equivalent.
450
+ """
451
+ # Convert string to ServerMode enum for backwards compatibility
452
+ if isinstance(mode, str):
453
+ try:
454
+ mode = ServerMode(mode.lower())
455
+ except ValueError:
456
+ valid_modes = [m.value for m in ServerMode]
457
+ raise ValueError(
458
+ f"Invalid mode: '{mode}'. Must be one of: {valid_modes}"
459
+ )
460
+
461
+ # Helper function to handle reset endpoint
462
+ async def reset_handler(
463
+ request: ResetRequest = Body(default_factory=ResetRequest),
464
+ ) -> ResetResponse:
465
+ """Reset endpoint - returns initial observation."""
466
+ _env = self._env_factory()
467
+
468
+ try:
469
+ kwargs = request.model_dump(exclude_unset=True)
470
+
471
+ is_async = _env.reset_async.__func__ is not Environment.reset_async
472
+
473
+ if is_async:
474
+ sig = inspect.signature(_env.reset_async)
475
+ else:
476
+ sig = inspect.signature(_env.reset)
477
+ valid_kwargs = self._get_valid_kwargs(sig, kwargs)
478
+
479
+ if is_async:
480
+ observation = await _env.reset_async(**valid_kwargs)
481
+ else:
482
+ observation = await self._run_sync_in_thread_pool(
483
+ _env.reset, **valid_kwargs
484
+ )
485
+ return ResetResponse(**serialize_observation(observation))
486
+ finally:
487
+ _env.close()
488
+
489
+ # Helper function to handle step endpoint
490
+ async def step_handler(request: StepRequest) -> StepResponse:
491
+ """Step endpoint - executes action and returns observation."""
492
+ action_data = request.action
493
+
494
+ try:
495
+ action = deserialize_action(action_data, self.action_cls)
496
+ except ValidationError as e:
497
+ raise HTTPException(
498
+ status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors()
499
+ )
500
+
501
+ _env = self._env_factory()
502
+
503
+ try:
504
+ kwargs = request.model_dump(exclude_unset=True, exclude={"action"})
505
+
506
+ is_async = _env.step_async.__func__ is not Environment.step_async
507
+
508
+ if is_async:
509
+ sig = inspect.signature(_env.step_async)
510
+ else:
511
+ sig = inspect.signature(_env.step)
512
+ valid_kwargs = self._get_valid_kwargs(
513
+ sig, kwargs, skip_params={"action"}
514
+ )
515
+
516
+ if is_async:
517
+ observation = await _env.step_async(action, **valid_kwargs)
518
+ else:
519
+ observation = await self._run_sync_in_thread_pool(
520
+ _env.step, action, **valid_kwargs
521
+ )
522
+
523
+ return StepResponse(**serialize_observation(observation))
524
+ finally:
525
+ _env.close()
526
+
527
+ # Helper function to handle MCP endpoint
528
+ async def mcp_handler(
529
+ request: JsonRpcRequest, session_env: Optional[Environment] = None
530
+ ) -> JsonRpcResponse:
531
+ """
532
+ Handle MCP JSON-RPC requests.
533
+
534
+ Supports tools/list and tools/call methods in JSON-RPC 2.0 format.
535
+ """
536
+ method = request.method
537
+ request_id = request.id
538
+
539
+ # Use provided session environment or create temporary one
540
+ if session_env is not None:
541
+ _env = session_env
542
+ should_close = False
543
+ else:
544
+ _env = self._env_factory()
545
+ should_close = True
546
+ try:
547
+ if method == McpMethod.TOOLS_LIST:
548
+ # Check if environment is MCP-enabled
549
+ if not hasattr(_env, "mcp_client"):
550
+ return JsonRpcResponse.error_response(
551
+ JsonRpcErrorCode.INTERNAL_ERROR,
552
+ "Environment does not support MCP",
553
+ request_id=request_id,
554
+ )
555
+
556
+ # Use async context manager for MCP client
557
+ async with _env.mcp_client:
558
+ tools = await _env.mcp_client.list_tools()
559
+
560
+ return JsonRpcResponse.success(
561
+ result={
562
+ "tools": [
563
+ t.model_dump() if hasattr(t, "model_dump") else dict(t)
564
+ for t in tools
565
+ ]
566
+ },
567
+ request_id=request_id,
568
+ )
569
+
570
+ elif method == McpMethod.TOOLS_CALL:
571
+ params = request.params
572
+ tool_name = params.get("name")
573
+ arguments = params.get("arguments", {})
574
+
575
+ if not hasattr(_env, "mcp_client"):
576
+ return JsonRpcResponse.error_response(
577
+ JsonRpcErrorCode.INTERNAL_ERROR,
578
+ "Environment does not support MCP",
579
+ request_id=request_id,
580
+ )
581
+
582
+ if not tool_name:
583
+ return JsonRpcResponse.error_response(
584
+ JsonRpcErrorCode.INVALID_REQUEST,
585
+ "Missing 'name' in params",
586
+ request_id=request_id,
587
+ )
588
+
589
+ # Use async context manager for MCP client
590
+ async with _env.mcp_client:
591
+ result = await _env.mcp_client.call_tool(
592
+ name=tool_name, arguments=arguments
593
+ )
594
+
595
+ # Ensure result is JSON serializable
596
+ serializable_result = _make_json_serializable(result)
597
+
598
+ return JsonRpcResponse.success(
599
+ result=serializable_result,
600
+ request_id=request_id,
601
+ )
602
+
603
+ else:
604
+ return JsonRpcResponse.error_response(
605
+ JsonRpcErrorCode.METHOD_NOT_FOUND,
606
+ f"Method not found: {method}",
607
+ request_id=request_id,
608
+ )
609
+
610
+ except Exception as e:
611
+ return JsonRpcResponse.error_response(
612
+ JsonRpcErrorCode.INTERNAL_ERROR,
613
+ str(e),
614
+ request_id=request_id,
615
+ )
616
+ finally:
617
+ if should_close:
618
+ _env.close()
619
+
620
+ # Register MCP WebSocket endpoint (available in both production and simulation modes)
621
+ @app.websocket("/mcp")
622
+ async def mcp_websocket_endpoint(websocket: WebSocket):
623
+ """
624
+ WebSocket endpoint for MCP JSON-RPC requests.
625
+
626
+ Each WebSocket connection gets its own environment instance for MCP operations.
627
+
628
+ Message Protocol:
629
+ - Client sends: JSON-RPC 2.0 request (tools/list, tools/call)
630
+ - Server responds: JSON-RPC 2.0 response (result or error)
631
+ """
632
+ await websocket.accept()
633
+
634
+ session_id = None
635
+ session_env = None
636
+
637
+ try:
638
+ # Create session with dedicated environment
639
+ session_id, session_env = await self._create_session()
640
+
641
+ while True:
642
+ # Receive message from client
643
+ raw_message = await websocket.receive_text()
644
+
645
+ try:
646
+ jsonrpc_dict = json.loads(raw_message)
647
+ jsonrpc_request = JsonRpcRequest(**jsonrpc_dict)
648
+ except json.JSONDecodeError as e:
649
+ error_resp = JsonRpcResponse.error_response(
650
+ JsonRpcErrorCode.PARSE_ERROR,
651
+ f"Parse error: {e}",
652
+ )
653
+ await websocket.send_text(error_resp.model_dump_json())
654
+ continue
655
+ except ValidationError as e:
656
+ error_resp = JsonRpcResponse.error_response(
657
+ JsonRpcErrorCode.INVALID_REQUEST,
658
+ f"Invalid request: {e}",
659
+ )
660
+ await websocket.send_text(error_resp.model_dump_json())
661
+ continue
662
+
663
+ try:
664
+ # Call mcp_handler with session environment
665
+ response = await mcp_handler(
666
+ jsonrpc_request, session_env=session_env
667
+ )
668
+ await websocket.send_text(response.model_dump_json())
669
+ except Exception as e:
670
+ error_resp = JsonRpcResponse.error_response(
671
+ JsonRpcErrorCode.INTERNAL_ERROR,
672
+ str(e),
673
+ request_id=jsonrpc_request.id,
674
+ )
675
+ await websocket.send_text(error_resp.model_dump_json())
676
+
677
+ except WebSocketDisconnect:
678
+ pass
679
+ except SessionCapacityError as e:
680
+ error_resp = JsonRpcResponse.error_response(
681
+ JsonRpcErrorCode.SERVER_ERROR,
682
+ str(e),
683
+ data={
684
+ "active_sessions": e.active_sessions,
685
+ "max_sessions": e.max_sessions,
686
+ },
687
+ )
688
+ await websocket.send_text(error_resp.model_dump_json())
689
+ except EnvironmentFactoryError as e:
690
+ error_resp = JsonRpcResponse.error_response(
691
+ JsonRpcErrorCode.SERVER_ERROR,
692
+ str(e),
693
+ data={"factory_name": e.factory_name},
694
+ )
695
+ await websocket.send_text(error_resp.model_dump_json())
696
+ except Exception as e:
697
+ error_resp = JsonRpcResponse.error_response(
698
+ JsonRpcErrorCode.SERVER_ERROR,
699
+ str(e),
700
+ )
701
+ await websocket.send_text(error_resp.model_dump_json())
702
+ finally:
703
+ if session_id:
704
+ await self._destroy_session(session_id)
705
+ try:
706
+ await websocket.close()
707
+ except RuntimeError:
708
+ pass
709
+
710
+ # Register simulation control routes only in simulation mode
711
+ if mode == ServerMode.SIMULATION:
712
+
713
+ @app.post(
714
+ "/reset",
715
+ response_model=ResetResponse,
716
+ tags=["Environment Control"],
717
+ summary="Reset the environment",
718
+ description="""
719
+ Reset the environment to its initial state and return the first observation.
720
+
721
+ You can optionally provide a seed for reproducibility and an episode_id for tracking.
722
+ """,
723
+ responses={
724
+ 200: {
725
+ "description": "Environment reset successfully",
726
+ "content": {
727
+ "application/json": {
728
+ "example": {
729
+ "observation": {"status": "ready", "data": {}},
730
+ "reward": None,
731
+ "done": False,
732
+ }
733
+ }
734
+ },
735
+ }
736
+ },
737
+ )
738
+ async def reset(
739
+ request: ResetRequest = Body(default_factory=ResetRequest),
740
+ ) -> ResetResponse:
741
+ return await reset_handler(request)
742
+
743
+ @app.post(
744
+ "/step",
745
+ response_model=StepResponse,
746
+ tags=["Environment Control"],
747
+ summary="Execute an action in the environment",
748
+ description="""
749
+ Execute an action in the environment and receive the resulting observation.
750
+
751
+ The action must conform to the environment's action schema, which can be
752
+ retrieved from the `/schema` endpoint. If the action is invalid,
753
+ the endpoint will return HTTP 422 with detailed validation errors.
754
+
755
+ The response includes:
756
+ - **observation**: The environment's response to the action
757
+ - **reward**: Optional reward signal (float or None)
758
+ - **done**: Boolean indicating if the episode has terminated
759
+ """,
760
+ responses={
761
+ 200: {
762
+ "description": "Action executed successfully",
763
+ "content": {
764
+ "application/json": {
765
+ "example": {
766
+ "observation": {"status": "success", "data": {}},
767
+ "reward": 1.0,
768
+ "done": False,
769
+ }
770
+ }
771
+ },
772
+ },
773
+ 422: {
774
+ "description": "Validation error - invalid action format or values",
775
+ "content": {
776
+ "application/json": {
777
+ "example": {
778
+ "detail": [
779
+ {
780
+ "type": "string_too_short",
781
+ "loc": ["body", "action", "message"],
782
+ "msg": "String should have at least 1 character",
783
+ "input": "",
784
+ }
785
+ ]
786
+ }
787
+ }
788
+ },
789
+ },
790
+ 500: {
791
+ "description": "Internal server error during action execution"
792
+ },
793
+ },
794
+ )
795
+ async def step(request: StepRequest) -> StepResponse:
796
+ return await step_handler(request)
797
+
798
+ def get_state_handler() -> State:
799
+ _env = self._env_factory()
800
+ try:
801
+ return _env.state
802
+ finally:
803
+ _env.close()
804
+
805
+ def get_metadata_handler() -> EnvironmentMetadata:
806
+ _env = self._env_factory()
807
+ try:
808
+ return _env.get_metadata()
809
+ finally:
810
+ _env.close()
811
+
812
+ # Build list of GET endpoints based on mode
813
+ get_endpoints = [
814
+ GetEndpointConfig(
815
+ path="/metadata",
816
+ handler=get_metadata_handler,
817
+ response_model=EnvironmentMetadata,
818
+ tag="Environment Info",
819
+ summary="Get environment metadata",
820
+ description="""
821
+ Get metadata about this environment.
822
+
823
+ Returns information about the environment including name, description,
824
+ version, author, and documentation links.
825
+ """,
826
+ ),
827
+ GetEndpointConfig(
828
+ path="/health",
829
+ handler=lambda: HealthResponse(status=HealthStatus.HEALTHY),
830
+ response_model=HealthResponse,
831
+ tag="Health",
832
+ summary="Health check",
833
+ description="Check if the environment server is running and healthy.",
834
+ ),
835
+ ]
836
+
837
+ # Only register /state endpoint in simulation mode
838
+ if mode == ServerMode.SIMULATION:
839
+ get_endpoints.insert(
840
+ 0,
841
+ GetEndpointConfig(
842
+ path="/state",
843
+ handler=get_state_handler,
844
+ response_model=State,
845
+ tag="State Management",
846
+ summary="Get current environment state",
847
+ description="""
848
+ Retrieve the current internal state of the environment.
849
+
850
+ The structure of the state object is defined by the environment's State model.
851
+ """,
852
+ ),
853
+ )
854
+
855
+ register_get_endpoints(app, get_endpoints)
856
+
857
+ # Register combined schema endpoint
858
+ @app.get(
859
+ "/schema",
860
+ response_model=SchemaResponse,
861
+ tags=["Schema"],
862
+ summary="Get all JSON schemas",
863
+ description="""
864
+ Get JSON schemas for actions, observations, and state in a single response.
865
+
866
+ Returns a combined schema object containing:
867
+ - **action**: JSON schema for actions accepted by this environment
868
+ - **observation**: JSON schema for observations returned by this environment
869
+ - **state**: JSON schema for environment state objects
870
+
871
+ This is more efficient than calling individual schema endpoints and provides
872
+ all schema information needed to interact with the environment.
873
+ """,
874
+ responses={
875
+ 200: {
876
+ "description": "Combined schemas retrieved successfully",
877
+ "content": {
878
+ "application/json": {
879
+ "example": {
880
+ "action": {
881
+ "type": "object",
882
+ "properties": {"message": {"type": "string"}},
883
+ },
884
+ "observation": {
885
+ "type": "object",
886
+ "properties": {"response": {"type": "string"}},
887
+ },
888
+ "state": {
889
+ "type": "object",
890
+ "properties": {"step_count": {"type": "integer"}},
891
+ },
892
+ }
893
+ }
894
+ },
895
+ }
896
+ },
897
+ )
898
+ async def get_schemas() -> SchemaResponse:
899
+ """Return all schemas in one response."""
900
+ return SchemaResponse(
901
+ action=self.action_cls.model_json_schema(),
902
+ observation=self.observation_cls.model_json_schema(),
903
+ state=State.model_json_schema(),
904
+ )
905
+
906
+ # Register MCP endpoint for production mode (direct MCP access)
907
+ @app.post("/mcp")
908
+ async def mcp_endpoint(request_raw: Request) -> Dict[str, Any]:
909
+ """
910
+ MCP JSON-RPC endpoint for production mode.
911
+
912
+ Bypasses step() overhead and provides direct access to MCP tools.
913
+ Supports tools/list and tools/call methods.
914
+ """
915
+ # Parse JSON manually to handle parse errors gracefully
916
+ try:
917
+ body = await request_raw.body()
918
+ request_dict = json.loads(body)
919
+ request = JsonRpcRequest(**request_dict)
920
+ except json.JSONDecodeError:
921
+ return JsonRpcResponse.error_response(
922
+ JsonRpcErrorCode.PARSE_ERROR
923
+ ).model_dump()
924
+ except ValidationError as e:
925
+ return JsonRpcResponse.error_response(
926
+ JsonRpcErrorCode.INVALID_REQUEST,
927
+ f"Invalid request: {e}",
928
+ ).model_dump()
929
+ except Exception:
930
+ return JsonRpcResponse.error_response(
931
+ JsonRpcErrorCode.PARSE_ERROR
932
+ ).model_dump()
933
+
934
+ method = request.method
935
+ params = request.params
936
+ request_id = request.id
937
+
938
+ # Create a temporary environment for MCP access
939
+ _env = self._env_factory()
940
+
941
+ try:
942
+ # Check if environment supports MCP
943
+ if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"):
944
+ return JsonRpcResponse.error_response(
945
+ JsonRpcErrorCode.INTERNAL_ERROR,
946
+ "Environment does not support MCP",
947
+ request_id=request_id,
948
+ ).model_dump()
949
+
950
+ if method == McpMethod.TOOLS_LIST:
951
+ # List tools from MCP server
952
+ if hasattr(_env, "mcp_client") and _env.mcp_client:
953
+ async with _env.mcp_client:
954
+ tools = await _env.mcp_client.list_tools()
955
+ return JsonRpcResponse.success(
956
+ result={
957
+ "tools": [
958
+ t.model_dump()
959
+ if hasattr(t, "model_dump")
960
+ else dict(t)
961
+ for t in tools
962
+ ]
963
+ },
964
+ request_id=request_id,
965
+ ).model_dump()
966
+ elif hasattr(_env, "mcp_server") and _env.mcp_server:
967
+ # Use server directly
968
+ tools = []
969
+ for tool_name, tool in get_server_tools(
970
+ _env.mcp_server
971
+ ).items():
972
+ tool_dict = {
973
+ "name": tool.name,
974
+ "description": tool.description or "",
975
+ "inputSchema": tool.parameters or {},
976
+ }
977
+ tools.append(tool_dict)
978
+ return JsonRpcResponse.success(
979
+ result={"tools": tools},
980
+ request_id=request_id,
981
+ ).model_dump()
982
+ else:
983
+ return JsonRpcResponse.error_response(
984
+ JsonRpcErrorCode.INTERNAL_ERROR,
985
+ "MCP server not available",
986
+ request_id=request_id,
987
+ ).model_dump()
988
+
989
+ elif method == McpMethod.TOOLS_CALL:
990
+ tool_name = params.get("name")
991
+ arguments = params.get("arguments", {})
992
+
993
+ if not tool_name:
994
+ return JsonRpcResponse.error_response(
995
+ JsonRpcErrorCode.INVALID_PARAMS,
996
+ "Invalid params - 'name' is required",
997
+ request_id=request_id,
998
+ ).model_dump()
999
+
1000
+ # Call tool via MCP
1001
+ if hasattr(_env, "mcp_client") and _env.mcp_client:
1002
+ async with _env.mcp_client:
1003
+ result = await _env.mcp_client.call_tool(
1004
+ name=tool_name, arguments=arguments
1005
+ )
1006
+ elif hasattr(_env, "mcp_server") and _env.mcp_server:
1007
+ # Call tool directly on FastMCP server
1008
+ server_tools = get_server_tools(_env.mcp_server)
1009
+ if tool_name in server_tools:
1010
+ tool = server_tools[tool_name]
1011
+ result = tool.fn(**arguments)
1012
+ else:
1013
+ return JsonRpcResponse.error_response(
1014
+ JsonRpcErrorCode.INVALID_PARAMS,
1015
+ f"Tool not found: {tool_name}",
1016
+ request_id=request_id,
1017
+ ).model_dump()
1018
+ else:
1019
+ return JsonRpcResponse.error_response(
1020
+ JsonRpcErrorCode.INTERNAL_ERROR,
1021
+ "MCP server not available",
1022
+ request_id=request_id,
1023
+ ).model_dump()
1024
+
1025
+ # Make result JSON serializable
1026
+ serializable_result = _make_json_serializable(result)
1027
+
1028
+ return JsonRpcResponse.success(
1029
+ result=serializable_result,
1030
+ request_id=request_id,
1031
+ ).model_dump()
1032
+
1033
+ else:
1034
+ return JsonRpcResponse.error_response(
1035
+ JsonRpcErrorCode.METHOD_NOT_FOUND,
1036
+ f"Method not found: {method}",
1037
+ request_id=request_id,
1038
+ ).model_dump()
1039
+
1040
+ except Exception as e:
1041
+ return JsonRpcResponse.error_response(
1042
+ JsonRpcErrorCode.INTERNAL_ERROR,
1043
+ str(e),
1044
+ request_id=request_id,
1045
+ ).model_dump()
1046
+ finally:
1047
+ _env.close()
1048
+
1049
+ # Register WebSocket endpoint for persistent sessions
1050
+ @app.websocket("/ws")
1051
+ async def websocket_endpoint(websocket: WebSocket):
1052
+ """
1053
+ WebSocket endpoint for persistent environment sessions.
1054
+
1055
+ Each WebSocket connection gets its own environment instance.
1056
+
1057
+ Message Protocol:
1058
+ - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage
1059
+ - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse
1060
+ """
1061
+ await websocket.accept()
1062
+
1063
+ session_id = None
1064
+ session_env = None
1065
+
1066
+ try:
1067
+ # Create session with dedicated environment
1068
+ session_id, session_env = await self._create_session()
1069
+
1070
+ while True:
1071
+ # Receive message from client
1072
+ raw_message = await websocket.receive_text()
1073
+
1074
+ try:
1075
+ message_dict = json.loads(raw_message)
1076
+ except json.JSONDecodeError as e:
1077
+ error_resp = WSErrorResponse(
1078
+ data={
1079
+ "message": f"Invalid JSON: {e}",
1080
+ "code": WSErrorCode.INVALID_JSON,
1081
+ }
1082
+ )
1083
+ await websocket.send_text(error_resp.model_dump_json())
1084
+ continue
1085
+
1086
+ msg_type = message_dict.get("type", "")
1087
+
1088
+ try:
1089
+ match msg_type:
1090
+ case "reset":
1091
+ msg = WSResetMessage(**message_dict)
1092
+
1093
+ is_async = (
1094
+ session_env.reset_async.__func__
1095
+ is not Environment.reset_async
1096
+ )
1097
+
1098
+ if is_async:
1099
+ sig = inspect.signature(session_env.reset_async)
1100
+ valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1101
+ observation = await session_env.reset_async(
1102
+ **valid_kwargs
1103
+ )
1104
+ else:
1105
+ sig = inspect.signature(session_env.reset)
1106
+ valid_kwargs = self._get_valid_kwargs(sig, msg.data)
1107
+ observation = await self._run_in_session_executor(
1108
+ session_id, session_env.reset, **valid_kwargs
1109
+ )
1110
+
1111
+ self._update_session_activity(session_id)
1112
+
1113
+ response = WSObservationResponse(
1114
+ data=serialize_observation(observation),
1115
+ )
1116
+
1117
+ case "step":
1118
+ msg = WSStepMessage(**message_dict)
1119
+ action = deserialize_action(msg.data, self.action_cls)
1120
+
1121
+ is_async = (
1122
+ session_env.step_async.__func__
1123
+ is not Environment.step_async
1124
+ )
1125
+
1126
+ if is_async:
1127
+ observation = await session_env.step_async(action)
1128
+ else:
1129
+ observation = await self._run_in_session_executor(
1130
+ session_id, session_env.step, action
1131
+ )
1132
+
1133
+ self._update_session_activity(
1134
+ session_id, increment_step=True
1135
+ )
1136
+
1137
+ response = WSObservationResponse(
1138
+ data=serialize_observation(observation)
1139
+ )
1140
+
1141
+ case "state":
1142
+ msg = WSStateMessage(**message_dict)
1143
+ state = session_env.state
1144
+ if hasattr(state, "model_dump"):
1145
+ state_data = state.model_dump()
1146
+ else:
1147
+ state_data = dict(state) if state else {}
1148
+
1149
+ response = WSStateResponse(data=state_data)
1150
+
1151
+ case "close":
1152
+ msg = WSCloseMessage(**message_dict)
1153
+ break
1154
+
1155
+ case "mcp":
1156
+ msg = WSMCPMessage(**message_dict)
1157
+ try:
1158
+ rpc_request = JsonRpcRequest(**msg.data)
1159
+ except (ValidationError, Exception) as e:
1160
+ rpc_response = JsonRpcResponse.error_response(
1161
+ JsonRpcErrorCode.INVALID_REQUEST,
1162
+ f"Invalid request: {e}",
1163
+ )
1164
+ else:
1165
+ rpc_response = await mcp_handler(
1166
+ rpc_request,
1167
+ session_env=session_env,
1168
+ )
1169
+ response = WSMCPResponse(data=rpc_response.model_dump())
1170
+
1171
+ case _:
1172
+ response = WSErrorResponse(
1173
+ data={
1174
+ "message": f"Unknown message type: {msg_type}",
1175
+ "code": WSErrorCode.UNKNOWN_TYPE,
1176
+ }
1177
+ )
1178
+
1179
+ await websocket.send_text(response.model_dump_json())
1180
+
1181
+ except ValidationError as e:
1182
+ error_resp = WSErrorResponse(
1183
+ data={
1184
+ "message": "Invalid message",
1185
+ "code": WSErrorCode.VALIDATION_ERROR,
1186
+ "errors": e.errors(),
1187
+ }
1188
+ )
1189
+ await websocket.send_text(error_resp.model_dump_json())
1190
+ except Exception as e:
1191
+ error_resp = WSErrorResponse(
1192
+ data={
1193
+ "message": str(e),
1194
+ "code": WSErrorCode.EXECUTION_ERROR,
1195
+ }
1196
+ )
1197
+ await websocket.send_text(error_resp.model_dump_json())
1198
+
1199
+ except WebSocketDisconnect:
1200
+ pass
1201
+ except SessionCapacityError as e:
1202
+ error_resp = WSErrorResponse(
1203
+ data={
1204
+ "message": str(e),
1205
+ "code": WSErrorCode.CAPACITY_REACHED,
1206
+ "active_sessions": e.active_sessions,
1207
+ "max_sessions": e.max_sessions,
1208
+ }
1209
+ )
1210
+ await websocket.send_text(error_resp.model_dump_json())
1211
+ except EnvironmentFactoryError as e:
1212
+ error_resp = WSErrorResponse(
1213
+ data={
1214
+ "message": str(e),
1215
+ "code": WSErrorCode.FACTORY_ERROR,
1216
+ "factory_name": e.factory_name,
1217
+ }
1218
+ )
1219
+ await websocket.send_text(error_resp.model_dump_json())
1220
+ except Exception as e:
1221
+ error_resp = WSErrorResponse(
1222
+ data={"message": str(e), "code": WSErrorCode.SESSION_ERROR}
1223
+ )
1224
+ await websocket.send_text(error_resp.model_dump_json())
1225
+ finally:
1226
+ if session_id:
1227
+ await self._destroy_session(session_id)
1228
+ try:
1229
+ await websocket.close()
1230
+ except RuntimeError:
1231
+ pass
1232
+
1233
+
1234
+ def create_app(
1235
+ env: Callable[[], Environment],
1236
+ action_cls: Type[Action],
1237
+ observation_cls: Type[Observation],
1238
+ env_name: Optional[str] = None,
1239
+ max_concurrent_envs: Optional[int] = None,
1240
+ concurrency_config: Optional[ConcurrencyConfig] = None,
1241
+ gradio_builder: Optional[Callable[..., Any]] = None,
1242
+ ) -> FastAPI:
1243
+ """
1244
+ Create a FastAPI application with or without web interface.
1245
+
1246
+ This function creates a FastAPI app with the web interface enabled by default,
1247
+ including README integration for better user experience.
1248
+
1249
+ Args:
1250
+ env: Environment factory (callable) that creates new instances
1251
+ action_cls: The Action subclass this environment expects
1252
+ observation_cls: The Observation subclass this environment returns
1253
+ env_name: Optional environment name for README loading
1254
+ max_concurrent_envs: Maximum concurrent WebSocket sessions.
1255
+ Mutually exclusive with concurrency_config.
1256
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
1257
+ Mutually exclusive with max_concurrent_envs.
1258
+ gradio_builder: Optional callable to build a custom Gradio UI at /web.
1259
+ Signature: (web_manager, action_fields, metadata, is_chat_env, title,
1260
+ quick_start_md) -> gr.Blocks. When None, the default Gradio app is used.
1261
+ See docs/customizing-web-ui.md.
1262
+
1263
+ Returns:
1264
+ FastAPI application instance with or without web interface and README integration
1265
+ """
1266
+ # Check if web interface should be enabled
1267
+ # This can be controlled via environment variable or build argument
1268
+ enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in (
1269
+ "true",
1270
+ "1",
1271
+ "yes",
1272
+ )
1273
+
1274
+ if enable_web:
1275
+ # Gradio-based web UI (gradio is a core dependency)
1276
+ from .web_interface import create_web_interface_app
1277
+
1278
+ return create_web_interface_app(
1279
+ env,
1280
+ action_cls,
1281
+ observation_cls,
1282
+ env_name,
1283
+ max_concurrent_envs,
1284
+ concurrency_config,
1285
+ gradio_builder=gradio_builder,
1286
+ )
1287
+ else:
1288
+ # Use standard FastAPI app without web interface
1289
+ return create_fastapi_app(
1290
+ env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
1291
+ )
1292
+
1293
+
1294
+ def create_fastapi_app(
1295
+ env: Callable[[], Environment],
1296
+ action_cls: Type[Action],
1297
+ observation_cls: Type[Observation],
1298
+ max_concurrent_envs: Optional[int] = None,
1299
+ concurrency_config: Optional[ConcurrencyConfig] = None,
1300
+ ) -> FastAPI:
1301
+ """
1302
+ Create a FastAPI application with comprehensive documentation.
1303
+
1304
+ Args:
1305
+ env: Environment factory (callable) that creates new instances
1306
+ action_cls: The Action subclass this environment expects
1307
+ observation_cls: The Observation subclass this environment returns
1308
+ max_concurrent_envs: Maximum concurrent WebSocket sessions.
1309
+ Mutually exclusive with concurrency_config.
1310
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings.
1311
+ Mutually exclusive with max_concurrent_envs.
1312
+
1313
+ Returns:
1314
+ FastAPI application instance
1315
+ """
1316
+ try:
1317
+ from fastapi import FastAPI
1318
+ except ImportError:
1319
+ raise ImportError(
1320
+ "FastAPI is required. Install with: pip install fastapi uvicorn"
1321
+ )
1322
+
1323
+ app = FastAPI(
1324
+ title="OpenEnv Environment HTTP API",
1325
+ version="1.0.0",
1326
+ description="""
1327
+ # OpenEnv Environment HTTP API
1328
+
1329
+ HTTP API for interacting with OpenEnv environments through a standardized interface.
1330
+
1331
+ ## Features
1332
+
1333
+ * **Environment Reset**: Initialize or restart episodes
1334
+ * **Action Execution**: Send actions and receive observations
1335
+ * **State Inspection**: Query current environment state
1336
+ * **Schema Access**: Retrieve JSON schemas for actions and observations
1337
+
1338
+ ## Workflow
1339
+
1340
+ 1. Call `/reset` to start a new episode and get initial observation
1341
+ 2. Call `/step` repeatedly with actions to interact with environment
1342
+ 3. Episode ends when observation returns `done: true`
1343
+ 4. Call `/state` anytime to inspect current environment state
1344
+
1345
+ ## Documentation
1346
+
1347
+ * **Swagger UI**: Available at `/docs`
1348
+ * **ReDoc**: Available at `/redoc`
1349
+ * **OpenAPI Schema**: Available at `/openapi.json`
1350
+ """,
1351
+ openapi_tags=[
1352
+ {
1353
+ "name": "Environment Control",
1354
+ "description": "Core operations for environment interaction (reset, step)",
1355
+ },
1356
+ {
1357
+ "name": "State Management",
1358
+ "description": "Operations for inspecting environment state",
1359
+ },
1360
+ {
1361
+ "name": "Environment Info",
1362
+ "description": "Information about the environment",
1363
+ },
1364
+ {
1365
+ "name": "Schema",
1366
+ "description": "JSON Schema endpoints for actions, observations, and state",
1367
+ },
1368
+ {"name": "Health", "description": "Service health and status checks"},
1369
+ ],
1370
+ docs_url="/docs",
1371
+ redoc_url="/redoc",
1372
+ openapi_url="/openapi.json",
1373
+ contact={
1374
+ "name": "OpenEnv Team",
1375
+ "url": "https://github.com/meta-pytorch/OpenEnv",
1376
+ },
1377
+ license_info={
1378
+ "name": "BSD-3-Clause",
1379
+ "url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE",
1380
+ },
1381
+ )
1382
+
1383
+ server = HTTPEnvServer(
1384
+ env,
1385
+ action_cls,
1386
+ observation_cls,
1387
+ max_concurrent_envs,
1388
+ concurrency_config=concurrency_config,
1389
+ )
1390
+ server.register_routes(app)
1391
+ return app
src/core/env_server/interfaces.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import inspect
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Generic, Optional, Protocol, TYPE_CHECKING, TypeVar
10
+
11
+ from typing_extensions import TypedDict
12
+
13
+ from .types import Action, EnvironmentMetadata, Observation, State
14
+
15
+ if TYPE_CHECKING:
16
+ from openenv.core.rubrics import Rubric
17
+
18
+ ActT = TypeVar("ActT", bound=Action)
19
+ ObsT = TypeVar("ObsT", bound=Observation)
20
+ StateT = TypeVar("StateT", bound=State)
21
+
22
+
23
+ class Message(TypedDict):
24
+ """A message in a conversation.
25
+
26
+ Compatible with Huggingface chat template format.
27
+ """
28
+
29
+ role: str
30
+ content: str
31
+
32
+
33
+ class ModelTokenizer(Protocol):
34
+ """Protocol for tokenizers that support chat templates.
35
+
36
+ This protocol defines the interface that tokenizers must implement
37
+ to work with chat-based environments. It's compatible with
38
+ Huggingface transformers tokenizers.
39
+ """
40
+
41
+ def apply_chat_template(
42
+ self,
43
+ conversation: list[Message],
44
+ tokenize: bool = True,
45
+ return_tensors: str | None = None,
46
+ **kwargs: Any,
47
+ ) -> Any:
48
+ """Apply a chat template to format and optionally tokenize a conversation.
49
+
50
+ Args:
51
+ conversation: List of message dictionaries with 'role' and 'content'
52
+ tokenize: Whether to tokenize the output
53
+ return_tensors: Format for returned tensors ('pt' for PyTorch)
54
+ **kwargs: Additional arguments
55
+
56
+ Returns:
57
+ Formatted and optionally tokenized conversation
58
+ """
59
+ ...
60
+
61
+ def decode(
62
+ self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
63
+ ) -> str:
64
+ """Decode token IDs back to text.
65
+
66
+ Args:
67
+ token_ids: Token IDs to decode
68
+ skip_special_tokens: Whether to skip special tokens in output
69
+ **kwargs: Additional arguments
70
+
71
+ Returns:
72
+ Decoded text string
73
+ """
74
+ ...
75
+
76
+
77
+ class Transform(ABC, Generic[ObsT]):
78
+ """Transform observations to add rewards, metrics, or other modifications.
79
+
80
+ Transforms follow the TorchRL pattern where they take an observation
81
+ and return a (potentially modified) observation. This allows for
82
+ flexible reward computation and observation augmentation.
83
+ """
84
+
85
+ @abstractmethod
86
+ def __call__(self, observation: ObsT) -> ObsT:
87
+ """Transform an observation.
88
+
89
+ Args:
90
+ observation: The input observation
91
+
92
+ Returns:
93
+ The transformed observation
94
+ """
95
+ pass
96
+
97
+
98
+ class Environment(ABC, Generic[ActT, ObsT, StateT]):
99
+ """Base class for all environment servers following Gym/Gymnasium API.
100
+
101
+ Args:
102
+ transform: Optional transform to apply to observations
103
+ rubric: Optional rubric for reward computation. When provided, the
104
+ rubric's output can be used to set the observation's reward in step().
105
+
106
+ Class Attributes:
107
+ SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions.
108
+ When True, multiple WebSocket connections can each have their own
109
+ environment instance (up to max_concurrent_envs). When False (default),
110
+ the environment should only be used with a single session at a time.
111
+
112
+ Set this to True in your Environment subclass if:
113
+ - The environment uses proper session isolation (e.g., unique working dirs)
114
+ - No shared mutable state exists between instances
115
+ - External resources (databases, APIs) can handle concurrent access
116
+
117
+ Attributes:
118
+ rubric: Optional rubric for computing rewards. Environments can set this
119
+ in __init__ and use it in step() to compute observation rewards.
120
+ Training infrastructure can access it for introspection:
121
+ for name, r in env.rubric.named_rubrics():
122
+ print(f"{name}: {r.last_score}")
123
+
124
+ See RFC 004 for rubric design: rfcs/004-rubrics.md
125
+ """
126
+
127
+ # Class-level flag indicating whether this environment supports concurrent sessions
128
+ SUPPORTS_CONCURRENT_SESSIONS: bool = False
129
+
130
+ # Optional rubric for reward computation
131
+ rubric: Optional["Rubric"]
132
+
133
+ def __init__(
134
+ self,
135
+ transform: Optional[Transform[ObsT]] = None,
136
+ rubric: Optional["Rubric"] = None,
137
+ ):
138
+ self.transform = transform
139
+ self.rubric = rubric
140
+
141
+ @abstractmethod
142
+ def reset(
143
+ self,
144
+ seed: Optional[int] = None,
145
+ episode_id: Optional[str] = None,
146
+ **kwargs: Any,
147
+ ) -> ObsT:
148
+ """Reset the environment and return initial observation."""
149
+ pass
150
+
151
+ async def reset_async(
152
+ self,
153
+ seed: Optional[int] = None,
154
+ episode_id: Optional[str] = None,
155
+ **kwargs: Any,
156
+ ) -> ObsT:
157
+ """Async version of reset. Default implementation calls sync reset.
158
+
159
+ Override to provide true async implementation.
160
+ """
161
+ return self.reset(seed=seed, episode_id=episode_id, **kwargs)
162
+
163
+ @abstractmethod
164
+ def step(
165
+ self,
166
+ action: ActT,
167
+ timeout_s: Optional[float] = None,
168
+ **kwargs: Any,
169
+ ) -> ObsT:
170
+ """Take a step in the environment."""
171
+ pass
172
+
173
+ async def step_async(
174
+ self,
175
+ action: ActT,
176
+ timeout_s: Optional[float] = None,
177
+ **kwargs: Any,
178
+ ) -> ObsT:
179
+ """Async version of step. Default implementation calls sync step.
180
+
181
+ Override to provide true async implementation.
182
+ """
183
+ return self.step(action, timeout_s=timeout_s, **kwargs)
184
+
185
+ @property
186
+ @abstractmethod
187
+ def state(self) -> StateT:
188
+ """Get the current environment state."""
189
+ pass
190
+
191
+ def get_metadata(self) -> EnvironmentMetadata:
192
+ """
193
+ Get metadata about this environment.
194
+
195
+ Override this method to provide custom metadata for the environment.
196
+ Default implementation returns basic metadata derived from class name.
197
+
198
+ Returns:
199
+ EnvironmentMetadata with environment information
200
+ """
201
+ return EnvironmentMetadata(
202
+ name=self.__class__.__name__,
203
+ description=f"{self.__class__.__name__} environment",
204
+ version="1.0.0",
205
+ )
206
+
207
+ def _apply_transform(self, observation: ObsT) -> ObsT:
208
+ """Apply transform if one is provided."""
209
+ if self.transform is not None:
210
+ return self.transform(observation)
211
+ return observation
212
+
213
+ def _apply_rubric(self, action: ActT, observation: ObsT) -> float:
214
+ """Apply rubric if one is provided.
215
+
216
+ Args:
217
+ action: The action taken by the agent.
218
+ observation: The resulting observation.
219
+
220
+ Returns:
221
+ Reward value from the rubric, or 0.0 if no rubric is set.
222
+
223
+ Usage in step():
224
+ def step(self, action: MyAction, ...) -> MyObservation:
225
+ # ... execute action and create observation ...
226
+ observation.reward = self._apply_rubric(action, observation)
227
+ return observation
228
+ """
229
+ if self.rubric is not None:
230
+ return self.rubric(action, observation)
231
+ return 0.0
232
+
233
+ async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float:
234
+ """Apply rubric asynchronously if one is provided.
235
+
236
+ Args:
237
+ action: The action taken by the agent.
238
+ observation: The resulting observation.
239
+
240
+ Returns:
241
+ Reward value from the rubric, or 0.0 if no rubric is set.
242
+
243
+ Usage in step_async():
244
+ async def step_async(self, action: MyAction, ...) -> MyObservation:
245
+ # ... execute action and create observation ...
246
+ observation.reward = await self._apply_rubric_async(action, observation)
247
+ return observation
248
+ """
249
+ if self.rubric is not None:
250
+ result = self.rubric(action, observation)
251
+ # If rubric returns a coroutine, await it
252
+ if inspect.iscoroutine(result):
253
+ return await result
254
+ return result
255
+ return 0.0
256
+
257
+ def _reset_rubric(self) -> None:
258
+ """Reset the rubric state if one is provided.
259
+
260
+ Call this in reset() to clear any trajectory state in the rubric.
261
+
262
+ Usage in reset():
263
+ def reset(self, ...) -> MyObservation:
264
+ self._reset_rubric()
265
+ # ... create initial observation ...
266
+ return observation
267
+ """
268
+ if self.rubric is not None:
269
+ self.rubric.reset()
270
+
271
+ async def _reset_rubric_async(self) -> None:
272
+ """Reset the rubric state asynchronously if one is provided.
273
+
274
+ Call this in reset_async() to clear any trajectory state in the rubric.
275
+
276
+ Usage in reset_async():
277
+ async def reset_async(self, ...) -> MyObservation:
278
+ await self._reset_rubric_async()
279
+ # ... create initial observation ...
280
+ return observation
281
+ """
282
+ if self.rubric is not None:
283
+ # Check if rubric has async reset method
284
+ if hasattr(self.rubric, "reset_async"):
285
+ result = self.rubric.reset_async()
286
+ if inspect.iscoroutine(result):
287
+ await result
288
+ else:
289
+ self.rubric.reset()
290
+
291
+ def close(self) -> None:
292
+ """Clean up resources used by the environment.
293
+
294
+ Override this method to implement custom cleanup logic.
295
+ Called when the environment is being destroyed or reset.
296
+ """
297
+ pass
src/core/env_server/mcp_environment.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ MCP Environment base class for OpenEnv.
9
+
10
+ This module provides the MCPEnvironment base class that integrates FastMCP servers
11
+ with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery
12
+ and invocation through the step() API, following RFC 003.
13
+
14
+ Key features:
15
+ - Automatic routing of ListToolsAction and CallToolAction to MCP server
16
+ - Reserved tool name validation (reset, step, state, close are protected)
17
+ - Timeout handling for tool calls
18
+ - Proper error categorization (tool not found, execution errors, timeouts)
19
+ - Mode-aware tool registration (production vs simulation)
20
+ - Code mode support via get_callables() and execute_code()
21
+
22
+ Usage:
23
+ from fastmcp import FastMCP
24
+ from openenv.core.env_server.mcp_environment import MCPEnvironment
25
+
26
+ class MyMCPEnv(MCPEnvironment):
27
+ def __init__(self):
28
+ mcp = FastMCP("my-server")
29
+
30
+ # Register mode-specific tools
31
+ @self.tool(mode="production")
32
+ def my_tool(arg: str) -> str:
33
+ return f"Production: {arg}"
34
+
35
+ @self.tool(mode="simulation")
36
+ def my_tool(arg: str) -> str:
37
+ return f"Simulation: {arg}"
38
+
39
+ super().__init__(mcp)
40
+
41
+ def reset(self, seed=None, episode_id=None, **kwargs):
42
+ # Reset logic here
43
+ ...
44
+
45
+ def _step_impl(self, action):
46
+ # Handle non-MCP actions
47
+ ...
48
+
49
+ @property
50
+ def state(self):
51
+ # Return current state
52
+ ...
53
+ """
54
+
55
+ import asyncio
56
+ import inspect
57
+ from abc import abstractmethod
58
+ from collections import defaultdict
59
+ from typing import Any, Callable, Dict, Optional
60
+
61
+ from fastmcp import Client
62
+ from fastmcp.client.client import CallToolResult
63
+ from mcp.types import TextContent
64
+
65
+ from ..utils import run_async_safely
66
+ from .interfaces import Environment
67
+ from .mcp_types import (
68
+ CallToolAction,
69
+ CallToolObservation,
70
+ ListToolsAction,
71
+ ListToolsObservation,
72
+ RESERVED_TOOL_NAMES,
73
+ Tool,
74
+ ToolError,
75
+ ToolErrorType,
76
+ )
77
+ from .types import Action, Observation
78
+
79
+
80
+ # Default timeout for MCP tool calls in seconds
81
+ MCP_TOOL_CALL_TIMEOUT = 30.0
82
+
83
+ # Valid modes for tool registration
84
+ VALID_MODES = {"production", "simulation"}
85
+
86
+
87
+ def get_server_tools(mcp_server: Any) -> Dict[str, Any]:
88
+ """
89
+ Get tools from a FastMCP server, compatible with both 2.x and 3.x.
90
+
91
+ Returns:
92
+ Dictionary mapping tool names to tool objects.
93
+ """
94
+ # FastMCP 2.x: get_tools() returns dict {name: Tool}
95
+ if hasattr(mcp_server, "get_tools"):
96
+ result = run_async_safely(mcp_server.get_tools())
97
+ if isinstance(result, dict):
98
+ return result
99
+ # FastMCP 3.x: list_tools() returns list of Tool objects
100
+ if hasattr(mcp_server, "list_tools"):
101
+ tools_list = run_async_safely(mcp_server.list_tools())
102
+ return {t.name: t for t in tools_list}
103
+ return {}
104
+
105
+
106
+ class MCPEnvironment(Environment):
107
+ """
108
+ Base class for environments that expose tools via MCP (Model Context Protocol).
109
+
110
+ MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing
111
+ agents to discover and invoke MCP tools through the standard step() interface.
112
+
113
+ The class automatically handles:
114
+ - ListToolsAction: Returns available tools from the MCP server
115
+ - CallToolAction: Invokes a specific tool with arguments
116
+
117
+ All other actions are delegated to the abstract _step_impl() method,
118
+ which subclasses must implement.
119
+
120
+ Args:
121
+ mcp_server: A FastMCP server instance containing tool definitions.
122
+ The server's tools will be validated against reserved names.
123
+ transform: Optional transform to apply to observations (inherited from Environment).
124
+
125
+ Raises:
126
+ ValueError: If any tool in the MCP server uses a reserved name
127
+ (reset, step, state, close).
128
+
129
+ Example:
130
+ >>> from fastmcp import FastMCP
131
+ >>> mcp = FastMCP("calculator")
132
+ >>> @mcp.tool()
133
+ ... def add(a: int, b: int) -> int:
134
+ ... return a + b
135
+ >>> env = MyMCPEnvironment(mcp)
136
+ >>> obs = env.step(ListToolsAction())
137
+ >>> obs.tools[0].name
138
+ 'add'
139
+ """
140
+
141
+ def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None:
142
+ """
143
+ Initialize the MCP environment.
144
+
145
+ Args:
146
+ mcp_server: A FastMCP server instance with tool definitions.
147
+ transform: Optional transform to apply to observations.
148
+
149
+ Raises:
150
+ ValueError: If any tool uses a reserved name (reset, step, state, close).
151
+ """
152
+ super().__init__(transform=transform)
153
+
154
+ # Validate tool names before storing
155
+ self._validate_tool_names(mcp_server)
156
+
157
+ self.mcp_server = mcp_server
158
+ self.mcp_client = Client(mcp_server)
159
+
160
+ # Track mode-specific tools: {tool_name: {mode: func}}
161
+ # mode can be "production", "simulation", or None (available in all modes)
162
+ self._mode_tools = defaultdict(dict)
163
+
164
+ # Track tool schemas for list_tools: {tool_name: {mode: schema}}
165
+ self._mode_tool_schemas = defaultdict(dict)
166
+
167
+ @property
168
+ def supports_code_mode(self) -> bool:
169
+ """Check if this environment supports code mode (execute_code)."""
170
+ return True
171
+
172
+ def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]:
173
+ """
174
+ Get tools from a FastMCP server, compatible with both 2.x and 3.x.
175
+
176
+ Returns:
177
+ Dictionary mapping tool names to tool objects.
178
+ """
179
+ return get_server_tools(mcp_server)
180
+
181
+ def get_callables(self) -> Dict[str, Callable]:
182
+ """
183
+ Get callable functions for code mode.
184
+
185
+ Returns tool functions as direct Python callables, enabling code mode
186
+ where agents write Python code that calls tools directly (no JSON-RPC
187
+ overhead). Mode-specific tools are filtered by the current mode.
188
+
189
+ Returns:
190
+ Dictionary mapping tool names to callables.
191
+ """
192
+ callables: Dict[str, Callable] = {}
193
+ current_mode = getattr(self, "_mode", None)
194
+
195
+ # Extract callables from FastMCP server using public API
196
+ for tool_name, tool in self._get_server_tools(self.mcp_server).items():
197
+ if hasattr(tool, "fn") and callable(tool.fn):
198
+ callables[tool_name] = tool.fn
199
+
200
+ # Add mode-specific tools available in current mode
201
+ for tool_name, mode_funcs in self._mode_tools.items():
202
+ if None in mode_funcs:
203
+ # Tool available in all modes (already in FastMCP if registered there)
204
+ if tool_name not in callables:
205
+ callables[tool_name] = mode_funcs[None]
206
+ elif current_mode in mode_funcs:
207
+ # Tool available in current mode only
208
+ callables[tool_name] = mode_funcs[current_mode]
209
+
210
+ return callables
211
+
212
+ def execute_code(self, code: str) -> Observation:
213
+ """
214
+ Execute Python code with tools available as callables.
215
+
216
+ This enables the CodeAct pattern where agents write Python code
217
+ that calls tools directly as functions, avoiding JSON-RPC overhead.
218
+
219
+ Args:
220
+ code: Python code to execute. Tools are available as functions
221
+ in the execution namespace. Set a variable named 'result'
222
+ to capture the return value.
223
+
224
+ Returns:
225
+ Observation with result in metadata["result"] or error in
226
+ metadata["error"].
227
+ """
228
+ namespace = self.get_callables()
229
+
230
+ result_dict: Dict[str, Any] = {}
231
+ try:
232
+ exec(code, namespace, result_dict)
233
+ result = result_dict.get("result")
234
+ return Observation(done=False, reward=0.0, metadata={"result": result})
235
+ except SyntaxError as e:
236
+ return Observation(
237
+ done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"}
238
+ )
239
+ except Exception as e:
240
+ return Observation(done=False, reward=0.0, metadata={"error": str(e)})
241
+
242
+ def _validate_tool_names(self, mcp_server: Any) -> None:
243
+ """
244
+ Validate that no tools use reserved names.
245
+
246
+ Reserved names (reset, step, state, close) are protected to maintain
247
+ the dual API boundary between infrastructure and agent APIs.
248
+
249
+ Args:
250
+ mcp_server: The FastMCP server to validate.
251
+
252
+ Raises:
253
+ ValueError: If any tool uses a reserved name.
254
+ """
255
+ tools_dict = self._get_server_tools(mcp_server)
256
+ if tools_dict:
257
+ tool_names = set(tools_dict.keys())
258
+ conflicts = tool_names & RESERVED_TOOL_NAMES
259
+ if conflicts:
260
+ raise ValueError(
261
+ f"MCP tools cannot use reserved names: {sorted(conflicts)}. "
262
+ f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
263
+ )
264
+
265
+ def tool(self, mode: Optional[str] = None) -> Callable:
266
+ """
267
+ Decorator for registering mode-aware tools.
268
+
269
+ Args:
270
+ mode: Optional mode for the tool ("production" or "simulation").
271
+ If None, tool is available in all modes.
272
+
273
+ Returns:
274
+ A decorator function for registering tools.
275
+
276
+ Raises:
277
+ ValueError: If mode is not None, "production", or "simulation".
278
+ """
279
+ if mode is not None and mode not in VALID_MODES:
280
+ raise ValueError(
281
+ f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None."
282
+ )
283
+
284
+ def decorator(func: Callable) -> Callable:
285
+ tool_name = func.__name__
286
+ # Validate tool name is not reserved
287
+ if tool_name in RESERVED_TOOL_NAMES:
288
+ raise ValueError(
289
+ f"Tool name '{tool_name}' is reserved and cannot be used. "
290
+ f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}"
291
+ )
292
+
293
+ # If mode is None, register with FastMCP as usual
294
+ if mode is None:
295
+ decorated_func = self.mcp_server.tool()(func)
296
+ self._mode_tools[tool_name][None] = func
297
+ return decorated_func
298
+
299
+ # For mode-specific tools, don't register with FastMCP
300
+ # Instead, track them ourselves
301
+ self._mode_tools[tool_name][mode] = func
302
+
303
+ # Extract schema information from function signature
304
+ sig = inspect.signature(func)
305
+ schema = {
306
+ "type": "object",
307
+ "properties": {},
308
+ "required": [],
309
+ }
310
+
311
+ for param_name, param in sig.parameters.items():
312
+ # Get type annotation
313
+ param_type = param.annotation
314
+ json_type = "string" # default
315
+ if param_type in (int, "int"):
316
+ json_type = "integer"
317
+ elif param_type in (float, "float"):
318
+ json_type = "number"
319
+ elif param_type in (bool, "bool"):
320
+ json_type = "boolean"
321
+
322
+ schema["properties"][param_name] = {"type": json_type}
323
+
324
+ # If no default value, it's required
325
+ if param.default == inspect.Parameter.empty:
326
+ schema["required"].append(param_name)
327
+
328
+ # Store the schema for this mode-specific tool
329
+ self._mode_tool_schemas[tool_name][mode] = {
330
+ "name": tool_name,
331
+ "description": func.__doc__ or "",
332
+ "input_schema": schema,
333
+ }
334
+
335
+ return func
336
+
337
+ return decorator
338
+
339
+ def step(
340
+ self,
341
+ action: Action,
342
+ timeout_s: Optional[float] = None,
343
+ **kwargs: Any,
344
+ ) -> Observation:
345
+ """
346
+ Execute an action in the environment.
347
+
348
+ This method routes MCP-specific actions (ListToolsAction, CallToolAction)
349
+ to the appropriate handlers, while delegating all other actions to
350
+ the subclass's _step_impl() method.
351
+
352
+ Args:
353
+ action: The action to execute. Can be:
354
+ - ListToolsAction: Returns available MCP tools
355
+ - CallToolAction: Invokes a specific MCP tool
356
+ - Any other Action: Delegated to _step_impl()
357
+ timeout_s: Optional timeout in seconds for the action.
358
+ Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions.
359
+ **kwargs: Additional arguments passed to handlers.
360
+
361
+ Returns:
362
+ Observation appropriate to the action type:
363
+ - ListToolsObservation for ListToolsAction
364
+ - CallToolObservation for CallToolAction
365
+ - Subclass-defined Observation for other actions
366
+ """
367
+ if isinstance(action, ListToolsAction):
368
+ return self._handle_list_tools()
369
+ elif isinstance(action, CallToolAction):
370
+ return self._handle_call_tool(action, timeout_s=timeout_s)
371
+ else:
372
+ return self._step_impl(action, timeout_s=timeout_s, **kwargs)
373
+
374
+ def _handle_list_tools(self) -> ListToolsObservation:
375
+ """
376
+ Handle a ListToolsAction by querying the MCP server.
377
+
378
+ Returns:
379
+ ListToolsObservation containing all available tools with their
380
+ names, descriptions, and input schemas, filtered by current mode.
381
+ """
382
+ try:
383
+ # Get current mode
384
+ current_mode = getattr(self, "_mode", None)
385
+
386
+ # Start with tools from FastMCP server (mode=None tools)
387
+ tools_result = run_async_safely(self._async_list_tools())
388
+
389
+ # Build list of Tool objects
390
+ tools = []
391
+
392
+ # Add FastMCP tools that are not mode-specific
393
+ for tool in tools_result:
394
+ if tool.name not in self._mode_tool_schemas:
395
+ tools.append(
396
+ Tool(
397
+ name=tool.name,
398
+ description=tool.description or "",
399
+ input_schema=tool.inputSchema
400
+ if hasattr(tool, "inputSchema")
401
+ else {},
402
+ )
403
+ )
404
+
405
+ # Add mode-specific tools available in current mode
406
+ for tool_name, mode_schemas in self._mode_tool_schemas.items():
407
+ if None in mode_schemas:
408
+ # Tool available in all modes
409
+ schema = mode_schemas[None]
410
+ tools.append(
411
+ Tool(
412
+ name=schema["name"],
413
+ description=schema["description"],
414
+ input_schema=schema["input_schema"],
415
+ )
416
+ )
417
+ elif current_mode in mode_schemas:
418
+ # Tool available in current mode
419
+ schema = mode_schemas[current_mode]
420
+ tools.append(
421
+ Tool(
422
+ name=schema["name"],
423
+ description=schema["description"],
424
+ input_schema=schema["input_schema"],
425
+ )
426
+ )
427
+
428
+ return ListToolsObservation(tools=tools)
429
+
430
+ except Exception as e:
431
+ # Return an observation with error in metadata
432
+ return ListToolsObservation(
433
+ tools=[],
434
+ metadata={
435
+ "error": str(e),
436
+ "error_type": "list_tools_failed",
437
+ },
438
+ )
439
+
440
+ async def _async_list_tools(self) -> list:
441
+ """
442
+ Async helper to list tools from the MCP client.
443
+
444
+ Returns:
445
+ List of tool objects from the MCP server.
446
+ """
447
+ async with self.mcp_client:
448
+ return await self.mcp_client.list_tools()
449
+
450
+ def _handle_call_tool(
451
+ self,
452
+ action: CallToolAction,
453
+ timeout_s: Optional[float] = None,
454
+ ) -> CallToolObservation:
455
+ """
456
+ Handle a CallToolAction by invoking the specified tool.
457
+
458
+ Args:
459
+ action: The CallToolAction containing tool_name and arguments.
460
+ timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s).
461
+
462
+ Returns:
463
+ CallToolObservation with the tool's result or an error.
464
+ """
465
+ timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT
466
+
467
+ # Check if this is a mode-specific tool
468
+ tool_name = action.tool_name
469
+ current_mode = getattr(self, "_mode", None)
470
+
471
+ if tool_name in self._mode_tools:
472
+ mode_info = self._mode_tools[tool_name]
473
+
474
+ # Check if tool is available in current mode
475
+ # Tool is available if:
476
+ # 1. It has a None mode (available in all modes), OR
477
+ # 2. It has an implementation for the current mode
478
+ if None in mode_info:
479
+ # Use the mode-agnostic version
480
+ func = mode_info[None]
481
+ elif current_mode in mode_info:
482
+ # Use the mode-specific version
483
+ func = mode_info[current_mode]
484
+ else:
485
+ # Tool not available in current mode
486
+ return CallToolObservation(
487
+ tool_name=tool_name,
488
+ result=None,
489
+ error=ToolError(
490
+ error_type=ToolErrorType.TOOL_NOT_FOUND,
491
+ message=f"Tool '{tool_name}' not available in {current_mode} mode",
492
+ ),
493
+ )
494
+
495
+ # Call the mode-specific function directly
496
+ try:
497
+ # Check if function is async and await if necessary
498
+ if inspect.iscoroutinefunction(func):
499
+ result = run_async_safely(func(**action.arguments))
500
+ else:
501
+ result = func(**action.arguments)
502
+
503
+ # Wrap result in CallToolResult format to match FastMCP behavior
504
+ return CallToolObservation(
505
+ tool_name=tool_name,
506
+ result=CallToolResult(
507
+ content=[TextContent(type="text", text=str(result))],
508
+ structured_content={"result": result},
509
+ meta=None,
510
+ data=result,
511
+ is_error=False,
512
+ ),
513
+ )
514
+ except Exception as e:
515
+ return CallToolObservation(
516
+ tool_name=tool_name,
517
+ result=None,
518
+ error=ToolError(
519
+ error_type=ToolErrorType.EXECUTION_ERROR,
520
+ message=str(e),
521
+ ),
522
+ )
523
+
524
+ # Not a mode-specific tool, use FastMCP
525
+ try:
526
+ # Run the async call_tool with timeout
527
+ # Use run_async_safely to handle both sync and async contexts
528
+ result = run_async_safely(
529
+ asyncio.wait_for(
530
+ self._async_call_tool(action.tool_name, action.arguments),
531
+ timeout=timeout,
532
+ )
533
+ )
534
+
535
+ return CallToolObservation(
536
+ tool_name=action.tool_name,
537
+ result=result,
538
+ )
539
+
540
+ except asyncio.TimeoutError:
541
+ return CallToolObservation(
542
+ tool_name=action.tool_name,
543
+ result=None,
544
+ error=ToolError(
545
+ error_type=ToolErrorType.TIMEOUT,
546
+ message=f"Tool '{action.tool_name}' timed out after {timeout} seconds",
547
+ ),
548
+ )
549
+
550
+ except Exception as e:
551
+ error_message = str(e)
552
+
553
+ # Determine error type based on the exception
554
+ if (
555
+ "not found" in error_message.lower()
556
+ or "unknown tool" in error_message.lower()
557
+ ):
558
+ error_type = ToolErrorType.TOOL_NOT_FOUND
559
+ elif (
560
+ "invalid" in error_message.lower()
561
+ or "argument" in error_message.lower()
562
+ ):
563
+ error_type = ToolErrorType.INVALID_ARGS
564
+ else:
565
+ error_type = ToolErrorType.EXECUTION_ERROR
566
+
567
+ return CallToolObservation(
568
+ tool_name=action.tool_name,
569
+ result=None,
570
+ error=ToolError(
571
+ error_type=error_type,
572
+ message=error_message,
573
+ ),
574
+ )
575
+
576
+ async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any:
577
+ """
578
+ Async helper to call a tool on the MCP server.
579
+
580
+ Args:
581
+ tool_name: Name of the tool to invoke.
582
+ arguments: Dictionary of arguments to pass to the tool.
583
+
584
+ Returns:
585
+ The result from the tool execution.
586
+ """
587
+ async with self.mcp_client:
588
+ return await self.mcp_client.call_tool(tool_name, arguments)
589
+
590
+ @abstractmethod
591
+ def _step_impl(
592
+ self,
593
+ action: Action,
594
+ timeout_s: Optional[float] = None,
595
+ **kwargs: Any,
596
+ ) -> Observation:
597
+ """
598
+ Handle non-MCP actions in the environment.
599
+
600
+ Subclasses must implement this method to handle any actions that are
601
+ not ListToolsAction or CallToolAction. This is where environment-specific
602
+ action processing should occur.
603
+
604
+ Args:
605
+ action: The action to execute (guaranteed not to be an MCP action).
606
+ timeout_s: Optional timeout in seconds.
607
+ **kwargs: Additional arguments.
608
+
609
+ Returns:
610
+ An Observation appropriate for the action.
611
+ """
612
+ pass
613
+
614
+ def close(self) -> None:
615
+ """
616
+ Clean up resources used by the environment.
617
+
618
+ This method cleans up the MCP client and any other resources.
619
+ Subclasses should call super().close() if they override this method.
620
+ """
621
+ # The MCP client uses async context manager, so cleanup happens
622
+ # automatically when the context exits. We just clear references.
623
+ self.mcp_client = None
624
+ self.mcp_server = None
src/core/env_server/mcp_types.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ MCP (Model Context Protocol) type definitions for OpenEnv.
9
+
10
+ This module defines strongly typed models for MCP tool discovery and invocation,
11
+ following RFC 003. These types map MCP's REST-like API (tools/list, tools/call)
12
+ to Gym-style action types.
13
+
14
+ Key design decisions:
15
+ - Tool discovery (list_tools) does NOT require reset() first
16
+ - Reserved tool names (reset, step, state, close) are prohibited
17
+ - Both step() and WebSocket /mcp paths are supported
18
+ """
19
+
20
+ from enum import Enum
21
+ from typing import Any, Dict, List, Literal, Optional, Union
22
+
23
+ from pydantic import BaseModel, ConfigDict, Field
24
+
25
+ from .types import Action, BaseMessage, Observation
26
+
27
+
28
+ # =============================================================================
29
+ # JSON-RPC 2.0 Types
30
+ # =============================================================================
31
+
32
+
33
+ class JsonRpcErrorCode(int, Enum):
34
+ """
35
+ Standard JSON-RPC 2.0 error codes.
36
+
37
+ See: https://www.jsonrpc.org/specification#error_object
38
+ """
39
+
40
+ # Standard JSON-RPC errors
41
+ PARSE_ERROR = -32700 # Invalid JSON was received
42
+ INVALID_REQUEST = -32600 # JSON is not a valid Request object
43
+ METHOD_NOT_FOUND = -32601 # Method does not exist / is not available
44
+ INVALID_PARAMS = -32602 # Invalid method parameter(s)
45
+ INTERNAL_ERROR = -32603 # Internal JSON-RPC error
46
+
47
+ # Server errors (reserved for implementation-defined errors)
48
+ SERVER_ERROR = -32000 # Generic server error
49
+
50
+
51
+ class McpMethod(str, Enum):
52
+ """Supported MCP method names."""
53
+
54
+ TOOLS_LIST = "tools/list"
55
+ TOOLS_CALL = "tools/call"
56
+
57
+
58
+ class JsonRpcError(BaseModel):
59
+ """
60
+ JSON-RPC 2.0 error object.
61
+
62
+ See: https://www.jsonrpc.org/specification#error_object
63
+ """
64
+
65
+ model_config = ConfigDict(extra="forbid")
66
+
67
+ code: int = Field(description="Error code indicating the error type")
68
+ message: str = Field(description="Short description of the error")
69
+ data: Optional[Any] = Field(
70
+ default=None, description="Additional error information"
71
+ )
72
+
73
+ @classmethod
74
+ def from_code(
75
+ cls, code: JsonRpcErrorCode, message: Optional[str] = None, data: Any = None
76
+ ) -> "JsonRpcError":
77
+ """Create an error from a standard error code."""
78
+ default_messages = {
79
+ JsonRpcErrorCode.PARSE_ERROR: "Parse error",
80
+ JsonRpcErrorCode.INVALID_REQUEST: "Invalid Request",
81
+ JsonRpcErrorCode.METHOD_NOT_FOUND: "Method not found",
82
+ JsonRpcErrorCode.INVALID_PARAMS: "Invalid params",
83
+ JsonRpcErrorCode.INTERNAL_ERROR: "Internal error",
84
+ JsonRpcErrorCode.SERVER_ERROR: "Server error",
85
+ }
86
+ return cls(
87
+ code=code.value,
88
+ message=message or default_messages.get(code, "Unknown error"),
89
+ data=data,
90
+ )
91
+
92
+
93
+ class JsonRpcRequest(BaseModel):
94
+ """
95
+ JSON-RPC 2.0 request object.
96
+
97
+ See: https://www.jsonrpc.org/specification#request_object
98
+ """
99
+
100
+ model_config = ConfigDict(extra="forbid")
101
+
102
+ jsonrpc: Literal["2.0"] = Field(description="JSON-RPC version, must be '2.0'")
103
+ method: str = Field(description="Name of the method to be invoked")
104
+ params: Dict[str, Any] = Field(
105
+ default_factory=dict, description="Parameter values for the method"
106
+ )
107
+ id: Optional[Union[str, int]] = Field(
108
+ default=None, description="Request identifier established by the client"
109
+ )
110
+
111
+
112
+ class JsonRpcResponse(BaseModel):
113
+ """
114
+ JSON-RPC 2.0 response object.
115
+
116
+ Per JSON-RPC 2.0 spec, a response has either 'result' or 'error', not both.
117
+ This model excludes None values during serialization to comply with the spec.
118
+
119
+ See: https://www.jsonrpc.org/specification#response_object
120
+ """
121
+
122
+ model_config = ConfigDict(extra="forbid")
123
+
124
+ jsonrpc: Literal["2.0"] = Field(default="2.0", description="JSON-RPC version")
125
+ result: Optional[Any] = Field(
126
+ default=None, description="Result of the method invocation"
127
+ )
128
+ error: Optional[JsonRpcError] = Field(
129
+ default=None, description="Error object if method invocation failed"
130
+ )
131
+ id: Optional[Union[str, int]] = Field(
132
+ default=None, description="Request identifier from the request"
133
+ )
134
+
135
+ def model_dump(self, **kwargs) -> Dict[str, Any]:
136
+ """Serialize to dict, excluding result or error when None (JSON-RPC compliance)."""
137
+ # Always include jsonrpc and id, but only include result OR error
138
+ data: Dict[str, Any] = {"jsonrpc": self.jsonrpc, "id": self.id}
139
+ if self.error is not None:
140
+ data["error"] = (
141
+ self.error.model_dump()
142
+ if hasattr(self.error, "model_dump")
143
+ else self.error
144
+ )
145
+ else:
146
+ # Only include result if there's no error
147
+ data["result"] = self.result
148
+ return data
149
+
150
+ def model_dump_json(self, **kwargs) -> str:
151
+ """Serialize to JSON string, excluding result or error when None (JSON-RPC compliance)."""
152
+ import json
153
+
154
+ return json.dumps(self.model_dump())
155
+
156
+ @classmethod
157
+ def success(
158
+ cls, result: Any, request_id: Optional[Union[str, int]] = None
159
+ ) -> "JsonRpcResponse":
160
+ """Create a success response."""
161
+ return cls(result=result, id=request_id)
162
+
163
+ @classmethod
164
+ def error_response(
165
+ cls,
166
+ code: JsonRpcErrorCode,
167
+ message: Optional[str] = None,
168
+ data: Any = None,
169
+ request_id: Optional[Union[str, int]] = None,
170
+ ) -> "JsonRpcResponse":
171
+ """Create an error response from a standard error code."""
172
+ return cls(
173
+ error=JsonRpcError.from_code(code, message, data),
174
+ id=request_id,
175
+ )
176
+
177
+
178
+ # =============================================================================
179
+ # MCP Tool Types
180
+ # =============================================================================
181
+
182
+
183
+ class Tool(BaseModel):
184
+ """
185
+ Strongly typed MCP tool specification.
186
+
187
+ Follows the MCP ToolSpec format for tool discovery.
188
+ See: https://modelcontextprotocol.io/specification/2025-06-18/server/tools
189
+ """
190
+
191
+ model_config = ConfigDict(extra="forbid")
192
+
193
+ name: str = Field(description="Unique identifier for the tool")
194
+ description: str = Field(
195
+ description="Human-readable description of what the tool does"
196
+ )
197
+ input_schema: Dict[str, Any] = Field(
198
+ description="JSON Schema for the tool's input parameters"
199
+ )
200
+
201
+
202
+ class ToolErrorType(str, Enum):
203
+ """Types of errors that can occur during tool execution."""
204
+
205
+ EXECUTION_ERROR = "execution_error" # Tool ran but failed
206
+ INVALID_ARGS = "invalid_args" # Invalid arguments provided
207
+ TRANSPORT_ERROR = "transport_error" # Communication failure
208
+ TOOL_NOT_FOUND = "tool_not_found" # Tool doesn't exist
209
+ TIMEOUT = "timeout" # Operation timed out
210
+
211
+
212
+ class ToolError(BaseModel):
213
+ """
214
+ Structured error for tool execution failures.
215
+
216
+ This is used for transport/framework errors, NOT for errors returned
217
+ by the tool itself (those go in the result field).
218
+ """
219
+
220
+ model_config = ConfigDict(extra="forbid")
221
+
222
+ error_type: ToolErrorType = Field(description="Category of the error")
223
+ message: str = Field(description="Human-readable error message")
224
+
225
+
226
+ # --- MCP Actions ---
227
+
228
+
229
+ class ListToolsAction(Action):
230
+ """
231
+ Request list of available tools from the environment.
232
+
233
+ This action triggers MCP's tools/list operation and returns
234
+ all available tools with their schemas.
235
+
236
+ Note: Does NOT require reset() to be called first.
237
+ """
238
+
239
+ type: Literal["list_tools"] = Field(
240
+ default="list_tools", description="Action type discriminator"
241
+ )
242
+
243
+
244
+ class CallToolAction(Action):
245
+ """
246
+ Call a specific tool via MCP.
247
+
248
+ This action triggers MCP's tools/call operation with the
249
+ specified tool name and arguments.
250
+ """
251
+
252
+ type: Literal["call_tool"] = Field(
253
+ default="call_tool", description="Action type discriminator"
254
+ )
255
+ tool_name: str = Field(description="Name of the tool to call")
256
+ arguments: Dict[str, Any] = Field(
257
+ default_factory=dict, description="Arguments to pass to the tool"
258
+ )
259
+
260
+
261
+ # --- MCP Observations ---
262
+
263
+
264
+ class ListToolsObservation(Observation):
265
+ """
266
+ Response containing available tools.
267
+
268
+ Returned when processing a ListToolsAction.
269
+ """
270
+
271
+ tools: List[Tool] = Field(description="List of available tools with their schemas")
272
+
273
+
274
+ class CallToolObservation(Observation):
275
+ """
276
+ Response from tool execution.
277
+
278
+ Contains the tool's result or an error if the call failed.
279
+ Tool-specific errors (from the tool itself) are included in the result.
280
+ Transport/framework errors use the error field.
281
+ """
282
+
283
+ tool_name: str = Field(description="Name of the tool that was called")
284
+ result: Any = Field(
285
+ default=None, description="Tool-specific result (may include tool errors)"
286
+ )
287
+ error: Optional[ToolError] = Field(
288
+ default=None, description="Transport/framework error if call failed"
289
+ )
290
+
291
+
292
+ # --- WebSocket Message Types for MCP ---
293
+
294
+
295
+ class WSMCPMessage(BaseMessage):
296
+ """
297
+ WebSocket message for MCP JSON-RPC requests.
298
+
299
+ Allows direct MCP access via WebSocket for production inference,
300
+ bypassing the step() API.
301
+ """
302
+
303
+ type: Literal["mcp"] = Field(default="mcp", description="Message type")
304
+ data: Dict[str, Any] = Field(description="JSON-RPC payload (method, params, id)")
305
+
306
+
307
+ class WSMCPResponse(BaseModel):
308
+ """
309
+ WebSocket response for MCP JSON-RPC.
310
+
311
+ Contains the JSON-RPC response from the MCP server.
312
+ """
313
+
314
+ model_config = ConfigDict(extra="forbid")
315
+
316
+ type: str = Field(default="mcp", description="Response type")
317
+ data: Dict[str, Any] = Field(description="JSON-RPC response payload")
318
+
319
+
320
+ # Reserved tool names that cannot be used (protects dual API boundary)
321
+ RESERVED_TOOL_NAMES = frozenset(["reset", "step", "state", "close"])
src/core/env_server/route_config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Route configuration utilities for declarative FastAPI route registration.
9
+
10
+ This module provides utilities to reduce boilerplate in route registration
11
+ by using configuration objects instead of repeated function calls.
12
+ """
13
+
14
+ from dataclasses import dataclass
15
+ from typing import Callable, List, Type
16
+
17
+ from fastapi import FastAPI
18
+ from pydantic import BaseModel
19
+
20
+
21
+ @dataclass
22
+ class GetEndpointConfig:
23
+ """Configuration for a simple GET endpoint."""
24
+
25
+ path: str
26
+ handler: Callable[[], BaseModel | dict]
27
+ response_model: Type[BaseModel] | type[dict]
28
+ tag: str
29
+ summary: str
30
+ description: str
31
+
32
+
33
+ def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None:
34
+ """
35
+ Register multiple GET endpoints from configuration.
36
+
37
+ Args:
38
+ app: FastAPI application instance
39
+ configs: List of GET endpoint configurations
40
+ """
41
+ for config in configs:
42
+ # Capture handler in a closure to avoid non-serializable default parameter
43
+ def make_endpoint(
44
+ handler: Callable[[], BaseModel | dict],
45
+ ) -> Callable[[], BaseModel | dict]:
46
+ async def endpoint() -> BaseModel | dict:
47
+ return handler()
48
+
49
+ return endpoint
50
+
51
+ app.get(
52
+ config.path,
53
+ response_model=config.response_model,
54
+ tags=[config.tag],
55
+ summary=config.summary,
56
+ description=config.description,
57
+ )(make_endpoint(config.handler))
src/core/env_server/serialization.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Shared serialization and deserialization utilities for OpenEnv HTTP servers.
9
+
10
+ This module provides common utilities for converting between JSON dictionaries
11
+ and Pydantic models (Action/Observation) to eliminate code duplication across
12
+ HTTP server and web interface implementations.
13
+ """
14
+
15
+ from typing import Any, Dict, Type
16
+
17
+ from .types import Action, Observation
18
+
19
+
20
+ def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action:
21
+ """
22
+ Convert JSON dict to Action instance using Pydantic validation.
23
+
24
+ This is a basic deserialization that works for most environments.
25
+ For special cases (e.g., tensor fields, custom type conversions),
26
+ use deserialize_action_with_preprocessing().
27
+
28
+ Args:
29
+ action_data: Dictionary containing action data
30
+ action_cls: The Action subclass to instantiate
31
+
32
+ Returns:
33
+ Action instance
34
+
35
+ Raises:
36
+ ValidationError: If action_data is invalid for the action class
37
+
38
+ Note:
39
+ This uses Pydantic's model_validate() for automatic validation.
40
+ """
41
+ return action_cls.model_validate(action_data)
42
+
43
+
44
+ def deserialize_action_with_preprocessing(
45
+ action_data: Dict[str, Any], action_cls: Type[Action]
46
+ ) -> Action:
47
+ """
48
+ Convert JSON dict to Action instance with preprocessing for special types.
49
+
50
+ This version handles common type conversions needed for web interfaces:
51
+ - Converting lists/strings to tensors for 'tokens' field
52
+ - Converting string action_id to int
53
+ - Other custom preprocessing as needed
54
+
55
+ Args:
56
+ action_data: Dictionary containing action data
57
+ action_cls: The Action subclass to instantiate
58
+
59
+ Returns:
60
+ Action instance
61
+
62
+ Raises:
63
+ ValidationError: If action_data is invalid for the action class
64
+ """
65
+ processed_data = {}
66
+
67
+ for key, value in action_data.items():
68
+ if key == "tokens" and isinstance(value, (list, str)):
69
+ # Convert list or string to tensor
70
+ if isinstance(value, str):
71
+ # If it's a string, try to parse it as a list of numbers
72
+ try:
73
+ import json
74
+
75
+ value = json.loads(value)
76
+ except Exception:
77
+ # If parsing fails, treat as empty list
78
+ value = []
79
+ if isinstance(value, list):
80
+ try:
81
+ import torch # type: ignore
82
+
83
+ processed_data[key] = torch.tensor(value, dtype=torch.long)
84
+ except ImportError:
85
+ # If torch not available, keep as list
86
+ processed_data[key] = value
87
+ else:
88
+ processed_data[key] = value
89
+ elif key == "action_id" and isinstance(value, str):
90
+ # Convert action_id from string to int
91
+ try:
92
+ processed_data[key] = int(value)
93
+ except ValueError:
94
+ # If conversion fails, keep original value
95
+ processed_data[key] = value
96
+ else:
97
+ processed_data[key] = value
98
+
99
+ return action_cls.model_validate(processed_data)
100
+
101
+
102
+ def serialize_observation(observation: Observation) -> Dict[str, Any]:
103
+ """
104
+ Convert Observation instance to JSON-compatible dict using Pydantic.
105
+
106
+ Args:
107
+ observation: Observation instance
108
+
109
+ Returns:
110
+ Dictionary compatible with EnvClient._parse_result()
111
+
112
+ The format matches what EnvClient expects:
113
+ {
114
+ "observation": {...}, # Observation fields
115
+ "reward": float | None,
116
+ "done": bool,
117
+ }
118
+ """
119
+ # Use Pydantic's model_dump() for serialization
120
+ obs_dict = observation.model_dump(
121
+ exclude={
122
+ "reward",
123
+ "done",
124
+ "metadata",
125
+ } # Exclude these from observation dict
126
+ )
127
+
128
+ # Extract reward and done directly from the observation
129
+ reward = observation.reward
130
+ done = observation.done
131
+
132
+ # Return in EnvClient expected format
133
+ return {
134
+ "observation": obs_dict,
135
+ "reward": reward,
136
+ "done": done,
137
+ }
src/core/env_server/types.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from enum import Enum
8
+ from typing import Annotated, Any, Dict, Literal, Optional, Union
9
+
10
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
11
+
12
+
13
+ # Type aliases
14
+ Scalar = Union[int, float, bool]
15
+
16
+
17
+ # =============================================================================
18
+ # Enums for Type Safety
19
+ # =============================================================================
20
+
21
+
22
+ class ServerMode(str, Enum):
23
+ """Server operation mode."""
24
+
25
+ SIMULATION = "simulation"
26
+ PRODUCTION = "production"
27
+
28
+
29
+ class HealthStatus(str, Enum):
30
+ """Server health status values."""
31
+
32
+ HEALTHY = "healthy"
33
+ UNHEALTHY = "unhealthy"
34
+ DEGRADED = "degraded"
35
+
36
+
37
+ class WSErrorCode(str, Enum):
38
+ """WebSocket error codes for structured error handling."""
39
+
40
+ INVALID_JSON = "INVALID_JSON"
41
+ UNKNOWN_TYPE = "UNKNOWN_TYPE"
42
+ VALIDATION_ERROR = "VALIDATION_ERROR"
43
+ EXECUTION_ERROR = "EXECUTION_ERROR"
44
+ CAPACITY_REACHED = "CAPACITY_REACHED"
45
+ FACTORY_ERROR = "FACTORY_ERROR"
46
+ SESSION_ERROR = "SESSION_ERROR"
47
+
48
+
49
+ # =============================================================================
50
+ # Core Types
51
+ # =============================================================================
52
+
53
+
54
+ class Action(BaseModel):
55
+ """Base class for all environment actions.
56
+
57
+ All action subclasses should inherit from this base class.
58
+ Uses Pydantic for automatic validation and serialization.
59
+ """
60
+
61
+ model_config = ConfigDict(
62
+ extra="forbid", # Reject unknown fields
63
+ validate_assignment=True, # Validate on field assignment
64
+ arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc.
65
+ )
66
+
67
+ metadata: Dict[str, Any] = Field(
68
+ default_factory=dict, description="Additional metadata for the action"
69
+ )
70
+
71
+
72
+ class Observation(BaseModel):
73
+ """Base class for all environment observations.
74
+
75
+ All observation subclasses should inherit from this base class.
76
+ Uses Pydantic for automatic validation and serialization.
77
+ """
78
+
79
+ model_config = ConfigDict(
80
+ extra="forbid",
81
+ validate_assignment=True,
82
+ arbitrary_types_allowed=True,
83
+ )
84
+
85
+ done: bool = Field(default=False, description="Whether the episode has terminated")
86
+ reward: bool | int | float | None = Field(
87
+ default=None, description="Reward signal from the last action"
88
+ )
89
+ metadata: Dict[str, Any] = Field(
90
+ default_factory=dict, description="Additional metadata for the observation"
91
+ )
92
+
93
+
94
+ class ResetRequest(BaseModel):
95
+ """Request model for environment reset."""
96
+
97
+ model_config = ConfigDict(
98
+ extra="allow", # Allow extra fields for custom reset parameters
99
+ json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]},
100
+ )
101
+
102
+ seed: Optional[int] = Field(
103
+ default=None, ge=0, description="Random seed for reproducible episodes"
104
+ )
105
+ episode_id: Optional[str] = Field(
106
+ default=None, max_length=255, description="Custom episode identifier"
107
+ )
108
+
109
+
110
+ class ResetResponse(BaseModel):
111
+ """Response model for environment reset."""
112
+
113
+ model_config = ConfigDict(extra="forbid")
114
+
115
+ observation: Dict[str, Any] = Field(
116
+ ..., description="Initial observation from the environment"
117
+ )
118
+ reward: Optional[float] = Field(
119
+ default=None, description="Initial reward (typically None at reset)"
120
+ )
121
+ done: bool = Field(
122
+ default=False, description="Whether episode is already done (typically False)"
123
+ )
124
+
125
+
126
+ class StepRequest(BaseModel):
127
+ """Request model for environment step."""
128
+
129
+ model_config = ConfigDict(
130
+ extra="allow", # Allow extra fields for custom step parameters
131
+ json_schema_extra={
132
+ "examples": [
133
+ {"action": {"value": 1}, "timeout_s": 30.0},
134
+ {"action": {"value": 1}, "render": True, "verbose": False},
135
+ ]
136
+ },
137
+ )
138
+
139
+ action: Dict[str, Any] = Field(
140
+ ...,
141
+ description="Action to execute, must conform to environment's action schema",
142
+ )
143
+ timeout_s: Optional[float] = Field(
144
+ default=None,
145
+ gt=0,
146
+ description="Optional timeout in seconds for action execution",
147
+ )
148
+ request_id: Optional[str] = Field(
149
+ default=None,
150
+ max_length=255,
151
+ description="Optional request identifier for tracking",
152
+ )
153
+
154
+
155
+ class StepResponse(BaseModel):
156
+ """Response model for environment step."""
157
+
158
+ model_config = ConfigDict(extra="forbid")
159
+
160
+ observation: Dict[str, Any] = Field(
161
+ ..., description="Observation resulting from the action"
162
+ )
163
+ reward: Optional[float] = Field(
164
+ default=None, description="Reward signal from the action"
165
+ )
166
+ done: bool = Field(default=False, description="Whether the episode has terminated")
167
+
168
+
169
+ class BaseMessage(BaseModel):
170
+ """Base class for WebSocket messages with shared configuration."""
171
+
172
+ model_config = ConfigDict(
173
+ extra="forbid",
174
+ validate_assignment=True,
175
+ )
176
+
177
+
178
+ class State(BaseModel):
179
+ """Base class for environment state.
180
+
181
+ Represents internal environment state, separate from observations.
182
+ """
183
+
184
+ model_config = ConfigDict(
185
+ extra="allow", # Allow extra fields for flexibility
186
+ validate_assignment=True,
187
+ arbitrary_types_allowed=True,
188
+ )
189
+
190
+ episode_id: Optional[str] = Field(
191
+ default=None, description="Unique identifier for the current episode"
192
+ )
193
+ step_count: int = Field(
194
+ default=0,
195
+ ge=0, # Greater than or equal to 0
196
+ description="Number of steps taken in the current episode",
197
+ )
198
+
199
+
200
+ class CodeExecResult(BaseMessage):
201
+ """Result of code execution containing stdout, stderr, and exit code."""
202
+
203
+ stdout: str = Field(description="Standard output from code execution")
204
+ stderr: str = Field(description="Standard error from code execution")
205
+ exit_code: int = Field(description="Exit code from code execution")
206
+
207
+
208
+ class EnvironmentMetadata(BaseMessage):
209
+ """Metadata about an environment for documentation and UI purposes."""
210
+
211
+ name: str = Field(description="Name of the environment")
212
+ description: str = Field(description="Description of what the environment does")
213
+ readme_content: Optional[str] = Field(
214
+ default=None, description="Content of the README file for the environment"
215
+ )
216
+ version: Optional[str] = Field(
217
+ default=None, description="Version of the environment"
218
+ )
219
+ author: Optional[str] = Field(default=None, description="Author of the environment")
220
+ documentation_url: Optional[str] = Field(
221
+ default=None, description="URL to the environment's documentation"
222
+ )
223
+
224
+
225
+ class SchemaResponse(BaseMessage):
226
+ """Response model for the combined schema endpoint."""
227
+
228
+ action: Dict[str, Any] = Field(
229
+ description="JSON schema for actions accepted by this environment"
230
+ )
231
+ observation: Dict[str, Any] = Field(
232
+ description="JSON schema for observations returned by this environment"
233
+ )
234
+ state: Dict[str, Any] = Field(
235
+ description="JSON schema for environment state objects"
236
+ )
237
+
238
+
239
+ class HealthResponse(BaseMessage):
240
+ """Response model for health check endpoint."""
241
+
242
+ status: HealthStatus = Field(
243
+ default=HealthStatus.HEALTHY,
244
+ description="Health status of the environment server",
245
+ )
246
+
247
+
248
+ class WSResetMessage(BaseMessage):
249
+ """WebSocket message to reset the environment."""
250
+
251
+ type: Literal["reset"] = Field(default="reset", description="Message type")
252
+ data: Dict[str, Any] = Field(
253
+ default_factory=dict,
254
+ description="Optional reset parameters (seed, episode_id, etc.)",
255
+ )
256
+
257
+
258
+ class WSStepMessage(BaseMessage):
259
+ """WebSocket message to execute a step."""
260
+
261
+ type: Literal["step"] = Field(default="step", description="Message type")
262
+ data: Dict[str, Any] = Field(
263
+ ..., description="Action data conforming to environment's action schema"
264
+ )
265
+
266
+
267
+ class WSStateMessage(BaseMessage):
268
+ """WebSocket message to request current state."""
269
+
270
+ type: Literal["state"] = Field(default="state", description="Message type")
271
+
272
+
273
+ class WSCloseMessage(BaseMessage):
274
+ """WebSocket message to close the session."""
275
+
276
+ type: Literal["close"] = Field(default="close", description="Message type")
277
+
278
+
279
+ # Discriminated union for incoming WebSocket messages
280
+ # Note: WSMCPMessage is defined in mcp_types.py to avoid circular imports
281
+ # The union here covers the core message types; MCP messages are handled separately
282
+ WSIncomingMessage = Annotated[
283
+ WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage,
284
+ Field(discriminator="type"),
285
+ ]
286
+
287
+
288
+ class WSObservationResponse(BaseModel):
289
+ """WebSocket response containing an observation."""
290
+
291
+ model_config = ConfigDict(extra="forbid")
292
+
293
+ type: Literal["observation"] = Field(
294
+ default="observation", description="Response type"
295
+ )
296
+ data: Dict[str, Any] = Field(description="Observation data")
297
+
298
+
299
+ class WSStateResponse(BaseModel):
300
+ """WebSocket response containing environment state."""
301
+
302
+ model_config = ConfigDict(extra="forbid")
303
+
304
+ type: Literal["state"] = Field(default="state", description="Response type")
305
+ data: Dict[str, Any] = Field(description="State data")
306
+
307
+
308
+ class WSErrorResponse(BaseModel):
309
+ """WebSocket response for errors."""
310
+
311
+ model_config = ConfigDict(extra="forbid")
312
+
313
+ type: Literal["error"] = Field(default="error", description="Response type")
314
+ data: Dict[str, Any] = Field(description="Error details including message and code")
315
+
316
+
317
+ class ConcurrencyConfig(BaseMessage):
318
+ """Configuration for concurrent environment sessions."""
319
+
320
+ max_concurrent_envs: int = Field(
321
+ default=1,
322
+ ge=1,
323
+ description="Maximum number of concurrent WebSocket sessions allowed",
324
+ )
325
+ session_timeout: Optional[float] = Field(
326
+ default=None,
327
+ gt=0,
328
+ description="Timeout in seconds for inactive sessions. None means no timeout.",
329
+ )
330
+
331
+
332
+ class ServerCapacityStatus(BaseMessage):
333
+ """Status of server capacity for concurrent sessions."""
334
+
335
+ active_sessions: int = Field(
336
+ ge=0,
337
+ description="Number of currently active sessions",
338
+ )
339
+ max_sessions: int = Field(
340
+ ge=1,
341
+ description="Maximum number of allowed sessions",
342
+ )
343
+
344
+ @model_validator(mode="after")
345
+ def check_capacity_bounds(self) -> "ServerCapacityStatus":
346
+ if self.active_sessions > self.max_sessions:
347
+ raise ValueError(
348
+ f"active_sessions ({self.active_sessions}) cannot exceed "
349
+ f"max_sessions ({self.max_sessions})"
350
+ )
351
+ return self
352
+
353
+ @property
354
+ def available_slots(self) -> int:
355
+ """Number of available session slots."""
356
+ return self.max_sessions - self.active_sessions
357
+
358
+ @property
359
+ def is_at_capacity(self) -> bool:
360
+ """Whether the server has reached maximum capacity."""
361
+ return self.available_slots == 0
362
+
363
+ @classmethod
364
+ def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus":
365
+ """Create status from active and max session counts."""
366
+ return cls(
367
+ active_sessions=active,
368
+ max_sessions=max_sessions,
369
+ )
370
+
371
+
372
+ class SessionInfo(BaseMessage):
373
+ """Information about an active session."""
374
+
375
+ session_id: str = Field(description="Unique identifier for the session")
376
+ created_at: float = Field(description="Unix timestamp when the session was created")
377
+ last_activity_at: float = Field(
378
+ description="Unix timestamp of the last activity in the session"
379
+ )
380
+ step_count: int = Field(
381
+ default=0,
382
+ ge=0,
383
+ description="Number of steps executed in this session",
384
+ )
385
+ environment_type: str = Field(
386
+ description="Environment type for this session (e.g. `CodingEnv`)"
387
+ )
src/core/env_server/web_interface.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Web interface for OpenEnv environments.
9
+
10
+ When ENABLE_WEB_INTERFACE is set, the server exposes a Gradio UI at /web for
11
+ reset, step, and state observation. Controlled by the CLI enable_interface
12
+ option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import json
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from datetime import datetime
21
+ from typing import Any, Callable, Dict, List, Optional, Type
22
+
23
+ import gradio as gr
24
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
25
+ from pydantic import BaseModel, ConfigDict, Field
26
+
27
+ from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME
28
+ from .gradio_ui import build_gradio_app, get_gradio_display_title
29
+ from .interfaces import Environment
30
+ from .serialization import deserialize_action_with_preprocessing, serialize_observation
31
+ from .types import Action, EnvironmentMetadata, Observation, State
32
+
33
+ # Quick Start markdown template; placeholders match init suffixes (__ENV_NAME__, __ENV_CLASS_NAME__*).
34
+ DEFAULT_QUICK_START_MARKDOWN = """
35
+ ### Connect to this environment
36
+
37
+ Connect from Python using `__ENV_CLASS_NAME__Env`:
38
+
39
+ ```python
40
+ from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env
41
+
42
+ with __ENV_CLASS_NAME__Env.from_env("<SPACE_ID>") as env:
43
+ result = await env.step(__ENV_CLASS_NAME__Action(message="..."))
44
+ ```
45
+
46
+ Or connect directly to a running server:
47
+
48
+ ```python
49
+ env = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000")
50
+ ```
51
+
52
+ ### Contribute to this environment
53
+
54
+ Submit improvements via pull request on the Hugging Face Hub.
55
+
56
+ ```bash
57
+ openenv fork <SPACE_ID> --repo-id <your-username>/<your-repo-name>
58
+ ```
59
+
60
+ Then make your changes and submit a pull request:
61
+
62
+ ```bash
63
+ cd <forked-repo>
64
+ openenv push <SPACE_ID> --create-pr
65
+ ```
66
+
67
+ For more information, see the [OpenEnv documentation](https://meta-pytorch.org/OpenEnv/).
68
+ """
69
+
70
+
71
+ def get_quick_start_markdown(
72
+ metadata: Optional[EnvironmentMetadata],
73
+ action_cls: Type[Action],
74
+ observation_cls: Type[Observation],
75
+ ) -> str:
76
+ """
77
+ Build Quick Start markdown with class names replaced from current env (init-style suffixes).
78
+
79
+ Uses the same placeholder names as the init template so that __ENV_CLASS_NAME__Env,
80
+ __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation and __ENV_NAME__ are
81
+ replaced with the actual class/package names.
82
+ """
83
+ import os
84
+
85
+ # Prefix from action class (e.g. EchoAction -> Echo)
86
+ action_name = getattr(action_cls, "__name__", "Action")
87
+ if action_name.endswith("Action"):
88
+ prefix = action_name[: -len("Action")]
89
+ else:
90
+ prefix = action_name.replace("Action", "").strip() or "Env"
91
+
92
+ env_client_name = f"{prefix}Env"
93
+ obs_name = getattr(observation_cls, "__name__", "Observation")
94
+ pkg_name = (metadata.name if metadata else "env").replace(" ", "_").lower()
95
+
96
+ space_id = os.environ.get("SPACE_ID", "<hf-username>/<hf-repo-name>")
97
+
98
+ content = DEFAULT_QUICK_START_MARKDOWN
99
+ content = content.replace("__ENV_CLASS_NAME__Env", env_client_name)
100
+ content = content.replace("__ENV_CLASS_NAME__Action", action_name)
101
+ content = content.replace("__ENV_CLASS_NAME__Observation", obs_name)
102
+ content = content.replace("__ENV_CLASS_NAME__", prefix)
103
+ content = content.replace("__ENV_NAME__", pkg_name)
104
+ content = content.replace("<SPACE_ID>", space_id)
105
+ return content.strip()
106
+
107
+
108
+ def load_environment_metadata(
109
+ env: Environment, env_name: Optional[str] = None
110
+ ) -> EnvironmentMetadata:
111
+ """
112
+ Load environment metadata including README content.
113
+
114
+ Args:
115
+ env: The environment instance, class, or factory function.
116
+ - If a class: used as a factory, won't call instance methods
117
+ - If a function: used as a factory, won't call instance methods
118
+ - If an instance: may call get_metadata() if available
119
+ env_name: Optional environment name for README file lookup
120
+
121
+ Returns:
122
+ EnvironmentMetadata with loaded information
123
+ """
124
+ import inspect
125
+
126
+ # Determine what type of env we received:
127
+ # 1. A class (used as factory) - e.g., PythonCodeActEnv
128
+ # 2. A function (factory function) - e.g., create_chat_environment
129
+ # 3. An actual instance - e.g., SnakeEnvironment()
130
+ is_class = inspect.isclass(env)
131
+ is_function = inspect.isfunction(env) or inspect.ismethod(env)
132
+ is_factory = is_class or is_function
133
+
134
+ # Try to get metadata from environment if it's an instance with get_metadata
135
+ if not is_factory and hasattr(env, "get_metadata"):
136
+ return env.get_metadata()
137
+
138
+ # Determine the class name for default metadata
139
+ if is_class:
140
+ # env is the class itself
141
+ class_name = env.__name__
142
+ elif is_function:
143
+ # env is a factory function - use its name or derive from env_name
144
+ class_name = env_name or env.__name__
145
+ else:
146
+ # env is an instance
147
+ class_name = env.__class__.__name__
148
+
149
+ # Default metadata
150
+ metadata = EnvironmentMetadata(
151
+ name=env_name or class_name,
152
+ description=f"{class_name} environment",
153
+ version="1.0.0",
154
+ )
155
+
156
+ # Try to load README from file system
157
+ readme_content = _load_readme_from_filesystem(env_name)
158
+ if readme_content:
159
+ metadata.readme_content = readme_content
160
+
161
+ return metadata
162
+
163
+
164
+ def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]:
165
+ """
166
+ Load README content from the filesystem.
167
+
168
+ Tries multiple locations:
169
+ 1. Container filesystem: /app/README.md
170
+ 2. Local development: src/envs/{env_name}/README.md
171
+ 3. Environment variable: ENV_README_PATH
172
+ """
173
+ import os
174
+ from pathlib import Path
175
+
176
+ # Try container filesystem first
177
+ container_readme = Path("/app/README.md")
178
+ if container_readme.exists():
179
+ try:
180
+ return container_readme.read_text(encoding="utf-8")
181
+ except Exception:
182
+ pass
183
+
184
+ # Try environment variable path
185
+ custom_path = os.environ.get("ENV_README_PATH")
186
+ if custom_path and Path(custom_path).exists():
187
+ try:
188
+ return Path(custom_path).read_text(encoding="utf-8")
189
+ except Exception:
190
+ pass
191
+
192
+ # Try local development path
193
+ if env_name:
194
+ local_readme = Path(f"src/envs/{env_name}/README.md")
195
+ if local_readme.exists():
196
+ try:
197
+ return local_readme.read_text(encoding="utf-8")
198
+ except Exception:
199
+ pass
200
+
201
+ return None
202
+
203
+
204
+ class ActionLog(BaseModel):
205
+ """Log entry for an action taken."""
206
+
207
+ model_config = ConfigDict(extra="forbid", validate_assignment=True)
208
+
209
+ timestamp: str = Field(description="Timestamp when action was taken")
210
+ action: Dict[str, Any] = Field(description="Action that was taken")
211
+ observation: Dict[str, Any] = Field(description="Observation returned from action")
212
+ reward: Optional[float] = Field(
213
+ default=None, description="Reward received from action"
214
+ )
215
+ done: bool = Field(description="Whether the episode is done after this action")
216
+ step_count: int = Field(description="Step count when this action was taken")
217
+
218
+
219
+ class EpisodeState(BaseModel):
220
+ """Current episode state for the web interface."""
221
+
222
+ model_config = ConfigDict(extra="forbid", validate_assignment=True)
223
+
224
+ episode_id: Optional[str] = Field(default=None, description="Current episode ID")
225
+ step_count: int = Field(description="Current step count in episode")
226
+ current_observation: Optional[Dict[str, Any]] = Field(
227
+ default=None, description="Current observation"
228
+ )
229
+ action_logs: List[ActionLog] = Field(
230
+ default_factory=list, description="List of action logs"
231
+ )
232
+ is_reset: bool = Field(
233
+ default=True, description="Whether the episode has been reset"
234
+ )
235
+
236
+
237
+ class WebInterfaceManager:
238
+ """Manages the web interface for an environment."""
239
+
240
+ MAX_ACTION_LOGS = 1000
241
+
242
+ def __init__(
243
+ self,
244
+ env: Environment,
245
+ action_cls: Type[Action],
246
+ observation_cls: Type[Observation],
247
+ metadata: Optional[EnvironmentMetadata] = None,
248
+ ):
249
+ import inspect
250
+
251
+ # If env is a class or factory function, instantiate it
252
+ if inspect.isclass(env) or inspect.isfunction(env):
253
+ self.env = env()
254
+ else:
255
+ self.env = env
256
+ self.action_cls = action_cls
257
+ self.observation_cls = observation_cls
258
+ self.metadata = metadata or EnvironmentMetadata(
259
+ name=env.__class__.__name__,
260
+ description=f"{env.__class__.__name__} environment",
261
+ )
262
+ self.episode_state = EpisodeState(
263
+ episode_id=None,
264
+ step_count=0,
265
+ current_observation=None,
266
+ action_logs=[],
267
+ )
268
+ self.connected_clients: List[WebSocket] = []
269
+ # Thread pool for running sync code (e.g., Playwright sync API) in async context
270
+ self._executor = ThreadPoolExecutor(max_workers=1)
271
+
272
+ async def _run_sync_in_thread_pool(self, func, *args, **kwargs):
273
+ """Run a synchronous function in the thread pool executor.
274
+
275
+ This is needed for environments using sync libraries (e.g., Playwright sync API)
276
+ that cannot be called directly from an async context.
277
+ """
278
+ loop = asyncio.get_event_loop()
279
+ # Use default arguments to capture values at lambda definition time
280
+ # to avoid closure issues with late binding
281
+ return await loop.run_in_executor(
282
+ self._executor, lambda f=func, a=args, kw=kwargs: f(*a, **kw)
283
+ )
284
+
285
+ async def connect_websocket(self, websocket: WebSocket):
286
+ """Connect a new WebSocket client."""
287
+ await websocket.accept()
288
+ self.connected_clients.append(websocket)
289
+
290
+ # Send current state to the new client
291
+ await self._send_state_update()
292
+
293
+ async def disconnect_websocket(self, websocket: WebSocket):
294
+ """Disconnect a WebSocket client."""
295
+ if websocket in self.connected_clients:
296
+ self.connected_clients.remove(websocket)
297
+
298
+ async def _send_state_update(self):
299
+ """Send current state to all connected clients."""
300
+ if not self.connected_clients:
301
+ return
302
+
303
+ state_data = {
304
+ "type": "state_update",
305
+ "episode_state": self.episode_state.model_dump(),
306
+ }
307
+
308
+ # Send to all connected clients
309
+ disconnected_clients = []
310
+ for client in self.connected_clients:
311
+ try:
312
+ await client.send_text(json.dumps(state_data))
313
+ except Exception:
314
+ disconnected_clients.append(client)
315
+
316
+ # Remove disconnected clients
317
+ for client in disconnected_clients:
318
+ self.connected_clients.remove(client)
319
+
320
+ async def reset_environment(self) -> Dict[str, Any]:
321
+ """Reset the environment and update state."""
322
+ # Run sync reset in thread pool to avoid blocking event loop
323
+ # and to support environments using sync libraries (e.g., Playwright)
324
+ observation: Observation = await self._run_sync_in_thread_pool(self.env.reset)
325
+ state: State = self.env.state
326
+
327
+ # Serialize observation once using shared utility
328
+ serialized = serialize_observation(observation)
329
+
330
+ # Update episode state
331
+ self.episode_state.episode_id = state.episode_id
332
+ self.episode_state.step_count = 0
333
+ self.episode_state.current_observation = serialized["observation"]
334
+ self.episode_state.action_logs = []
335
+ self.episode_state.is_reset = True
336
+
337
+ # Send state update
338
+ await self._send_state_update()
339
+
340
+ return serialized
341
+
342
+ async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
343
+ """Execute a step in the environment and update state."""
344
+ # Deserialize action with preprocessing for web interface special cases
345
+ action: Action = deserialize_action_with_preprocessing(
346
+ action_data, self.action_cls
347
+ )
348
+
349
+ # Run sync step in thread pool to avoid blocking event loop
350
+ # and to support environments using sync libraries (e.g., Playwright)
351
+ observation: Observation = await self._run_sync_in_thread_pool(
352
+ self.env.step, action
353
+ )
354
+ state: State = self.env.state
355
+
356
+ # Serialize observation once using shared utility
357
+ serialized = serialize_observation(observation)
358
+
359
+ # Create action log
360
+ action_log = ActionLog(
361
+ timestamp=datetime.now().isoformat(),
362
+ action=action.model_dump(exclude={"metadata"}),
363
+ observation=serialized["observation"],
364
+ reward=observation.reward,
365
+ done=observation.done,
366
+ step_count=state.step_count,
367
+ )
368
+
369
+ # Update episode state
370
+ self.episode_state.episode_id = state.episode_id
371
+ self.episode_state.step_count = state.step_count
372
+ self.episode_state.current_observation = serialized["observation"]
373
+ self.episode_state.action_logs.append(action_log)
374
+ if len(self.episode_state.action_logs) > self.MAX_ACTION_LOGS:
375
+ self.episode_state.action_logs = self.episode_state.action_logs[
376
+ -self.MAX_ACTION_LOGS :
377
+ ]
378
+ self.episode_state.is_reset = False
379
+
380
+ # Send state update
381
+ await self._send_state_update()
382
+
383
+ return serialized
384
+
385
+ def get_state(self) -> Dict[str, Any]:
386
+ """Get current environment state."""
387
+ state: State = self.env.state
388
+ return state.model_dump()
389
+
390
+
391
+ def create_web_interface_app(
392
+ env: Environment,
393
+ action_cls: Type[Action],
394
+ observation_cls: Type[Observation],
395
+ env_name: Optional[str] = None,
396
+ max_concurrent_envs: Optional[int] = None,
397
+ concurrency_config: Optional[Any] = None,
398
+ gradio_builder: Optional[Callable[..., Any]] = None,
399
+ ) -> FastAPI:
400
+ """
401
+ Create a FastAPI application with web interface for the given environment.
402
+
403
+ Args:
404
+ env: The Environment instance to serve
405
+ action_cls: The Action subclass this environment expects
406
+ observation_cls: The Observation subclass this environment returns
407
+ env_name: Optional environment name for README loading
408
+ max_concurrent_envs: Maximum concurrent WebSocket sessions
409
+ concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings
410
+ gradio_builder: Optional callable (web_manager, action_fields, metadata,
411
+ is_chat_env, title, quick_start_md) -> gr.Blocks to use instead of the
412
+ default Gradio UI. Lets envs replace or customize the /web interface.
413
+
414
+ Returns:
415
+ FastAPI application instance with web interface
416
+ """
417
+ from .http_server import create_fastapi_app
418
+
419
+ # Create the base environment app
420
+ app = create_fastapi_app(
421
+ env, action_cls, observation_cls, max_concurrent_envs, concurrency_config
422
+ )
423
+
424
+ # Load environment metadata
425
+ metadata = load_environment_metadata(env, env_name)
426
+
427
+ # Create web interface manager
428
+ web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata)
429
+
430
+ # Web API routes first (so they take precedence over Gradio mount at /web)
431
+ @app.get("/web/metadata")
432
+ async def web_metadata():
433
+ """Get environment metadata."""
434
+ return web_manager.metadata.model_dump()
435
+
436
+ @app.websocket("/ws/ui")
437
+ async def websocket_ui_endpoint(websocket: WebSocket):
438
+ """WebSocket endpoint for web UI real-time updates.
439
+
440
+ Note: Uses /ws/ui to avoid conflict with /ws in http_server.py
441
+ which is used for concurrent environment sessions.
442
+ """
443
+ await web_manager.connect_websocket(websocket)
444
+ try:
445
+ while True:
446
+ # Keep connection alive
447
+ await websocket.receive_text()
448
+ except WebSocketDisconnect:
449
+ await web_manager.disconnect_websocket(websocket)
450
+
451
+ @app.post("/web/reset")
452
+ async def web_reset():
453
+ """Reset endpoint for web interface."""
454
+ return await web_manager.reset_environment()
455
+
456
+ @app.post("/web/step")
457
+ async def web_step(request: Dict[str, Any]):
458
+ """Step endpoint for web interface."""
459
+ # Check if this is a message-based request (chat environment)
460
+ if "message" in request:
461
+ message = request["message"]
462
+ if hasattr(web_manager.env, "message_to_action"):
463
+ action = web_manager.env.message_to_action(message)
464
+ if hasattr(action, "tokens"):
465
+ action_data = {"tokens": action.tokens.tolist()}
466
+ else:
467
+ action_data = action.model_dump(exclude={"metadata"})
468
+ else:
469
+ action_data = {"message": message}
470
+ else:
471
+ action_data = request.get("action", {})
472
+
473
+ return await web_manager.step_environment(action_data)
474
+
475
+ @app.get("/web/state")
476
+ async def web_state():
477
+ """State endpoint for web interface."""
478
+ return web_manager.get_state()
479
+
480
+ action_fields = _extract_action_fields(action_cls)
481
+ is_chat_env = _is_chat_env(action_cls)
482
+ quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls)
483
+
484
+ default_blocks = build_gradio_app(
485
+ web_manager,
486
+ action_fields,
487
+ metadata,
488
+ is_chat_env,
489
+ title=metadata.name,
490
+ quick_start_md=quick_start_md,
491
+ )
492
+ if gradio_builder is not None:
493
+ custom_blocks = gradio_builder(
494
+ web_manager,
495
+ action_fields,
496
+ metadata,
497
+ is_chat_env,
498
+ metadata.name,
499
+ quick_start_md,
500
+ )
501
+ if not isinstance(custom_blocks, gr.Blocks):
502
+ raise TypeError(
503
+ f"gradio_builder must return a gr.Blocks instance, "
504
+ f"got {type(custom_blocks).__name__}"
505
+ )
506
+ gradio_blocks = gr.TabbedInterface(
507
+ [default_blocks, custom_blocks],
508
+ tab_names=["Playground", "Visualization"],
509
+ title=get_gradio_display_title(metadata),
510
+ )
511
+ else:
512
+ gradio_blocks = default_blocks
513
+ app = gr.mount_gradio_app(
514
+ app,
515
+ gradio_blocks,
516
+ path="/web",
517
+ theme=OPENENV_GRADIO_THEME,
518
+ css=OPENENV_GRADIO_CSS,
519
+ )
520
+
521
+ return app
522
+
523
+
524
+ def _is_chat_env(action_cls: Type[Action]) -> bool:
525
+ """Return True if the action class is a chat-style env (tokens field)."""
526
+ if hasattr(action_cls, "model_fields"):
527
+ for field_name, field_info in action_cls.model_fields.items():
528
+ if (
529
+ field_name == "tokens"
530
+ and hasattr(field_info.annotation, "__name__")
531
+ and "Tensor" in str(field_info.annotation)
532
+ ):
533
+ return True
534
+ return False
535
+
536
+
537
+ def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]:
538
+ """Extract enhanced field metadata from Action class for form generation."""
539
+ # Use Pydantic's JSON schema generation for robust metadata extraction
540
+ try:
541
+ schema = action_cls.model_json_schema()
542
+ except AttributeError:
543
+ # Fallback for non-Pydantic v2 models or if something goes wrong
544
+ return []
545
+
546
+ properties = schema.get("properties", {})
547
+ required_fields = schema.get("required", [])
548
+
549
+ action_fields = []
550
+
551
+ for field_name, field_info in properties.items():
552
+ if field_name == "metadata":
553
+ continue
554
+
555
+ # JSON schema "type" can be a string or list/undefined
556
+ # Determine our internal input type
557
+ input_type = _determine_input_type_from_schema(field_info, field_name)
558
+
559
+ is_required = field_name in required_fields
560
+
561
+ action_fields.append(
562
+ {
563
+ "name": field_name,
564
+ "type": input_type,
565
+ "required": is_required,
566
+ "description": field_info.get("description", ""),
567
+ "default_value": field_info.get("default"),
568
+ "choices": field_info.get("enum"),
569
+ "min_value": field_info.get("minimum"),
570
+ "max_value": field_info.get("maximum"),
571
+ "min_length": field_info.get("minLength"),
572
+ "max_length": field_info.get("maxLength"),
573
+ "pattern": field_info.get("pattern"),
574
+ "placeholder": _generate_placeholder(field_name, field_info),
575
+ "help_text": _generate_help_text(field_name, field_info),
576
+ }
577
+ )
578
+
579
+ return action_fields
580
+
581
+
582
+ def _determine_input_type_from_schema(
583
+ field_info: Dict[str, Any], field_name: str
584
+ ) -> str:
585
+ """Determine input type from JSON schema for form generation (Gradio UI)."""
586
+ schema_type = field_info.get("type")
587
+
588
+ # Check for specific tensor field convention
589
+ if "tokens" in field_name.lower():
590
+ return "tensor"
591
+
592
+ if "enum" in field_info:
593
+ return "select"
594
+
595
+ if schema_type == "boolean":
596
+ return "checkbox"
597
+
598
+ if schema_type == "integer" or schema_type == "number":
599
+ return "number"
600
+
601
+ if schema_type == "string":
602
+ # Check if it should be a textarea
603
+ if (
604
+ field_info.get("maxLength", 0) > 100
605
+ or "message" in field_name.lower()
606
+ or "code" in field_name.lower()
607
+ ):
608
+ return "textarea"
609
+ return "text"
610
+
611
+ # Default fallback
612
+ return "text"
613
+
614
+
615
+ def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str:
616
+ """Generate placeholder text."""
617
+ if "message" in field_name.lower():
618
+ return f"Enter {field_name.replace('_', ' ')}..."
619
+ elif "code" in field_name.lower():
620
+ return "Enter Python code here..."
621
+ elif "tokens" in field_name.lower():
622
+ return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)"
623
+ else:
624
+ return f"Enter {field_name.replace('_', ' ')}..."
625
+
626
+
627
+ def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str:
628
+ """Generate help text."""
629
+ description = field_info.get("description", "")
630
+ if description:
631
+ return description
632
+
633
+ if "action_id" in field_name.lower():
634
+ return "The action ID to execute in environment"
635
+ elif "game_name" in field_name.lower():
636
+ return "Name of game or environment"
637
+ elif "tokens" in field_name.lower():
638
+ return "Token IDs as a comma-separated list of integers"
639
+ elif "code" in field_name.lower():
640
+ return "Python code to execute in environment"
641
+ elif "message" in field_name.lower():
642
+ return "Text message to send"
643
+
644
+ return ""