diff --git a/Dockerfile b/Dockerfile index 9d94785d19b64c8dfd4c7c7057d8ac2ee4f40191..2a5c2a5d0f65e78026fc35d946ea9e43e8e5b185 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # The build script (openenv build) handles context detection and sets appropriate build args. ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest -FROM ${BASE_IMAGE} AS builder +FROM ghcr.io/meta-pytorch/openenv-base:latest AS builder WORKDIR /app @@ -40,22 +40,26 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ # Install dependencies using uv sync # If uv.lock exists, use it; otherwise resolve on the fly +RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ + install -m 0755 /root/.local/bin/uv /usr/local/bin/uv && \ + install -m 0755 /root/.local/bin/uvx /usr/local/bin/uvx + RUN --mount=type=cache,target=/root/.cache/uv \ if [ -f uv.lock ]; then \ - uv sync --frozen --no-install-project --no-editable; \ + uv sync --no-install-project --no-editable; \ else \ uv sync --no-install-project --no-editable; \ fi RUN --mount=type=cache,target=/root/.cache/uv \ if [ -f uv.lock ]; then \ - uv sync --frozen --no-editable; \ + uv sync --no-editable; \ else \ uv sync --no-editable; \ fi # Final runtime stage -FROM ${BASE_IMAGE} +FROM ghcr.io/meta-pytorch/openenv-base:latest WORKDIR /app @@ -77,5 +81,6 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ # Run the FastAPI server # The module path is constructed to work with the /app/env structure -ENV ENABLE_WEB_INTERFACE=true CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] + +ENV ENABLE_WEB_INTERFACE=true diff --git a/README.md b/README.md index 228eebfff972a8153f08ed853727d6a0cb96c129..a315c3a99ab382e8f11fef957f6f6960cc5cd844 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,33 @@ --- title: REPL Environment Server emoji: 🎮 -colorFrom: yellow -colorTo: indigo +colorFrom: blue +colorTo: green sdk: docker pinned: false app_port: 8000 base_path: /web tags: + - openenv-0.2.2 - openenv --- +## Hugging Face Space Deployment + +This Space is built from OpenEnv environment `repl_env`. + +- Space URL: `https://huggingface.co/spaces/openenv/repl` +- OpenEnv pinned ref: `0.2.2` +- Hub tag: `openenv` + +### Connecting from Code + +```python +from envs.repl_env import Env + +env = Env(base_url="https://huggingface.co/spaces/openenv/repl") +``` + # REPL Environment for OpenEnv A Python REPL environment for training language models on code execution tasks, based on the [Recursive Language Models (RLM)](https://arxiv.org/abs/2512.24601) paradigm. @@ -99,7 +116,7 @@ from repl_env import REPLEnv env = REPLEnv.from_docker_image("repl-env:latest") # Or from HuggingFace Hub -env = REPLEnv.from_hub("openenv/repl-env") +env = REPLEnv.from_hub("openenv/repl") ``` ## API Reference @@ -431,13 +448,15 @@ uv run --project . server ### Using Docker ```bash +# From the repl_env directory +cd envs/repl_env docker build -t repl-env:latest -f server/Dockerfile . docker run -p 8000:8000 repl-env:latest ``` ### Testing ```bash -pytest tests/ +pytest tests/envs/test_repl_env.py ``` ## References diff --git a/__init__.py b/__init__.py index 21c79d6e3d546515318b42f4831f3653f4dbcc3e..88316d0ff753febb91e2a4bfd81c723a3a4cdd65 100644 --- a/__init__.py +++ b/__init__.py @@ -40,20 +40,20 @@ References: - Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/ """ -from .models import REPLAction, REPLObservation, REPLState, CodeBlockResult from .client import REPLEnv +from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState from .prompts import ( - # System prompts - RLM_SYSTEM_PROMPT, - RLM_SYSTEM_PROMPT_QWEN, - # Prompt building - QueryMetadata, + build_initial_prompt, build_rlm_system_prompt, build_user_prompt, - build_initial_prompt, # Parsing utilities extract_code_blocks, - format_observation, + format_observations, + # Prompt building + QueryMetadata, + # System prompts + RLM_SYSTEM_PROMPT, + RLM_SYSTEM_PROMPT_QWEN, ) __all__ = [ @@ -74,5 +74,5 @@ __all__ = [ "build_initial_prompt", # Parsing utilities "extract_code_blocks", - "format_observation", + "format_observations", ] diff --git a/client.py b/client.py index 54cb2caab0312e4dad86f209f09344ef442c0033..608deeeb3fcbd11b8f9bd72522896bf0de37634b 100644 --- a/client.py +++ b/client.py @@ -38,11 +38,12 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING try: from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient - from .models import REPLAction, REPLObservation, REPLState, CodeBlockResult + + from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState except ImportError: + from models import CodeBlockResult, REPLAction, REPLObservation, REPLState from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient - from models import REPLAction, REPLObservation, REPLState, CodeBlockResult if TYPE_CHECKING: from .server.repl_environment import REPLEnvironment @@ -265,9 +266,7 @@ class REPLEnv: Returns: StepResult with done=True. """ - return self.step( - REPLAction(code="", is_final=True, final_answer=answer) - ) + return self.step(REPLAction(code="", is_final=True, final_answer=answer)) def get_variable(self, name: str) -> StepResult[REPLObservation]: """ @@ -315,9 +314,7 @@ class REPLEnv: self._remote_client.close() self._remote_client = None - def _wrap_observation( - self, obs: REPLObservation - ) -> StepResult[REPLObservation]: + def _wrap_observation(self, obs: REPLObservation) -> StepResult[REPLObservation]: """Wrap a local REPLObservation in a StepResult.""" return StepResult( observation=obs, diff --git a/envs/repl_env/README.md b/envs/repl_env/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f61c3a4e8a4b87d2dc047b1de1ab8713e49b4dfc --- /dev/null +++ b/envs/repl_env/README.md @@ -0,0 +1,450 @@ +--- +title: REPL Environment Server +emoji: 🎮 +colorFrom: yellow +colorTo: indigo +sdk: docker +pinned: false +app_port: 8000 +base_path: /web +tags: + - openenv +--- + +# REPL Environment for OpenEnv + +A Python REPL environment for training language models on code execution tasks, based on the [Recursive Language Models (RLM)](https://arxiv.org/abs/2512.24601) paradigm. + +## Overview + +The RLM paradigm allows language models to: +- Execute Python code in a sandboxed REPL environment +- Make recursive calls to themselves or other LMs via `llm_query()` / `llm_query_batched()` +- Handle near-infinite context by programmatically decomposing and exploring data +- Terminate with explicit `FINAL(answer)` or `answer = {"content": ..., "ready": True}` signals + +## Features + +- **Unified API**: Same `REPLEnv` class works for both local and remote execution +- **Sandboxed Python Execution**: Safe code execution with restricted builtins +- **Context Loading**: Load large contexts that agents can explore programmatically +- **Multiple Finalization Patterns**: + - Direct call: `FINAL(answer)` - helper function injected into namespace + - Print pattern: `print('FINAL(answer)')` or `print('FINAL_VAR(var_name)')` + - Prime Intellect style: `answer = {"content": "...", "ready": True}` +- **Iteration Limits**: Configurable maximum steps per episode +- **Reward Signals**: Customizable reward functions for RL training +- **Optional LLM Oracle**: Can enable `llm_query()` and `llm_query_batched()` for recursive calls + +## Quick Start + +### Local Mode (No Server Required) + +```python +from repl_env import REPLEnv + +# Create environment - runs locally by default +with REPLEnv() as env: + result = env.reset( + context="This is a large document with lots of text...", + task_prompt="Find the word count" + ) + + # Execute code iteratively + result = env.execute("words = context.split()") + result = env.execute("count = len(words)") + result = env.execute("print(f'FINAL({count})')") + + print(f"Done: {result.done}") + print(f"Final Answer: {env.state().final_answer}") +``` + +### Remote Server Mode + +```python +from repl_env import REPLEnv + +# Connect to a running server - same API! +with REPLEnv(base_url="https://my-server.hf.space") as env: + result = env.reset(context="...", task_prompt="...") + result = env.execute("count = len(context)") + result = env.execute("print(f'FINAL({count})')") +``` + +### Local Mode with LLM Support + +```python +from repl_env import REPLEnv + +def my_llm_query(prompt: str) -> str: + return your_llm.generate(prompt) + +def my_llm_query_batched(prompts: list[str]) -> list[str]: + return [my_llm_query(p) for p in prompts] + +# Pass LLM functions for recursive calls +with REPLEnv(llm_query_fn=my_llm_query, llm_batch_fn=my_llm_query_batched) as env: + result = env.reset(context=large_document, task_prompt="Summarize this") + + # Now the executed code can use llm_query() and llm_query_batched()! + result = env.execute("summary = llm_query('Summarize: ' + context[:1000])") +``` + +### From Docker or HuggingFace Hub + +```python +from repl_env import REPLEnv + +# Start from Docker image +env = REPLEnv.from_docker_image("repl-env:latest") + +# Or from HuggingFace Hub +env = REPLEnv.from_hub("openenv/repl-env") +``` + +## API Reference + +### REPLEnv + +```python +class REPLEnv: + def __init__( + self, + base_url: str | None = None, # Server URL (None = local mode) + *, + # Local-only options + llm_query_fn: Callable | None = None, # Function for llm_query() + llm_batch_fn: Callable | None = None, # Function for llm_query_batched() + max_output_length: int = 8192, # Max stdout/stderr chars + context_preview_length: int = 500, # Chars in context preview + reward_on_success: float = 1.0, # Reward on FINAL() + reward_on_iteration: float = 0.0, # Reward per step + reward_on_failure: float = -0.1, # Reward on max iterations + reward_on_error: float = -0.05, # Reward on execution error + # Remote-only options + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + ): ... + + def reset( + self, + *, + context: str = "", # Text to analyze (as `context` variable) + task_prompt: str = "", # Task description + max_iterations: int = 30, # Max code execution steps + seed: int | None = None, # Random seed + episode_id: str | None = None, # Custom episode ID + hf_token: str | None = None, # HF token for llm_query (remote mode) + llm_model: str | None = None, # Model for llm_query (remote mode) + ) -> StepResult[REPLObservation]: ... + + def execute(self, code: str) -> StepResult[REPLObservation]: ... + def step(self, action: REPLAction) -> StepResult[REPLObservation]: ... + def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]: ... + def state(self) -> REPLState: ... + def close(self) -> None: ... +``` + +### Action Space + +```python +class REPLAction: + code: str = "" # Python code to execute + is_final: bool = False # Whether this signals the final answer + final_answer: str | None = None # The final answer (if is_final=True) +``` + +### Observation Space + +```python +class REPLObservation: + result: CodeBlockResult # Execution result (stdout, stderr, etc.) + context_preview: str | None # First 500 chars of context + context_length: int # Total context length + available_variables: list # Variables in namespace + iteration: int # Current iteration + max_iterations: int # Max iterations + done: bool # Episode complete? + reward: float # Step reward + metadata: dict # Additional info (final_answer, etc.) +``` + +## Finalization Patterns + +### Pattern 1: Direct FINAL() call (recommended) +```python +result = env.execute("answer = 42") +result = env.execute("FINAL(answer)") +# -> done=True, final_answer="42" +``` + +### Pattern 2: FINAL() via print +```python +result = env.execute("answer = 42") +result = env.execute("print(f'FINAL({answer})')") +# -> done=True, final_answer="42" +``` + +### Pattern 3: FINAL_VAR() for variable reference +```python +result = env.execute("my_result = 'The answer is 42'") +# Direct call (recommended) - pass variable name as string +# FINAL_VAR looks up the variable and returns FINAL(value) +result = env.execute('FINAL_VAR("my_result")') +# -> done=True, final_answer="The answer is 42" + +# Also works via print (for regex detection) +result = env.execute("print('FINAL_VAR(my_result)')") +# -> done=True, final_answer="The answer is 42" +``` + +### Pattern 4: Prime Intellect style answer dict +```python +result = env.execute("answer['content'] = '42'") +result = env.execute("answer['ready'] = True") +# -> done=True, final_answer="42" +``` + +## Prompts Module + +The `prompts` module provides RLM-style prompts and parsing utilities: + +```python +from repl_env.prompts import ( + # System prompts (from official RLM repo) + RLM_SYSTEM_PROMPT, # Base prompt with llm_query_batched + RLM_SYSTEM_PROMPT_QWEN, # For Qwen models (adds cost warning) + + # Prompt building + QueryMetadata, # Context metadata dataclass + build_rlm_system_prompt, # Build system messages with metadata + build_user_prompt, # Build user prompt for each iteration + build_initial_prompt, # Convenience wrapper for iteration 0 + + # Parsing utilities + extract_code_blocks, # Extract code from ```repl``` or ```python``` blocks + format_observation, # Format execution result for LLM +) + +# Example: Build messages using official RLM style +query_metadata = QueryMetadata( + context_lengths=[len(context)], + context_total_length=len(context), + context_type="str", +) +messages = build_rlm_system_prompt(RLM_SYSTEM_PROMPT_QWEN, query_metadata) +messages.append(build_user_prompt(root_prompt="Count words in the context", iteration=0)) + +# Extract code from LLM response (supports ```repl``` and ```python```) +response = "Here's my solution:\n```repl\ncount = len(context.split())\nFINAL(count)\n```" +code_blocks = extract_code_blocks(response) # ["count = len(context.split())\nFINAL(count)"] +``` + +## Examples + +See the `examples/` directory for complete working examples: + +- **`examples/repl_with_llm.py`** - Full RLM loop with local Qwen model +- **`examples/repl_oolong_simple.py`** - RLM on Oolong benchmark with HuggingFace Inference API + +Run examples: +```bash +# Full RLM example with local model (requires GPU) +python examples/repl_with_llm.py + +# Oolong benchmark with HF Inference API (requires HF_TOKEN) +python examples/repl_oolong_simple.py +``` + +## Model Usage + +### Inference Loop + +A typical model inference loop where the LLM generates code and the environment executes it: + +```python +from repl_env import REPLEnv +from repl_env.prompts import RLM_SYSTEM_PROMPT, build_initial_prompt, extract_code_blocks, format_observation + +# Works with both local and remote! +with REPLEnv(base_url="http://localhost:8000") as env: # or REPLEnv() for local + result = env.reset( + context="The quick brown fox jumps over the lazy dog. " * 1000, + task_prompt="Count how many times 'fox' appears" + ) + + messages = [ + {"role": "system", "content": RLM_SYSTEM_PROMPT}, + {"role": "user", "content": build_initial_prompt( + task_prompt="Count how many times 'fox' appears", + context_length=result.observation.context_length, + context_preview=result.observation.context_preview, + variables=result.observation.available_variables, + )}, + ] + + while not result.done: + # Get code from LLM + response = your_llm.chat(messages) + code_blocks = extract_code_blocks(response) + + for code in code_blocks: + result = env.execute(code) + if result.done: + break + + # Update conversation + messages.append({"role": "assistant", "content": response}) + messages.append({"role": "user", "content": format_observation(result.observation)}) + + print(f"Final answer: {env.state().final_answer}") +``` + +### Recursive LLM Calls (RLM Paradigm) + +The key insight of RLM is that models can make recursive calls to themselves or other LLMs from within the code: + +```python +from repl_env import REPLEnv + +def llm_query(prompt: str) -> str: + """Single LLM call - model can call this from executed code""" + return your_llm.generate(prompt) + +def llm_query_batched(prompts: list[str]) -> list[str]: + """Batch LLM calls for efficiency (parallel in production)""" + return [your_llm.generate(p) for p in prompts] + +# Create environment with LLM oracle (local mode) +with REPLEnv(llm_query_fn=llm_query, llm_batch_fn=llm_query_batched) as env: + result = env.reset( + context=massive_document, # Could be 100K+ chars + task_prompt="Summarize each section and find key themes" + ) + + # The model can now generate code like this: + code = """ +# Split document into sections +sections = context.split('\\n\\n') + +# Use LLM to summarize each section (recursive call!) +summaries = llm_query_batched([f"Summarize: {s[:1000]}" for s in sections[:10]]) + +# Combine summaries +combined = '\\n'.join(summaries) + +# Final synthesis using another LLM call +answer['content'] = llm_query(f"Find key themes in: {combined}") +answer['ready'] = True +""" + + result = env.execute(code) + print(f"Done: {result.done}, Answer: {env.state().final_answer}") +``` + +### RL Training Integration + +For RL training, integrate with frameworks like TRL, prime-rl, or verifiers: + +```python +from repl_env import REPLEnv + +def collect_trajectory(env, policy, context, task): + """Collect a single trajectory for RL training""" + result = env.reset(context=context, task_prompt=task) + + trajectory = [] + total_reward = 0 + + while not result.done: + # Policy generates code + code = policy.generate(result.observation) + + # Step environment + next_result = env.execute(code) + + # Store transition + trajectory.append({ + "observation": result.observation, + "action": code, + "reward": next_result.reward, + "next_observation": next_result.observation, + "done": next_result.done, + }) + + total_reward += next_result.reward + result = next_result + + return trajectory, total_reward + +# Training loop +with REPLEnv( + reward_on_success=1.0, + reward_on_iteration=0.0, + reward_on_error=-0.05, + reward_on_failure=-0.1, +) as env: + for epoch in range(num_epochs): + for context, task, ground_truth in dataset: + trajectory, reward = collect_trajectory(env, policy, context, task) + + # Verify answer correctness (optional external reward) + if trajectory: + final_answer = env.state().final_answer + if final_answer == ground_truth: + reward += verification_bonus + + # Update policy (use your RL framework - PPO, GRPO, DPO, etc.) + policy.update(trajectory, reward) +``` + +### Reward Configuration + +Configure rewards for different outcomes: + +```python +env = REPLEnv( + reward_on_success=1.0, # When FINAL() is called + reward_on_iteration=0.0, # Per step (can be negative to encourage efficiency) + reward_on_error=-0.05, # When code execution fails + reward_on_failure=-0.1, # When max iterations reached without answer +) +``` + +## Environment Configuration + +| Environment Variable | Description | Default | +|---------------------|-------------|---------| +| `REPL_CONTEXT` | Initial context to load | "" | +| `REPL_TASK_PROMPT` | Task description | "" | +| `REPL_MAX_ITERATIONS` | Max steps per episode | 30 | +| `HF_TOKEN` | HuggingFace token for llm_query (server fallback) | None | +| `LLM_MODEL` | Model for llm_query/llm_query_batched | Qwen/Qwen3-Coder-480B-A35B-Instruct | + +## Running the Server + +### Using UV +```bash +cd envs/repl_env +uv run --project . server +``` + +### Using Docker +```bash +# From the repl_env directory +cd envs/repl_env +docker build -t repl-env:latest -f server/Dockerfile . +docker run -p 8000:8000 repl-env:latest +``` + +### Testing +```bash +pytest tests/envs/test_repl_env.py +``` + +## References + +- [RLM Paper (arXiv:2512.24601)](https://arxiv.org/abs/2512.24601) +- [RLM Implementation](https://github.com/alexzhang13/rlm) +- [Alex Zhang's RLM Blog](https://alexzhang13.github.io/blog/2025/rlm/) +- [Prime Intellect RLM Blog](https://www.primeintellect.ai/blog/rlm) diff --git a/envs/repl_env/__init__.py b/envs/repl_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88316d0ff753febb91e2a4bfd81c723a3a4cdd65 --- /dev/null +++ b/envs/repl_env/__init__.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +REPL Environment for OpenEnv. + +A Python REPL environment for training language models on code execution tasks, +based on the Recursive Language Models (RLM) paradigm. + +This environment allows language models to: +- Execute Python code in a sandboxed REPL +- Work with large contexts loaded as variables +- Finalize answers via FINAL(), FINAL_VAR(), or answer dict pattern +- Optionally make recursive LLM calls via llm_query() / llm_query_batched() + +Example: + >>> from repl_env import REPLEnv, REPLAction + >>> + >>> # Start from Docker + >>> env = REPLEnv.from_docker_image("repl-env:latest") + >>> + >>> # Reset with context + >>> result = env.reset(context="Hello World", task_prompt="Count characters") + >>> + >>> # Execute code + >>> result = env.execute("count = len(context)") + >>> result = env.execute("print(f'FINAL({count})')") + >>> + >>> # Check result + >>> print(f"Done: {result.done}, Answer: {result.observation.metadata['final_answer']}") + >>> + >>> env.close() + +References: + - RLM Paper: https://arxiv.org/abs/2512.24601 + - Prime Intellect Blog: https://www.primeintellect.ai/blog/rlm + - Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/ +""" + +from .client import REPLEnv +from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState +from .prompts import ( + build_initial_prompt, + build_rlm_system_prompt, + build_user_prompt, + # Parsing utilities + extract_code_blocks, + format_observations, + # Prompt building + QueryMetadata, + # System prompts + RLM_SYSTEM_PROMPT, + RLM_SYSTEM_PROMPT_QWEN, +) + +__all__ = [ + # Models + "REPLAction", + "REPLObservation", + "REPLState", + "CodeBlockResult", + # Client + "REPLEnv", + # System prompts + "RLM_SYSTEM_PROMPT", + "RLM_SYSTEM_PROMPT_QWEN", + # Prompt building + "QueryMetadata", + "build_rlm_system_prompt", + "build_user_prompt", + "build_initial_prompt", + # Parsing utilities + "extract_code_blocks", + "format_observations", +] diff --git a/envs/repl_env/client.py b/envs/repl_env/client.py new file mode 100644 index 0000000000000000000000000000000000000000..608deeeb3fcbd11b8f9bd72522896bf0de37634b --- /dev/null +++ b/envs/repl_env/client.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +REPL Environment Client. + +This module provides a unified client for the REPL Environment that works +with both remote servers (via WebSocket) and local execution (no server needed). + +Examples: + # Connect to remote server with your HF token for sub-LLM calls + env = REPLEnv(base_url="https://my-server.hf.space") + result = env.reset( + context="...", + task_prompt="...", + hf_token=os.environ["HF_TOKEN"], # Server uses this for llm_query + ) + + # Run locally (no server) + env = REPLEnv() + + # Local with LLM support + env = REPLEnv(llm_query_fn=my_llm, llm_batch_fn=my_batch) + + # All use the same interface + result = env.execute("x = len(context)") + env.close() +""" + +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING + +# Support both in-repo and standalone imports +try: + from openenv.core.client_types import StepResult + from openenv.core.env_client import EnvClient + + from .models import CodeBlockResult, REPLAction, REPLObservation, REPLState +except ImportError: + from models import CodeBlockResult, REPLAction, REPLObservation, REPLState + from openenv.core.client_types import StepResult + from openenv.core.env_client import EnvClient + +if TYPE_CHECKING: + from .server.repl_environment import REPLEnvironment + + +class REPLEnv: + """ + Unified client for the REPL Environment. + + Works with both remote servers and local execution, providing the same + interface regardless of where the code runs. + + Examples: + >>> # Connect to a running server + >>> with REPLEnv(base_url="http://localhost:8000") as env: + ... result = env.reset(context="Hello World", task_prompt="Count chars") + ... result = env.execute("count = len(context)") + ... result = env.execute("print(f'FINAL({count})')") + ... print(result.done) # True + + >>> # Run locally without a server + >>> with REPLEnv() as env: + ... result = env.reset(context="Hello World", task_prompt="Count chars") + ... result = env.execute("count = len(context)") + ... print(result.observation.result.success) # True + + >>> # Local with LLM support for recursive calls + >>> def my_llm(prompt: str) -> str: + ... return "LLM response" + >>> with REPLEnv(llm_query_fn=my_llm) as env: + ... result = env.reset(context="...") + ... result = env.execute("response = llm_query('Summarize: ' + context)") + + >>> # From Docker image + >>> env = REPLEnv.from_docker_image("repl-env:latest") + + >>> # From HuggingFace Hub + >>> env = REPLEnv.from_hub("openenv/repl-env") + """ + + def __init__( + self, + base_url: Optional[str] = None, + *, + # Local-only options (ignored when base_url is set) + llm_query_fn: Optional[Callable[[str], str]] = None, + llm_batch_fn: Optional[Callable[[List[str]], List[str]]] = None, + max_output_length: int = 8192, + context_preview_length: int = 500, + reward_on_success: float = 1.0, + reward_on_iteration: float = 0.0, + reward_on_failure: float = -0.1, + reward_on_error: float = -0.05, + # Connection options (ignored when running locally) + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + ): + """ + Initialize REPL environment. + + Args: + base_url: Server URL. If None, runs locally without a server. + llm_query_fn: Function for llm_query() calls (local mode only). + llm_batch_fn: Function for llm_query_batched() calls (local mode only). + max_output_length: Max stdout/stderr chars per execution (local only). + context_preview_length: Chars to show in context preview (local only). + reward_on_success: Reward when final answer submitted (local only). + reward_on_iteration: Reward per iteration step (local only). + reward_on_failure: Reward when max iterations reached (local only). + reward_on_error: Reward when code execution fails (local only). + connect_timeout_s: WebSocket connection timeout (remote only). + message_timeout_s: Message response timeout (remote only). + """ + self._base_url = base_url + self._local_env: Optional[REPLEnvironment] = None + self._remote_client: Optional[_RemoteREPLClient] = None + + # Store local-mode options + self._llm_query_fn = llm_query_fn + self._llm_batch_fn = llm_batch_fn + self._max_output_length = max_output_length + self._context_preview_length = context_preview_length + self._reward_on_success = reward_on_success + self._reward_on_iteration = reward_on_iteration + self._reward_on_failure = reward_on_failure + self._reward_on_error = reward_on_error + + # Store remote-mode options + self._connect_timeout_s = connect_timeout_s + self._message_timeout_s = message_timeout_s + + # Provider for container/runtime lifecycle (set by factory methods) + self._provider = None + + def _ensure_initialized(self) -> None: + """Initialize the appropriate backend (local or remote).""" + if self._local_env is not None or self._remote_client is not None: + return + + if self._base_url is None: + # Local mode: create REPLEnvironment directly + from .server.repl_environment import REPLEnvironment + + self._local_env = REPLEnvironment( + max_output_length=self._max_output_length, + context_preview_length=self._context_preview_length, + reward_on_success=self._reward_on_success, + reward_on_iteration=self._reward_on_iteration, + reward_on_failure=self._reward_on_failure, + reward_on_error=self._reward_on_error, + llm_query_fn=self._llm_query_fn, + llm_batch_fn=self._llm_batch_fn, + ) + else: + # Remote mode: create WebSocket client + self._remote_client = _RemoteREPLClient( + base_url=self._base_url, + connect_timeout_s=self._connect_timeout_s, + message_timeout_s=self._message_timeout_s, + provider=self._provider, + ) + self._remote_client.connect() + + def reset( + self, + *, + context: str = "", + task_prompt: str = "", + max_iterations: int = 30, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + hf_token: Optional[str] = None, + llm_model: Optional[str] = None, + ) -> StepResult[REPLObservation]: + """ + Reset the environment for a new episode. + + Args: + context: Text content to analyze (accessible as `context` variable). + task_prompt: Description of the task to solve. + max_iterations: Maximum code execution steps before timeout. + seed: Optional random seed for reproducibility. + episode_id: Optional custom episode identifier. + hf_token: Optional HuggingFace token for llm_query/llm_query_batched. + When provided, the server uses this token for sub-LLM calls + instead of its own configured token. + Security: Token is NOT stored in state or logged. + llm_model: Optional model name for LLM functions (default: Qwen3-Coder-480B). + + Returns: + StepResult with initial observation. + """ + self._ensure_initialized() + + if self._local_env is not None: + # Local mode + self._local_env.max_iterations = max_iterations + obs = self._local_env.reset( + seed=seed, + episode_id=episode_id, + context=context, + task_prompt=task_prompt, + hf_token=hf_token, + llm_model=llm_model, + ) + return self._wrap_observation(obs) + else: + # Remote mode + assert self._remote_client is not None + return self._remote_client.reset( + context=context, + task_prompt=task_prompt, + max_iterations=max_iterations, + seed=seed, + episode_id=episode_id, + hf_token=hf_token, + llm_model=llm_model, + ) + + def step(self, action: REPLAction) -> StepResult[REPLObservation]: + """ + Execute a REPL action. + + Args: + action: REPLAction containing code to execute. + + Returns: + StepResult with execution observation. + """ + self._ensure_initialized() + + if self._local_env is not None: + obs = self._local_env.step(action) + return self._wrap_observation(obs) + else: + assert self._remote_client is not None + return self._remote_client.step(action) + + def execute(self, code: str) -> StepResult[REPLObservation]: + """ + Execute Python code in the REPL. + + Convenience method that wraps step() with a code-only action. + + Args: + code: Python code to execute. + + Returns: + StepResult with execution observation. + """ + return self.step(REPLAction(code=code)) + + def submit_final_answer(self, answer: str) -> StepResult[REPLObservation]: + """ + Submit a final answer and terminate the episode. + + Args: + answer: The final answer string. + + Returns: + StepResult with done=True. + """ + return self.step(REPLAction(code="", is_final=True, final_answer=answer)) + + def get_variable(self, name: str) -> StepResult[REPLObservation]: + """ + Retrieve and print a variable from the REPL namespace. + + Args: + name: Variable name to retrieve. + + Returns: + StepResult with variable value in stdout. + """ + return self.execute(f"print(repr({name}))") + + def state(self) -> REPLState: + """ + Get current environment state. + + Returns: + REPLState with current environment information. + """ + self._ensure_initialized() + + if self._local_env is not None: + return self._local_env.state + else: + assert self._remote_client is not None + return self._remote_client.state() + + def list_variables(self) -> List[str]: + """ + Get list of available variables in the current session. + + Returns: + List of variable names. + """ + return self.state().namespace_keys + + def close(self) -> None: + """Clean up resources.""" + if self._local_env is not None: + self._local_env.close() + self._local_env = None + + if self._remote_client is not None: + self._remote_client.close() + self._remote_client = None + + def _wrap_observation(self, obs: REPLObservation) -> StepResult[REPLObservation]: + """Wrap a local REPLObservation in a StepResult.""" + return StepResult( + observation=obs, + reward=obs.reward, + done=obs.done, + ) + + # Context manager support + + def __enter__(self) -> "REPLEnv": + """Enter context manager.""" + self._ensure_initialized() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager.""" + self.close() + + # Factory methods + + @classmethod + def from_docker_image( + cls, + image: str, + **kwargs: Any, + ) -> "REPLEnv": + """ + Create a REPL environment by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "repl-env:latest"). + **kwargs: Additional arguments passed to container start. + + Returns: + Connected REPLEnv instance. + """ + from openenv.core.containers.runtime import LocalDockerProvider + + provider = LocalDockerProvider() + base_url = provider.start_container(image, **kwargs) + provider.wait_for_ready(base_url) + + env = cls(base_url=base_url) + env._provider = provider + env._ensure_initialized() + return env + + @classmethod + def from_hub( + cls, + repo_id: str, + *, + use_docker: bool = True, + **kwargs: Any, + ) -> "REPLEnv": + """ + Create a REPL environment from a HuggingFace Space. + + Args: + repo_id: HuggingFace space identifier (e.g., "openenv/repl-env"). + use_docker: If True, pull from HF registry. If False, run with UV. + **kwargs: Additional arguments passed to provider. + + Returns: + Connected REPLEnv instance. + """ + if use_docker: + from openenv.core.containers.runtime import LocalDockerProvider + + provider = LocalDockerProvider() + tag = kwargs.pop("tag", "latest") + image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + base_url = provider.start_container(image, **kwargs) + provider.wait_for_ready(base_url) + else: + from openenv.core.containers.runtime import UVProvider + + project_path = kwargs.pop( + "project_path", f"git+https://huggingface.co/spaces/{repo_id}" + ) + provider = UVProvider(project_path=project_path, **kwargs) + base_url = provider.start() + provider.wait_for_ready() + + env = cls(base_url=base_url) + env._provider = provider + env._ensure_initialized() + return env + + +class _RemoteREPLClient(EnvClient[REPLAction, REPLObservation, REPLState]): + """ + Internal WebSocket client for remote REPL connections. + + This is the original EnvClient-based implementation, now used internally + by REPLEnv for remote mode. + """ + + def _step_payload(self, action: REPLAction) -> Dict: + """Convert REPLAction to JSON payload for step request.""" + return { + "code": action.code, + "is_final": action.is_final, + "final_answer": action.final_answer, + } + + def _parse_result(self, payload: Dict) -> StepResult[REPLObservation]: + """Parse server response into StepResult[REPLObservation].""" + obs_data = payload.get("observation", {}) + result_data = obs_data.get("result", {}) + + observation = REPLObservation( + result=CodeBlockResult( + stdout=result_data.get("stdout", ""), + stderr=result_data.get("stderr", ""), + locals_snapshot=result_data.get("locals_snapshot", {}), + execution_time=result_data.get("execution_time", 0.0), + success=result_data.get("success", True), + exception=result_data.get("exception"), + ), + context_preview=obs_data.get("context_preview"), + context_length=obs_data.get("context_length", 0), + available_variables=obs_data.get("available_variables", []), + iteration=obs_data.get("iteration", 0), + max_iterations=obs_data.get("max_iterations", 30), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict) -> REPLState: + """Parse server response into REPLState object.""" + return REPLState( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + context=payload.get("context"), + task_prompt=payload.get("task_prompt"), + iteration=payload.get("iteration", 0), + max_iterations=payload.get("max_iterations", 30), + namespace_keys=payload.get("namespace_keys", []), + final_answer=payload.get("final_answer"), + total_execution_time=payload.get("total_execution_time", 0.0), + ) diff --git a/envs/repl_env/models.py b/envs/repl_env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a8651de625a19b2b77753cf756ff5ffa721f82 --- /dev/null +++ b/envs/repl_env/models.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Data models for the REPL Environment. + +The REPL environment provides a Python REPL for training language models +on code execution tasks, based on the Recursive Language Models (RLM) paradigm. + +Supports two finalization patterns: +1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)') +2. Prime Intellect style: answer = {"content": "...", "ready": True} +""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +# Support both in-repo and standalone imports +try: + from openenv.core.env_server.types import Action, Observation, State +except ImportError: + from openenv.core.env_server.types import Action, Observation, State + + +class REPLAction(Action): + """Action containing Python code to execute in the REPL. + + Supports multiple finalization patterns: + 1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)') in code + 2. Prime Intellect style: answer = {"content": "...", "ready": True} in namespace + 3. Explicit: Set is_final=True with final_answer + """ + + code: str = Field(default="", description="Python code to execute") + is_final: bool = Field( + default=False, + description="Whether this action signals the final answer", + ) + final_answer: Optional[str] = Field( + default=None, description="Final answer if is_final=True" + ) + + +class CodeBlockResult(BaseModel): + """Result of executing a single code block.""" + + stdout: str = Field(default="", description="Standard output from execution") + stderr: str = Field(default="", description="Standard error from execution") + locals_snapshot: Dict[str, str] = Field( + default_factory=dict, + description="String representations of new/modified variables", + ) + execution_time: float = Field( + default=0.0, ge=0, description="Execution time in seconds" + ) + success: bool = Field(default=True, description="Whether execution succeeded") + exception: Optional[str] = Field( + default=None, description="Exception message if execution failed" + ) + + +class REPLObservation(Observation): + """Observation returned after code execution in the REPL.""" + + result: CodeBlockResult = Field( + default_factory=CodeBlockResult, description="Result of code execution" + ) + context_preview: Optional[str] = Field( + default=None, + description="Preview of the context (first N chars) if context is loaded", + ) + context_length: int = Field( + default=0, ge=0, description="Total length of context in characters" + ) + available_variables: List[str] = Field( + default_factory=list, + description="List of variable names available in the namespace", + ) + iteration: int = Field(default=0, ge=0, description="Current iteration number") + max_iterations: int = Field( + default=30, ge=1, description="Maximum allowed iterations" + ) + + +class REPLState(State): + """Extended state for REPL environment.""" + + context: Optional[str] = Field( + default=None, description="The context/problem to work with" + ) + task_prompt: Optional[str] = Field( + default=None, description="The task description to solve" + ) + iteration: int = Field(default=0, ge=0, description="Current iteration number") + max_iterations: int = Field( + default=30, ge=1, description="Max iterations before termination" + ) + namespace_keys: List[str] = Field( + default_factory=list, description="Variables currently in namespace" + ) + final_answer: Optional[str] = Field( + default=None, description="Final answer if episode is complete" + ) + total_execution_time: float = Field( + default=0.0, ge=0, description="Total code execution time in seconds" + ) diff --git a/envs/repl_env/openenv.yaml b/envs/repl_env/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5181c365ae40fae66486c43b0f84ff4c6d2487a1 --- /dev/null +++ b/envs/repl_env/openenv.yaml @@ -0,0 +1,6 @@ +spec_version: 1 +name: repl +type: space +runtime: fastapi +app: server.app:app +port: 8000 diff --git a/envs/repl_env/prompts.py b/envs/repl_env/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3b7feb9094c486d02108692b004c48ce973491 --- /dev/null +++ b/envs/repl_env/prompts.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +RLM System Prompts and Parsing Utilities for the REPL Environment. + +Based on the official RLM repo: https://github.com/alexzhang13/rlm + +Two versions available: +- RLM_SYSTEM_PROMPT: Base prompt from the repo (with llm_query_batched) +- RLM_SYSTEM_PROMPT_QWEN: For Qwen3-Coder-480B (adds IMPORTANT cost warning) + +Parsing utilities help extract code blocks and format observations. +""" + +import re +import textwrap +from dataclasses import dataclass +from typing import List, Optional + + +# ============================================================================= +# Query Metadata (for context info) +# ============================================================================= + + +@dataclass +class QueryMetadata: + """Metadata about the context for building prompts.""" + + context_lengths: List[int] + context_total_length: int + context_type: str = "str" # "str" or "List[str]" + + +# ============================================================================= +# System Prompt from Official RLM Repo +# ============================================================================= + +RLM_SYSTEM_PROMPT = textwrap.dedent( + """You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer. + + The REPL environment is initialized with: +1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query. +2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment. +3. A `llm_query_batched` function that allows you to query multiple prompts concurrently: `llm_query_batched(prompts: List[str]) -> List[str]`. This is much faster than sequential `llm_query` calls when you have multiple independent queries. Results are returned in the same order as the input prompts. +4. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning. + +You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer. +Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer. + +You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. For example, a viable strategy is to feed 10 documents per sub-LLM query. Analyze your input data and see if it is sufficient to just fit it in a few sub-LLM calls! + +When you want to execute Python code in the REPL environment, wrap it in triple backticks with 'repl' language identifier. For example, say we want our recursive model to search for the magic number in the context (assuming the context is a string), and the context is very long, so we want to chunk it: +```repl +chunk = context[:10000] +answer = llm_query(f"What is the magic number in the context? Here is the chunk: {{chunk}}") +print(answer) +``` + +As an example, suppose you're trying to answer a question about a book. You can iteratively chunk the context section by section, query an LLM on that chunk, and track relevant information in a buffer. +```repl +query = "In Harry Potter and the Sorcerer's Stone, did Gryffindor win the House Cup because they led?" +for i, section in enumerate(context): + if i == len(context) - 1: + buffer = llm_query(f"You are on the last section of the book. So far you know that: {{buffers}}. Gather from this last section to answer {{query}}. Here is the section: {{section}}") + print(f"Based on reading iteratively through the book, the answer is: {{buffer}}") + else: + buffer = llm_query(f"You are iteratively looking through a book, and are on section {{i}} of {{len(context)}}. Gather information to help answer {{query}}. Here is the section: {{section}}") + print(f"After section {{i}} of {{len(context)}}, you have tracked: {{buffer}}") +``` + +As another example, when the context isn't that long (e.g. >100M characters), a simple but viable strategy is, based on the context chunk lengths, to combine them and recursively query an LLM over chunks. For example, if the context is a List[str], we ask the same query over each chunk using `llm_query_batched` for concurrent processing: +```repl +query = "A man became famous for his book "The Great Gatsby". How many jobs did he have?" +# Suppose our context is ~1M chars, and we want each sub-LLM query to be ~0.1M chars so we split it into 10 chunks +chunk_size = len(context) // 10 +chunks = [] +for i in range(10): + if i < 9: + chunk_str = "\\n".join(context[i*chunk_size:(i+1)*chunk_size]) + else: + chunk_str = "\\n".join(context[i*chunk_size:]) + chunks.append(chunk_str) + +# Use batched query for concurrent processing - much faster than sequential calls! +prompts = [f"Try to answer the following query: {{query}}. Here are the documents:\\n{{chunk}}. Only answer if you are confident in your answer based on the evidence." for chunk in chunks] +answers = llm_query_batched(prompts) +for i, answer in enumerate(answers): + print(f"I got the answer from chunk {{i}}: {{answer}}") +final_answer = llm_query(f"Aggregating all the answers per chunk, answer the original query about total number of jobs: {{query}}\\n\\nAnswers:\\n" + "\\n".join(answers)) +``` + +As a final example, after analyzing the context and realizing its separated by Markdown headers, we can maintain state through buffers by chunking the context by headers, and iteratively querying an LLM over it: +```repl +# After finding out the context is separated by Markdown headers, we can chunk, summarize, and answer +import re +sections = re.split(r'### (.+)', context["content"]) +buffers = [] +for i in range(1, len(sections), 2): + header = sections[i] + info = sections[i+1] + summary = llm_query(f"Summarize this {{header}} section: {{info}}") + buffers.append(f"{{header}}: {{summary}}") +final_answer = llm_query(f"Based on these summaries, answer the original query: {{query}}\\n\\nSummaries:\\n" + "\\n".join(buffers)) +``` +In the next step, we can return FINAL_VAR("final_answer"). + +IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of the FINAL functions. Do not use these unless you have completed your task. You have two options: +1. Use FINAL(value) to provide the answer directly, e.g., FINAL(42) or FINAL(my_variable) +2. Use FINAL_VAR("variable_name") to return a variable by name, e.g., FINAL_VAR("final_answer") + +Think step by step carefully, plan, and execute this plan immediately in your response -- do not just say "I will do this" or "I will do that". Output to the REPL environment and recursive LLMs as much as possible. Remember to explicitly answer the original query in your final answer. +""" +) + + +# ============================================================================= +# System Prompt for Qwen3-Coder-480B (with IMPORTANT cost warning from paper) +# Adds cost warning after the "sub LLMs are powerful" paragraph +# ============================================================================= + +RLM_SYSTEM_PROMPT_QWEN = textwrap.dedent( + """You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer. + +The REPL environment is initialized with: +1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query. +2. A `llm_query` function that allows you to query an LLM (that can handle around 500K chars) inside your REPL environment. +3. A `llm_query_batched` function that allows you to query multiple prompts concurrently: `llm_query_batched(prompts: List[str]) -> List[str]`. This is much faster than sequential `llm_query` calls when you have multiple independent queries. Results are returned in the same order as the input prompts. +4. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning. + +You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer. +Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer. + +You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. For example, a viable strategy is to feed 10 documents per sub-LLM query. Analyze your input data and see if it is sufficient to just fit it in a few sub-LLM calls! + +IMPORTANT: Be very careful about using 'llm_query' as it incurs high runtime costs. Always batch as much information as reasonably possible into each call (aim for around ~200k characters per call). For example, if you have 1000 lines of information to process, it's much better to split into chunks of 5 and call 'llm_query' on each chunk (200 calls total) rather than making 1000 individual calls. Minimize the number of 'llm_query' calls by batching related information together. + +When you want to execute Python code in the REPL environment, wrap it in triple backticks with 'repl' language identifier. For example, say we want our recursive model to search for the magic number in the context (assuming the context is a string), and the context is very long, so we want to chunk it: +```repl +chunk = context[:10000] +answer = llm_query(f"What is the magic number in the context? Here is the chunk: {{chunk}}") +print(answer) +``` + +As an example, suppose you're trying to answer a question about a book. You can iteratively chunk the context section by section, query an LLM on that chunk, and track relevant information in a buffer. +```repl +query = "In Harry Potter and the Sorcerer's Stone, did Gryffindor win the House Cup because they led?" +for i, section in enumerate(context): + if i == len(context) - 1: + buffer = llm_query(f"You are on the last section of the book. So far you know that: {{buffers}}. Gather from this last section to answer {{query}}. Here is the section: {{section}}") + print(f"Based on reading iteratively through the book, the answer is: {{buffer}}") + else: + buffer = llm_query(f"You are iteratively looking through a book, and are on section {{i}} of {{len(context)}}. Gather information to help answer {{query}}. Here is the section: {{section}}") + print(f"After section {{i}} of {{len(context)}}, you have tracked: {{buffer}}") +``` + +As another example, when the context isn't that long (e.g. >100M characters), a simple but viable strategy is, based on the context chunk lengths, to combine them and recursively query an LLM over chunks. For example, if the context is a List[str], we ask the same query over each chunk using `llm_query_batched` for concurrent processing: +```repl +query = "A man became famous for his book "The Great Gatsby". How many jobs did he have?" +# Suppose our context is ~1M chars, and we want each sub-LLM query to be ~0.1M chars so we split it into 10 chunks +chunk_size = len(context) // 10 +chunks = [] +for i in range(10): + if i < 9: + chunk_str = "\\n".join(context[i*chunk_size:(i+1)*chunk_size]) + else: + chunk_str = "\\n".join(context[i*chunk_size:]) + chunks.append(chunk_str) + +# Use batched query for concurrent processing - much faster than sequential calls! +prompts = [f"Try to answer the following query: {{query}}. Here are the documents:\\n{{chunk}}. Only answer if you are confident in your answer based on the evidence." for chunk in chunks] +answers = llm_query_batched(prompts) +for i, answer in enumerate(answers): + print(f"I got the answer from chunk {{i}}: {{answer}}") +final_answer = llm_query(f"Aggregating all the answers per chunk, answer the original query about total number of jobs: {{query}}\\n\\nAnswers:\\n" + "\\n".join(answers)) +``` + +As a final example, after analyzing the context and realizing its separated by Markdown headers, we can maintain state through buffers by chunking the context by headers, and iteratively querying an LLM over it: +```repl +# After finding out the context is separated by Markdown headers, we can chunk, summarize, and answer +import re +sections = re.split(r'### (.+)', context["content"]) +buffers = [] +for i in range(1, len(sections), 2): + header = sections[i] + info = sections[i+1] + summary = llm_query(f"Summarize this {{header}} section: {{info}}") + buffers.append(f"{{header}}: {{summary}}") +final_answer = llm_query(f"Based on these summaries, answer the original query: {{query}}\\n\\nSummaries:\\n" + "\\n".join(buffers)) +``` +In the next step, we can return FINAL_VAR("final_answer"). + +IMPORTANT: When you are done with the iterative process, you MUST provide a final answer using one of the FINAL functions. Do not use these unless you have completed your task. You have two options: +1. Use FINAL(value) to provide the answer directly, e.g., FINAL(42) or FINAL(my_variable) +2. Use FINAL_VAR("variable_name") to return a variable by name, e.g., FINAL_VAR("final_answer") + +Think step by step carefully, plan, and execute this plan immediately in your response -- do not just say "I will do this" or "I will do that". Output to the REPL environment and recursive LLMs as much as possible. Remember to explicitly answer the original query in your final answer. +""" +) + + +# ============================================================================= +# User Prompt Templates (from official RLM repo) +# ============================================================================= + +USER_PROMPT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the prompt.\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:""" + +USER_PROMPT_WITH_ROOT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the original prompt: \"{root_prompt}\".\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:""" + + +# ============================================================================= +# Prompt Building Functions (from official RLM repo) +# ============================================================================= + + +def build_rlm_system_prompt( + system_prompt: str, + query_metadata: QueryMetadata, +) -> List[dict]: + """ + Build the initial system prompt for the REPL environment based on extra prompt metadata. + + Args: + system_prompt: The system prompt to use + query_metadata: QueryMetadata object containing context metadata + + Returns: + List of message dictionaries [system, assistant(metadata)] + """ + context_lengths = query_metadata.context_lengths + context_total_length = query_metadata.context_total_length + context_type = query_metadata.context_type + + # If there are more than 100 chunks, truncate to the first 100 chunks. + if len(context_lengths) > 100: + others = len(context_lengths) - 100 + context_lengths_str = ( + str(context_lengths[:100]) + "... [" + str(others) + " others]" + ) + else: + context_lengths_str = str(context_lengths) + + metadata_prompt = f"Your context is a {context_type} with {context_total_length} total characters, and is broken up into chunks of char lengths: {context_lengths_str}." + + return [ + {"role": "system", "content": system_prompt}, + {"role": "assistant", "content": metadata_prompt}, + ] + + +def build_user_prompt( + root_prompt: Optional[str] = None, + iteration: int = 0, + context_count: int = 1, + history_count: int = 0, +) -> dict: + """ + Build the user prompt for a given iteration. + + Args: + root_prompt: The original query/task + iteration: Current iteration number (0 = first) + context_count: Number of context variables available + history_count: Number of prior conversation histories + + Returns: + User message dict + """ + if iteration == 0: + safeguard = "You have not interacted with the REPL environment or seen your prompt / context yet. Your next action should be to look through and figure out how to answer the prompt, so don't just provide a final answer yet.\n\n" + prompt = safeguard + ( + USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) + if root_prompt + else USER_PROMPT + ) + else: + prompt = ( + "The history before is your previous interactions with the REPL environment. " + + ( + USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) + if root_prompt + else USER_PROMPT + ) + ) + + # Inform model about multiple contexts if present + if context_count > 1: + prompt += f"\n\nNote: You have {context_count} contexts available (context_0 through context_{context_count - 1})." + + # Inform model about prior conversation histories if present + if history_count > 0: + if history_count == 1: + prompt += "\n\nNote: You have 1 prior conversation history available in the `history` variable." + else: + prompt += f"\n\nNote: You have {history_count} prior conversation histories available (history_0 through history_{history_count - 1})." + + return {"role": "user", "content": prompt} + + +# ============================================================================= +# Convenience Functions (for backward compatibility) +# ============================================================================= + + +def build_initial_prompt( + task_prompt: str, + context_length: int, + context_preview: Optional[str] = None, + variables: Optional[List[str]] = None, + **kwargs, +) -> str: + """Build the initial user prompt (convenience wrapper). + + Args: + task_prompt: The task to accomplish + context_length: Total length of the context + context_preview: Preview of the context (not used) + variables: List of available variable names (not used) + + Returns: + Formatted initial prompt string + """ + return build_user_prompt(root_prompt=task_prompt, iteration=0)["content"] + + +# ============================================================================= +# Parsing Utilities +# ============================================================================= + + +def extract_code_blocks(text: str, language: str = "python") -> List[str]: + """Extract code blocks from LLM response. + + Supports both ```repl``` (official RLM) and ```python``` style blocks. + + Args: + text: The LLM response text + language: Language identifier to match (default "python") + + Returns: + List of code strings extracted from the response + """ + # Match 'repl' (official) and 'python' (common alternative) + patterns = [ + r"```repl\s*(.*?)```", + rf"```{language}\s*(.*?)```", + ] + + all_matches = [] + for pattern in patterns: + matches = re.findall(pattern, text, re.DOTALL) + all_matches.extend(m.strip() for m in matches if m.strip()) + + return all_matches + + +def format_observations(observations) -> str: + """Format REPL observations into observation text for the LLM. + + Args: + observations: List of REPL observations from env.step() + + Returns: + Formatted observation string + """ + formatted = [] + for i, observation in enumerate(observations, 1): + output = ( + observation.result.stdout.strip() + if observation.result.stdout + else "(no output)" + ) + if observation.result.success: + formatted.append(f"Code block {i} output:\n{output}") + else: + error = ( + observation.result.stderr + or observation.result.exception + or "Unknown error" + ) + formatted.append( + f"Code block {i} output:\n{output}\n\nERROR: {error}\n" + f"Fix the error in code block {i}. Remember: 'context' is already defined." + ) + return "\n\n".join(formatted) diff --git a/envs/repl_env/pyproject.toml b/envs/repl_env/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..b1073b5bd91840f63f4a8ce84a852b57ef442701 --- /dev/null +++ b/envs/repl_env/pyproject.toml @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "openenv-repl" +version = "0.1.0" +description = "Recursive Language Model REPL Environment for OpenEnv" +requires-python = ">=3.10" +dependencies = [ + # Core OpenEnv dependencies (required for server functionality) + "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main", + "fastapi>=0.115.0", + "pydantic>=2.0.0", + "uvicorn>=0.24.0", + "requests>=2.31.0", + # Environment-specific dependencies + "smolagents>=1.22.0,<2", + # LLM support via HuggingFace Inference API + "huggingface_hub>=0.20.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-cov>=4.0.0", +] + +[project.scripts] +# Server entry point - enables running via: uv run --project . server +# or: python -m repl_env.server.app +server = "repl_env.server.app:main" + +[tool.setuptools] +# Explicitly list packages - "repl_env" maps to current dir +packages = ["repl_env", "repl_env.server"] +package-dir = {"repl_env" = ".", "repl_env.server" = "server"} diff --git a/envs/repl_env/server/Dockerfile b/envs/repl_env/server/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..756d65c6cf58644f550b5b257699698ccb26440e --- /dev/null +++ b/envs/repl_env/server/Dockerfile @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Multi-stage build using openenv-base +# This Dockerfile is flexible and works for both: +# - In-repo environments (with local src/core) +# - Standalone environments (with openenv from pip) +# The build script (openenv build) handles context detection and sets appropriate build args. + +ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ${BASE_IMAGE} AS builder + +WORKDIR /app + +# Build argument to control whether we're building standalone or in-repo +ARG BUILD_MODE=in-repo +ARG ENV_NAME=repl_env + +# Copy environment code (always at root of build context) +COPY . /app/env + +# For in-repo builds, openenv-core is already in the pyproject.toml dependencies +# For standalone builds, openenv-core will be installed from pip via pyproject.toml +WORKDIR /app/env + +# Ensure uv is available (for local builds where base image lacks it) +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +# Install git for building from git repos (build-time only) +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies using uv sync +# If uv.lock exists, use it; otherwise resolve on the fly +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-install-project --no-editable; \ + else \ + uv sync --no-install-project --no-editable; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-editable; \ + else \ + uv sync --no-editable; \ + fi + +# Final runtime stage +FROM ${BASE_IMAGE} + +WORKDIR /app + +# Copy the virtual environment from builder +COPY --from=builder /app/env/.venv /app/.venv + +# Copy the environment code +COPY --from=builder /app/env /app/env + +# Set PATH to use the virtual environment +ENV PATH="/app/.venv/bin:$PATH" + +# Set PYTHONPATH so imports work correctly +ENV PYTHONPATH="/app/env:$PYTHONPATH" + +# Health check using Python (more portable than curl/wget) +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1 + +# Run the FastAPI server +# The module path is constructed to work with the /app/env structure +CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/envs/repl_env/server/__init__.py b/envs/repl_env/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34a07139a0671dfa10d32e5b9655c9135291912a --- /dev/null +++ b/envs/repl_env/server/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +REPL Environment Server Components. + +This module contains the server-side implementation of the REPL environment. +""" + +from .python_executor import PythonExecutor +from .repl_environment import REPLEnvironment + +__all__ = [ + "REPLEnvironment", + "PythonExecutor", +] diff --git a/envs/repl_env/server/app.py b/envs/repl_env/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1057dea982750b02115b7e872c74ce59b125d8 --- /dev/null +++ b/envs/repl_env/server/app.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +FastAPI application for the REPL Environment. + +This module creates an HTTP server that exposes the REPLEnvironment +over HTTP and WebSocket endpoints, compatible with EnvClient. + +The server includes llm_query and llm_query_batched support via HuggingFace Inference API, +enabling the Recursive Language Model (RLM) paradigm. + +LLM Token Configuration: + 1. Client can pass `hf_token` in reset() - RECOMMENDED + 2. Server fallback: HF_TOKEN environment variable + +LLM functions are created dynamically in REPLEnvironment.reset() based on the +available token (client or server). + +Usage: + # Development (with auto-reload): + uvicorn server.app:app --reload --host 0.0.0.0 --port 8000 + + # Production: + uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4 + + # Or run directly: + uv run --project . server + +Environment Variables: + HF_TOKEN: Fallback HuggingFace API token (client token takes priority) + LLM_MODEL: Model to use for llm_query/llm_query_batched (default: Qwen/Qwen3-Coder-480B-A35B-Instruct) +""" + +import os + +# Support both in-repo and standalone imports +try: + # In-repo imports (when running from OpenEnv repository) + from openenv.core.env_server.http_server import create_app + + from ..models import REPLAction, REPLObservation + from .repl_environment import REPLEnvironment +except ImportError: + from models import REPLAction, REPLObservation + + # Standalone imports (when environment is standalone with openenv from pip) + from openenv.core.env_server.http_server import create_app + from server.repl_environment import REPLEnvironment + + +# ============== LLM CONFIGURATION ============== +LLM_MODEL = os.environ.get("LLM_MODEL", "Qwen/Qwen3-Coder-480B-A35B-Instruct") +HF_TOKEN = os.environ.get("HF_TOKEN", None) +# =============================================== + +# Log LLM configuration +if HF_TOKEN: + print(f"[REPL Server] LLM support ENABLED (server token configured)") + print(f"[REPL Server] Default model: {LLM_MODEL}") +else: + print("[REPL Server] No server HF_TOKEN configured") + print( + "[REPL Server] LLM functions will be enabled if client passes hf_token in reset()" + ) + +# Simple factory - LLM functions are created dynamically in reset() based on token +env_factory = REPLEnvironment + +# Create the app with web interface and README integration +app = create_app(env_factory, REPLAction, REPLObservation, env_name="repl_env") + + +def main(): + """ + Entry point for direct execution via uv run or python -m. + + This function enables running the server without Docker: + uv run --project . server + python -m envs.repl_env.server.app + openenv serve repl_env + """ + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) + + +if __name__ == "__main__": + main() diff --git a/envs/repl_env/server/python_executor.py b/envs/repl_env/server/python_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef71e46e81d478860ea2614fd9336e8ecfd4010 --- /dev/null +++ b/envs/repl_env/server/python_executor.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Sandboxed Python code executor for the REPL environment. + +Uses smolagents.LocalPythonExecutor as the backend for battle-tested sandboxed +execution, with RLM-specific features on top: +- Context loading (set_context) +- Variable access (get_variable, list_variables) +- Function injection (inject_function for llm_query, llm_query_batched) +- Output capped at 8,192 characters per turn (configurable) +- Persistent namespace across code blocks +""" + +import json +import logging +import time +import traceback +from collections.abc import Callable +from typing import Any, Dict, List, Optional + +from smolagents import LocalPythonExecutor + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class PythonExecutor: + """Sandboxed Python code executor with persistent namespace. + + Wraps smolagents.LocalPythonExecutor with RLM-specific features: + - Context loading for RLM tasks + - Variable tracking for observation + - Function injection for llm_query, llm_query_batched + - Configurable output length limit (default 8192 chars per Prime Intellect) + """ + + def __init__( + self, + max_output_length: int = 8192, + allowed_imports: Optional[List[str]] = None, + ): + """Initialize the executor. + + Args: + max_output_length: Maximum characters for stdout/stderr (default 8192) + allowed_imports: List of allowed module names for import + + Note: + smolagents.LocalPythonExecutor does NOT support wall-clock timeouts. + Instead, it limits operations (10M ops) and while iterations (1M). + """ + self.max_output_length = max_output_length + + # Default allowed imports for RLM tasks + default_imports = [ + "re", + "json", + "math", + "random", + "collections", + "itertools", + "functools", + "operator", + "string", + "textwrap", + "difflib", + "statistics", + "decimal", + "fractions", + "datetime", + "copy", + "pprint", + "typing", + "dataclasses", + "enum", + "bisect", + "heapq", + "array", + "struct", + "base64", + "hashlib", + "hmac", + "uuid", + ] + + self.allowed_imports = allowed_imports or default_imports + + # Initialize the smolagents executor + self._executor = LocalPythonExecutor( + additional_authorized_imports=self.allowed_imports + ) + + # Track variables we've set (for list_variables) + self._user_variables: set[str] = set() + + # Track callable functions to register with send_tools + self._callable_tools: Dict[str, Callable[..., Any]] = {} + + # Register helper utilities + self._register_helpers() + + def _register_helpers(self) -> None: + """Register helper functions with the executor.""" + helpers = { + "format_exc": traceback.format_exc, + "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)), + } + # Register helpers as callable tools + for name, func in helpers.items(): + self.inject_function(name, func) + + def _sync_callable_tools(self) -> None: + """Sync callable functions with the executor via send_tools.""" + if self._callable_tools: + try: + # Type ignore: smolagents accepts callables despite Tool type hint + self._executor.send_tools(self._callable_tools) # type: ignore[arg-type] + except Exception: + logger.debug( + "send_tools failed; continuing without extra tools", + exc_info=True, + ) + + def set_context(self, context: str, variable_name: str = "context") -> None: + """Load context into namespace as a variable. + + Args: + context: The context string to load + variable_name: Name of the variable (default "context") + """ + self.set_variable(variable_name, context) + + def set_variable(self, name: str, value: Any) -> None: + """Set a variable in the namespace. + + Args: + name: Variable name + value: Variable value + """ + # Access the executor's internal state to set variables + if hasattr(self._executor, "state"): + self._executor.state[name] = value + else: + # Fallback: store in injected vars for later retrieval + self._executor._injected_vars = getattr( + self._executor, "_injected_vars", {} + ) + self._executor._injected_vars[name] = value + + self._user_variables.add(name) + + def get_variable(self, name: str) -> Optional[Any]: + """Retrieve a variable from namespace. + + Args: + name: Variable name + + Returns: + The variable value or None if not found + """ + # Try to get from executor's state + if hasattr(self._executor, "state"): + return self._executor.state.get(name) + + # Fallback to injected vars + if hasattr(self._executor, "_injected_vars"): + return self._executor._injected_vars.get(name) + + return None + + def list_variables(self) -> List[str]: + """List non-private variables in namespace. + + Returns: + List of variable names (excluding private and builtins) + """ + variables = set() + + # Get from executor's state + if hasattr(self._executor, "state"): + for key in self._executor.state: + if not key.startswith("_"): + variables.add(key) + + # Include tracked user variables + variables.update(self._user_variables) + + return list(variables) + + def execute(self, code: str) -> Dict[str, Any]: + """Execute Python code and return results. + + Args: + code: Python code to execute + + Returns: + Dictionary with stdout, stderr, locals_snapshot, execution_time, + success, and exception fields + """ + start_time = time.time() + success = True + exception_msg = None + new_locals: Dict[str, str] = {} + + # Track state before execution + pre_state_keys = set() + if hasattr(self._executor, "state"): + pre_state_keys = set(self._executor.state.keys()) + + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + + try: + exec_result = self._executor(code) + + # Extract logs/prints + try: + logs = getattr(exec_result, "logs", None) + if logs: + stdout_parts.append(str(logs)) + except Exception: + logger.debug("Failed to read exec_result.logs", exc_info=True) + + # Extract the result / output value + try: + if hasattr(exec_result, "output"): + out_val = exec_result.output + if out_val is not None: + try: + stdout_parts.append(json.dumps(out_val)) + except Exception: + stdout_parts.append(repr(out_val)) + except Exception: + logger.debug("Failed to read exec_result.output", exc_info=True) + + # Check for errors + try: + err = getattr(exec_result, "error", None) + if err: + stderr_parts.append(str(err)) + success = False + exception_msg = str(err) + except Exception: + logger.debug("Failed to read exec_result.error", exc_info=True) + + try: + ex = getattr(exec_result, "exception", None) + if ex: + stderr_parts.append(str(ex)) + success = False + exception_msg = str(ex) + except Exception: + logger.debug("Failed to read exec_result.exception", exc_info=True) + + # Determine success from exit_code if available + try: + if hasattr(exec_result, "exit_code"): + if exec_result.exit_code is not None and exec_result.exit_code != 0: + success = False + elif hasattr(exec_result, "success"): + success = bool(exec_result.success) + except Exception: + logger.debug("Failed to determine exec_result exit code", exc_info=True) + + except Exception as e: + success = False + exception_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" + stderr_parts.append(exception_msg) + + execution_time = time.time() - start_time + + # Capture new/modified variables + if hasattr(self._executor, "state"): + for key in self._executor.state: + if key not in pre_state_keys and not key.startswith("_"): + try: + val = self._executor.state[key] + val_repr = repr(val) + if len(val_repr) > 500: + val_repr = val_repr[:500] + "..." + new_locals[key] = val_repr + self._user_variables.add(key) + except Exception: + new_locals[key] = "" + + # Compose stdout/stderr + stdout = "\n".join(part for part in stdout_parts if part) + stderr = "\n".join(part for part in stderr_parts if part) + + # Truncate output to max_output_length + if len(stdout) > self.max_output_length: + stdout = ( + stdout[: self.max_output_length] + + f"\n... (truncated, total {len(stdout)} chars)" + ) + + if len(stderr) > self.max_output_length: + stderr = ( + stderr[: self.max_output_length] + + f"\n... (truncated, total {len(stderr)} chars)" + ) + + return { + "stdout": stdout, + "stderr": stderr, + "locals_snapshot": new_locals, + "execution_time": execution_time, + "success": success, + "exception": exception_msg, + } + + def reset(self) -> None: + """Reset namespace to initial state.""" + # Create a new executor instance + self._executor = LocalPythonExecutor( + additional_authorized_imports=self.allowed_imports + ) + self._user_variables.clear() + self._callable_tools.clear() + self._register_helpers() + + def inject_function(self, name: str, func: Callable[..., Any]) -> None: + """Inject a callable function into the namespace. + + Used for adding llm_query, llm_query_batched, FINAL, etc. + + Args: + name: Function name in namespace + func: The callable to inject + """ + # Add to callable tools and sync with executor + self._callable_tools[name] = func + self._user_variables.add(name) + self._sync_callable_tools() diff --git a/envs/repl_env/server/repl_environment.py b/envs/repl_env/server/repl_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3d65b2a491834af2552cd8f8f1cb446ae5d3a5 --- /dev/null +++ b/envs/repl_env/server/repl_environment.py @@ -0,0 +1,516 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +REPL Environment Implementation. + +A Python REPL environment for training language models on code execution tasks, +based on the Recursive Language Models (RLM) paradigm. + +References: +- RLM Paper: https://arxiv.org/abs/2512.24601 +- Prime Intellect Blog: https://www.primeintellect.ai/blog/rlm +- Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/ +""" + +import os +import re +from collections.abc import Callable +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +# Support both in-repo and standalone imports +try: + from openenv.core.env_server.interfaces import Environment + from openenv.core.env_server.types import EnvironmentMetadata +except ImportError: + from openenv.core.env_server.interfaces import Environment + from openenv.core.env_server.types import EnvironmentMetadata + +try: + from ..models import CodeBlockResult, REPLAction, REPLObservation, REPLState +except ImportError: + from models import CodeBlockResult, REPLAction, REPLObservation, REPLState + +try: + from .python_executor import PythonExecutor +except ImportError: + from python_executor import PythonExecutor + + +class REPLEnvironment(Environment): + """ + A REPL environment for training language models to use code execution. + + Based on the Recursive Language Models (RLM) paradigm, this environment allows + language models to: + - Execute Python code in a sandboxed REPL + - Work with large contexts loaded as variables + - Finalize answers via FINAL(), FINAL_VAR(), or answer dict pattern + - Optionally make recursive LLM calls via llm_query() / llm_query_batched() + + Supports two finalization patterns: + 1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)') + 2. Prime Intellect style: answer = {"content": "...", "ready": True} + + Example: + >>> env = REPLEnvironment(context="Hello World", task_prompt="Count chars") + >>> obs = env.reset() + >>> print(obs.context_preview) # "Hello World" + >>> + >>> obs = env.step(REPLAction(code="result = len(context)")) + >>> print(obs.result.success) # True + >>> print(obs.available_variables) # ["context", "result", "answer"] + >>> + >>> obs = env.step(REPLAction(code="print(f'FINAL({result})')")) + >>> print(obs.done) # True + >>> print(obs.metadata["final_answer"]) # "11" + """ + + SUPPORTS_CONCURRENT_SESSIONS = True + + def __init__( + self, + context: Optional[str] = None, + task_prompt: Optional[str] = None, + max_iterations: int = 30, + max_output_length: int = 8192, + context_preview_length: int = 500, + reward_on_success: float = 1.0, + reward_on_iteration: float = 0.0, + reward_on_failure: float = -0.1, + reward_on_error: float = -0.05, + llm_query_fn: Optional[Callable[[str], str]] = None, + llm_batch_fn: Optional[Callable[[List[str]], List[str]]] = None, + ): + """Initialize the REPL environment. + + Args: + context: Initial context to load (can also be set via REPL_CONTEXT env var) + task_prompt: Task description (can also be set via REPL_TASK_PROMPT env var) + max_iterations: Maximum steps per episode (default 30, env var REPL_MAX_ITERATIONS) + max_output_length: Max chars for stdout/stderr per turn (default 8192) + context_preview_length: Chars to show in context preview (default 500) + reward_on_success: Reward when final answer is submitted (default 1.0) + reward_on_iteration: Reward per iteration step (default 0.0) + reward_on_failure: Reward when max iterations reached (default -0.1) + reward_on_error: Reward when code execution fails (default -0.05) + llm_query_fn: Optional function for llm_query() support + llm_batch_fn: Optional function for llm_query_batched() support + """ + self.initial_context = context or os.environ.get("REPL_CONTEXT", "") + self.initial_task_prompt = task_prompt or os.environ.get("REPL_TASK_PROMPT", "") + self.max_iterations = int(os.environ.get("REPL_MAX_ITERATIONS", max_iterations)) + self.max_output_length = max_output_length + self.context_preview_length = context_preview_length + + # Reward configuration + self.reward_on_success = reward_on_success + self.reward_on_iteration = reward_on_iteration + self.reward_on_failure = reward_on_failure + self.reward_on_error = reward_on_error + + # Optional LLM functions for recursive calls + self.llm_query_fn = llm_query_fn + self.llm_batch_fn = llm_batch_fn + + # State (initialized on reset) + self._state: Optional[REPLState] = None + self._executor: Optional[PythonExecutor] = None + + def _create_llm_functions( + self, + hf_token: str, + llm_model: Optional[str] = None, + ) -> None: + """Create LLM functions dynamically using client-provided token. + + This allows clients to use their own HF token instead of the server's. + + Security: The token is used only to initialize the InferenceClient + and is NOT stored in state, logged, or persisted anywhere. + + Args: + hf_token: HuggingFace API token (not logged or persisted) + llm_model: Model to use (default: Qwen/Qwen3-Coder-480B-A35B-Instruct) + """ + from concurrent.futures import as_completed, ThreadPoolExecutor + + try: + from huggingface_hub import InferenceClient + except ImportError: + # huggingface_hub not installed, skip LLM functions + return + + model = llm_model or os.environ.get( + "LLM_MODEL", "Qwen/Qwen3-Coder-480B-A35B-Instruct" + ) + client = InferenceClient(model=model, token=hf_token) + + def llm_query(prompt: str) -> str: + """Query the LLM with a prompt and return the response.""" + try: + messages = [{"role": "user", "content": prompt}] + response = client.chat_completion( + messages=messages, + max_tokens=2048, + temperature=0.7, + ) + return response.choices[0].message.content or "" + except Exception as e: + return f"Error calling LLM: {e}" + + def llm_query_batched(prompts: List[str]) -> List[str]: + """Query the LLM with multiple prompts in parallel.""" + if not prompts: + return [] + + max_workers = min(len(prompts), 8) + results: List[str] = [""] * len(prompts) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_idx = { + executor.submit(llm_query, prompt): idx + for idx, prompt in enumerate(prompts) + } + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + results[idx] = future.result() + except Exception as e: + results[idx] = f"Error: {e}" + + return results + + self.llm_query_fn = llm_query + self.llm_batch_fn = llm_query_batched + + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + context: Optional[str] = None, + task_prompt: Optional[str] = None, + hf_token: Optional[str] = None, + llm_model: Optional[str] = None, + **kwargs: Any, + ) -> REPLObservation: + """Reset the environment with optional new context. + + Args: + seed: Optional random seed (for reproducibility) + episode_id: Optional episode identifier (if not provided, one is generated) + context: Context to load (overrides initial_context) + task_prompt: Task description (overrides initial_task_prompt) + hf_token: Optional HuggingFace token for llm_query/llm_query_batched. + If provided, creates LLM functions using this token. + Security: Token is NOT stored in state or logged. + llm_model: Optional model name for LLM functions (default: from env or Qwen3-Coder) + **kwargs: Additional reset parameters + + Returns: + Initial REPLObservation with environment ready message + """ + effective_context = context or self.initial_context + effective_task_prompt = task_prompt or self.initial_task_prompt + + # Create LLM functions if not already provided at init + # Priority: client hf_token > server HF_TOKEN env var + if not self.llm_query_fn: + effective_token = hf_token or os.environ.get("HF_TOKEN") + if effective_token: + self._create_llm_functions(effective_token, llm_model) + + # Initialize state + self._state = REPLState( + episode_id=episode_id or str(uuid4()), + step_count=0, + context=effective_context, + task_prompt=effective_task_prompt, + iteration=0, + max_iterations=self.max_iterations, + namespace_keys=[], + final_answer=None, + total_execution_time=0.0, + ) + + # Initialize executor + self._executor = PythonExecutor(max_output_length=self.max_output_length) + + # Initialize answer dict (Prime Intellect style) + self._executor.set_variable("answer", {"content": "", "ready": False}) + + # Load context into namespace if provided + if effective_context: + self._executor.set_context(effective_context) + + # Inject LLM functions if provided + # Names: llm_query (single), llm_query_batched (official RLM), llm_batch (alias) + if self.llm_query_fn: + self._executor.inject_function("llm_query", self.llm_query_fn) + if self.llm_batch_fn: + self._executor.inject_function( + "llm_query_batched", self.llm_batch_fn + ) # Official name + self._executor.inject_function("llm_batch", self.llm_batch_fn) # Alias + + # Inject FINAL helper function so both FINAL(x) and print(f'FINAL({x})') work + # Returns the FINAL pattern as a string so it appears in stdout for detection + def final_helper(value): + """Helper that returns FINAL(value) string for detection.""" + return f"FINAL({value})" + + self._executor.inject_function("FINAL", final_helper) + + # Inject FINAL_VAR helper that looks up variable and returns FINAL(value) + # This matches official RLM behavior - strips quotes from var_name and looks up in namespace + executor = self._executor # Capture for closure + + def final_var_helper(var_name: str): + """Look up variable by name and return FINAL(value) for detection.""" + # Strip quotes if present (handles both FINAL_VAR("x") and FINAL_VAR(x)) + var_name_clean = str(var_name).strip().strip("\"'") + # Look up variable in executor namespace + value = executor.get_variable(var_name_clean) + if value is not None: + return f"FINAL({value})" + return f"FINAL_VAR({var_name_clean})" # Fallback for regex detection + + self._executor.inject_function("FINAL_VAR", final_var_helper) + + # Update namespace keys + self._state.namespace_keys = self._executor.list_variables() + + # Build initial message + message_parts = ["REPL environment initialized."] + if effective_context: + message_parts.append( + f"Context loaded ({len(effective_context)} chars). Use 'context' variable to access it." + ) + if effective_task_prompt: + message_parts.append(f"Task: {effective_task_prompt}") + message_parts.append( + "Use answer['content'] to store your answer, and set answer['ready'] = True when done." + ) + + return REPLObservation( + result=CodeBlockResult( + stdout="\n".join(message_parts), + stderr="", + locals_snapshot={}, + execution_time=0.0, + success=True, + exception=None, + ), + context_preview=( + effective_context[: self.context_preview_length] + if effective_context + else None + ), + context_length=len(effective_context) if effective_context else 0, + available_variables=self._state.namespace_keys, + iteration=0, + max_iterations=self.max_iterations, + done=False, + reward=0.0, + metadata={ + "task_prompt": effective_task_prompt, + "message": "Environment ready.", + }, + ) + + def step( + self, + action: REPLAction, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> REPLObservation: + """Execute code and return observation. + + Args: + action: REPLAction containing code to execute + timeout_s: Optional timeout in seconds (not currently used) + **kwargs: Additional step parameters + + Returns: + REPLObservation with execution results + """ + if self._state is None or self._executor is None: + raise RuntimeError("Environment not initialized. Call reset() first.") + + self._state.step_count += 1 + self._state.iteration += 1 + + # Check if agent explicitly signals final answer + if action.is_final: + self._state.final_answer = action.final_answer or "" + return self._create_final_observation( + success=True, + message="Final answer submitted.", + reward=self.reward_on_success, + ) + + # Check iteration limit + if self._state.iteration >= self.max_iterations: + # Check if there's a partial answer in the answer dict + answer_var = self._executor.get_variable("answer") + if isinstance(answer_var, dict) and answer_var.get("content"): + self._state.final_answer = str(answer_var.get("content", "")) + return self._create_final_observation( + success=False, + message=f"Maximum iterations ({self.max_iterations}) reached.", + reward=self.reward_on_failure, + ) + + # Execute code + result = self._executor.execute(action.code) + self._state.total_execution_time += result["execution_time"] + self._state.namespace_keys = self._executor.list_variables() + + # Calculate reward + reward = self.reward_on_iteration + if not result["success"]: + reward += self.reward_on_error + + # Check for final answer patterns + final_answer = self._extract_final_answer(result["stdout"]) + done = final_answer is not None + + if done: + self._state.final_answer = final_answer + reward = self.reward_on_success + + return REPLObservation( + result=CodeBlockResult( + stdout=result["stdout"], + stderr=result["stderr"], + locals_snapshot=result["locals_snapshot"], + execution_time=result["execution_time"], + success=result["success"], + exception=result["exception"], + ), + context_preview=( + self._state.context[: self.context_preview_length] + if self._state.context + else None + ), + context_length=len(self._state.context) if self._state.context else 0, + available_variables=self._state.namespace_keys, + iteration=self._state.iteration, + max_iterations=self.max_iterations, + done=done, + reward=reward, + metadata={ + "task_prompt": self._state.task_prompt, + "final_answer": final_answer, + "execution_time": result["execution_time"], + }, + ) + + def _extract_final_answer(self, stdout: str) -> Optional[str]: + """Extract final answer from output. + + Supports multiple patterns: + 1. RLM-style: FINAL(answer) in stdout + 2. RLM-style: FINAL_VAR(variable_name) in stdout + 3. Prime Intellect style: answer = {"content": "...", "ready": True} in namespace + + Args: + stdout: Standard output from code execution + + Returns: + Final answer string or None if not found + """ + # Pattern 1: RLM-style FINAL(answer) + final_match = re.search(r"FINAL\((.*?)\)", stdout, re.DOTALL) + if final_match: + return final_match.group(1).strip() + + # Pattern 2: RLM-style FINAL_VAR(variable_name) + final_var_match = re.search(r"FINAL_VAR\((\w+)\)", stdout) + if final_var_match and self._executor: + var_name = final_var_match.group(1) + value = self._executor.get_variable(var_name) + if value is not None: + return str(value) + + # Pattern 3: Prime Intellect style answer dict + if self._executor: + answer_var = self._executor.get_variable("answer") + if isinstance(answer_var, dict): + if answer_var.get("ready", False): + return str(answer_var.get("content", "")) + + return None + + def _create_final_observation( + self, success: bool, message: str, reward: float + ) -> REPLObservation: + """Create observation for episode termination. + + Args: + success: Whether the episode ended successfully + message: Termination message + reward: Final reward value + + Returns: + Final REPLObservation with done=True + """ + return REPLObservation( + result=CodeBlockResult( + stdout=message, + stderr="", + locals_snapshot={}, + execution_time=0.0, + success=success, + exception=None, + ), + context_preview=None, + context_length=0, + available_variables=[], + iteration=self._state.iteration if self._state else 0, + max_iterations=self.max_iterations, + done=True, + reward=reward, + metadata={ + "final_answer": self._state.final_answer if self._state else None, + "total_execution_time": ( + self._state.total_execution_time if self._state else 0 + ), + "total_iterations": self._state.iteration if self._state else 0, + }, + ) + + @property + def state(self) -> REPLState: + """Get the current environment state. + + Returns: + Current REPLState + + Raises: + RuntimeError: If environment not initialized + """ + if self._state is None: + raise RuntimeError("Environment not initialized. Call reset() first.") + return self._state + + def close(self) -> None: + """Cleanup resources.""" + self._executor = None + self._state = None + + def get_metadata(self) -> EnvironmentMetadata: + """Get environment metadata. + + Returns: + EnvironmentMetadata with environment info + """ + return EnvironmentMetadata( + name="repl_env", + description="Python REPL environment for RLM-style code execution", + version="0.1.0", + ) diff --git a/models.py b/models.py index 3c93f91f46ced37c83d9b92cd6fa19e470082500..c2a8651de625a19b2b77753cf756ff5ffa721f82 100644 --- a/models.py +++ b/models.py @@ -48,9 +48,7 @@ class REPLAction(Action): class CodeBlockResult(BaseModel): """Result of executing a single code block.""" - stdout: str = Field( - default="", description="Standard output from execution" - ) + stdout: str = Field(default="", description="Standard output from execution") stderr: str = Field(default="", description="Standard error from execution") locals_snapshot: Dict[str, str] = Field( default_factory=dict, @@ -59,9 +57,7 @@ class CodeBlockResult(BaseModel): execution_time: float = Field( default=0.0, ge=0, description="Execution time in seconds" ) - success: bool = Field( - default=True, description="Whether execution succeeded" - ) + success: bool = Field(default=True, description="Whether execution succeeded") exception: Optional[str] = Field( default=None, description="Exception message if execution failed" ) @@ -84,9 +80,7 @@ class REPLObservation(Observation): default_factory=list, description="List of variable names available in the namespace", ) - iteration: int = Field( - default=0, ge=0, description="Current iteration number" - ) + iteration: int = Field(default=0, ge=0, description="Current iteration number") max_iterations: int = Field( default=30, ge=1, description="Maximum allowed iterations" ) @@ -101,9 +95,7 @@ class REPLState(State): task_prompt: Optional[str] = Field( default=None, description="The task description to solve" ) - iteration: int = Field( - default=0, ge=0, description="Current iteration number" - ) + iteration: int = Field(default=0, ge=0, description="Current iteration number") max_iterations: int = Field( default=30, ge=1, description="Max iterations before termination" ) diff --git a/prompts.py b/prompts.py index ecd846faf812b8f2112aa3d3fee363a7f598fd28..7c3b7feb9094c486d02108692b004c48ce973491 100644 --- a/prompts.py +++ b/prompts.py @@ -358,19 +358,32 @@ def extract_code_blocks(text: str, language: str = "python") -> List[str]: return all_matches -def format_observation(obs) -> str: - """Format a REPLObservation into observation text for the LLM. +def format_observations(observations) -> str: + """Format REPL observations into observation text for the LLM. Args: - obs: The REPLObservation from env.step() + observations: List of REPL observations from env.step() Returns: Formatted observation string """ - output = obs.result.stdout.strip() if obs.result.stdout else "(no output)" - - if obs.result.success: - return f"Code output:\n{output}" - else: - error = obs.result.stderr or obs.result.exception or "Unknown error" - return f"Code output:\n{output}\n\nERROR: {error}\nFix the error. Remember: 'context' is already defined." + formatted = [] + for i, observation in enumerate(observations, 1): + output = ( + observation.result.stdout.strip() + if observation.result.stdout + else "(no output)" + ) + if observation.result.success: + formatted.append(f"Code block {i} output:\n{output}") + else: + error = ( + observation.result.stderr + or observation.result.exception + or "Unknown error" + ) + formatted.append( + f"Code block {i} output:\n{output}\n\nERROR: {error}\n" + f"Fix the error in code block {i}. Remember: 'context' is already defined." + ) + return "\n\n".join(formatted) diff --git a/pyproject.toml b/pyproject.toml index 7d9f2a16cf379fbc5a54fb956f4c6212beba18dc..b1073b5bd91840f63f4a8ce84a852b57ef442701 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ description = "Recursive Language Model REPL Environment for OpenEnv" requires-python = ">=3.10" dependencies = [ # Core OpenEnv dependencies (required for server functionality) - "openenv-core @ git+https://github.com/meta-pytorch/OpenEnv.git@v0.2.1", + "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git@main", "fastapi>=0.115.0", "pydantic>=2.0.0", "uvicorn>=0.24.0", diff --git a/server/Dockerfile b/server/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..756d65c6cf58644f550b5b257699698ccb26440e --- /dev/null +++ b/server/Dockerfile @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Multi-stage build using openenv-base +# This Dockerfile is flexible and works for both: +# - In-repo environments (with local src/core) +# - Standalone environments (with openenv from pip) +# The build script (openenv build) handles context detection and sets appropriate build args. + +ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ${BASE_IMAGE} AS builder + +WORKDIR /app + +# Build argument to control whether we're building standalone or in-repo +ARG BUILD_MODE=in-repo +ARG ENV_NAME=repl_env + +# Copy environment code (always at root of build context) +COPY . /app/env + +# For in-repo builds, openenv-core is already in the pyproject.toml dependencies +# For standalone builds, openenv-core will be installed from pip via pyproject.toml +WORKDIR /app/env + +# Ensure uv is available (for local builds where base image lacks it) +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +# Install git for building from git repos (build-time only) +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies using uv sync +# If uv.lock exists, use it; otherwise resolve on the fly +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-install-project --no-editable; \ + else \ + uv sync --no-install-project --no-editable; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-editable; \ + else \ + uv sync --no-editable; \ + fi + +# Final runtime stage +FROM ${BASE_IMAGE} + +WORKDIR /app + +# Copy the virtual environment from builder +COPY --from=builder /app/env/.venv /app/.venv + +# Copy the environment code +COPY --from=builder /app/env /app/env + +# Set PATH to use the virtual environment +ENV PATH="/app/.venv/bin:$PATH" + +# Set PYTHONPATH so imports work correctly +ENV PYTHONPATH="/app/env:$PYTHONPATH" + +# Health check using Python (more portable than curl/wget) +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1 + +# Run the FastAPI server +# The module path is constructed to work with the /app/env structure +CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/server/__init__.py b/server/__init__.py index 7d573bc1e0459952cad57645cab1f666a0640c18..34a07139a0671dfa10d32e5b9655c9135291912a 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -10,8 +10,8 @@ REPL Environment Server Components. This module contains the server-side implementation of the REPL environment. """ -from .repl_environment import REPLEnvironment from .python_executor import PythonExecutor +from .repl_environment import REPLEnvironment __all__ = [ "REPLEnvironment", diff --git a/server/app.py b/server/app.py index 2616582400adfbbb8e75df3f556ec38cd325294c..dc1057dea982750b02115b7e872c74ce59b125d8 100644 --- a/server/app.py +++ b/server/app.py @@ -41,12 +41,14 @@ import os try: # In-repo imports (when running from OpenEnv repository) from openenv.core.env_server.http_server import create_app + from ..models import REPLAction, REPLObservation from .repl_environment import REPLEnvironment except ImportError: + from models import REPLAction, REPLObservation + # Standalone imports (when environment is standalone with openenv from pip) from openenv.core.env_server.http_server import create_app - from models import REPLAction, REPLObservation from server.repl_environment import REPLEnvironment diff --git a/server/python_executor.py b/server/python_executor.py index 6affa11fd4eb5d3644d960471a3df6f10a417485..1ef71e46e81d478860ea2614fd9336e8ecfd4010 100644 --- a/server/python_executor.py +++ b/server/python_executor.py @@ -108,9 +108,7 @@ class PythonExecutor: """Register helper functions with the executor.""" helpers = { "format_exc": traceback.format_exc, - "safe_json_dumps": lambda obj: json.dumps( - obj, default=lambda o: repr(o) - ), + "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)), } # Register helpers as callable tools for name, func in helpers.items(): @@ -257,30 +255,21 @@ class PythonExecutor: success = False exception_msg = str(ex) except Exception: - logger.debug( - "Failed to read exec_result.exception", exc_info=True - ) + logger.debug("Failed to read exec_result.exception", exc_info=True) # Determine success from exit_code if available try: if hasattr(exec_result, "exit_code"): - if ( - exec_result.exit_code is not None - and exec_result.exit_code != 0 - ): + if exec_result.exit_code is not None and exec_result.exit_code != 0: success = False elif hasattr(exec_result, "success"): success = bool(exec_result.success) except Exception: - logger.debug( - "Failed to determine exec_result exit code", exc_info=True - ) + logger.debug("Failed to determine exec_result exit code", exc_info=True) except Exception as e: success = False - exception_msg = ( - f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" - ) + exception_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" stderr_parts.append(exception_msg) execution_time = time.time() - start_time diff --git a/server/repl_environment.py b/server/repl_environment.py index 7c0d52159ecf8e4dd31b9de22de76145839014c2..9c3d65b2a491834af2552cd8f8f1cb446ae5d3a5 100644 --- a/server/repl_environment.py +++ b/server/repl_environment.py @@ -31,9 +31,9 @@ except ImportError: from openenv.core.env_server.types import EnvironmentMetadata try: - from ..models import REPLAction, REPLObservation, REPLState, CodeBlockResult + from ..models import CodeBlockResult, REPLAction, REPLObservation, REPLState except ImportError: - from models import REPLAction, REPLObservation, REPLState, CodeBlockResult + from models import CodeBlockResult, REPLAction, REPLObservation, REPLState try: from .python_executor import PythonExecutor @@ -102,12 +102,8 @@ class REPLEnvironment(Environment): llm_batch_fn: Optional function for llm_query_batched() support """ self.initial_context = context or os.environ.get("REPL_CONTEXT", "") - self.initial_task_prompt = task_prompt or os.environ.get( - "REPL_TASK_PROMPT", "" - ) - self.max_iterations = int( - os.environ.get("REPL_MAX_ITERATIONS", max_iterations) - ) + self.initial_task_prompt = task_prompt or os.environ.get("REPL_TASK_PROMPT", "") + self.max_iterations = int(os.environ.get("REPL_MAX_ITERATIONS", max_iterations)) self.max_output_length = max_output_length self.context_preview_length = context_preview_length @@ -141,7 +137,7 @@ class REPLEnvironment(Environment): hf_token: HuggingFace API token (not logged or persisted) llm_model: Model to use (default: Qwen/Qwen3-Coder-480B-A35B-Instruct) """ - from concurrent.futures import ThreadPoolExecutor, as_completed + from concurrent.futures import as_completed, ThreadPoolExecutor try: from huggingface_hub import InferenceClient @@ -242,9 +238,7 @@ class REPLEnvironment(Environment): ) # Initialize executor - self._executor = PythonExecutor( - max_output_length=self.max_output_length - ) + self._executor = PythonExecutor(max_output_length=self.max_output_length) # Initialize answer dict (Prime Intellect style) self._executor.set_variable("answer", {"content": "", "ready": False}) @@ -261,9 +255,7 @@ class REPLEnvironment(Environment): self._executor.inject_function( "llm_query_batched", self.llm_batch_fn ) # Official name - self._executor.inject_function( - "llm_batch", self.llm_batch_fn - ) # Alias + self._executor.inject_function("llm_batch", self.llm_batch_fn) # Alias # Inject FINAL helper function so both FINAL(x) and print(f'FINAL({x})') work # Returns the FINAL pattern as a string so it appears in stdout for detection @@ -285,9 +277,7 @@ class REPLEnvironment(Environment): value = executor.get_variable(var_name_clean) if value is not None: return f"FINAL({value})" - return ( - f"FINAL_VAR({var_name_clean})" # Fallback for regex detection - ) + return f"FINAL_VAR({var_name_clean})" # Fallback for regex detection self._executor.inject_function("FINAL_VAR", final_var_helper) @@ -349,9 +339,7 @@ class REPLEnvironment(Environment): REPLObservation with execution results """ if self._state is None or self._executor is None: - raise RuntimeError( - "Environment not initialized. Call reset() first." - ) + raise RuntimeError("Environment not initialized. Call reset() first.") self._state.step_count += 1 self._state.iteration += 1 @@ -409,9 +397,7 @@ class REPLEnvironment(Environment): if self._state.context else None ), - context_length=len(self._state.context) - if self._state.context - else 0, + context_length=len(self._state.context) if self._state.context else 0, available_variables=self._state.namespace_keys, iteration=self._state.iteration, max_iterations=self.max_iterations, @@ -490,9 +476,7 @@ class REPLEnvironment(Environment): done=True, reward=reward, metadata={ - "final_answer": self._state.final_answer - if self._state - else None, + "final_answer": self._state.final_answer if self._state else None, "total_execution_time": ( self._state.total_execution_time if self._state else 0 ), @@ -511,9 +495,7 @@ class REPLEnvironment(Environment): RuntimeError: If environment not initialized """ if self._state is None: - raise RuntimeError( - "Environment not initialized. Call reset() first." - ) + raise RuntimeError("Environment not initialized. Call reset() first.") return self._state def close(self) -> None: diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47e552876c5221e751fb28e7c7b583b02ee506bb --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""EnvTorch: Standardized agentic execution environments.""" diff --git a/src/core/README.md b/src/core/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d153f1e4f72ce6c7b4e814c78c74e0e734c462b --- /dev/null +++ b/src/core/README.md @@ -0,0 +1,212 @@ +# image OpenEnv: Agentic Execution Environments + +An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - step(), reset(), state(). Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs. + +In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use. + + +## Overview +`openenv.core` provides the foundational building blocks for creating and interacting with containerized environments over HTTP. It enables you to build agent environments that can be deployed as Docker containers and accessed via a simple HTTP API. + +> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental +> stage. You should expect bugs, incomplete features, and APIs that may change +> in future versions. The project welcomes bugfixes, but to make sure things are +> well coordinated you should discuss any significant change before starting the +> work. It's recommended that you signal your intention to contribute in the +> issue tracker, either by filing a new issue or by claiming an existing one. + + +# OpenEnv Core + +Core components for OpenEnv - a framework for building HTTP-based agentic environments. + +## Features + +- **EnvClient**: Async-first client for interacting with remote environments +- **SyncEnvClient**: Synchronous wrapper via `.sync()` for sync codebases +- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket +- **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.) +- **Type System**: Strongly-typed Action/Observation/State interfaces +- **Web Interface**: Optional web UI for interacting with environments + +## Installation + +```bash +pip install "openenv[core]" +``` + +For development: +```bash +pip install "openenv[core]" +``` + +## Quick Start + +### Creating an Environment Client + +EnvClient is **async by default**. Use `async with` and `await` for all operations: + +```python +import asyncio +from openenv.core import EnvClient, StepResult +from dataclasses import dataclass +from typing import Any + +@dataclass +class MyAction: + text: str + +@dataclass +class MyObservation: + response: str + +class MyEnvClient(EnvClient[MyAction, MyObservation, Any]): + def _step_payload(self, action: MyAction) -> dict: + return {"text": action.text} + + def _parse_result(self, payload: dict) -> StepResult[MyObservation]: + obs_data = payload["observation"] + return StepResult( + observation=MyObservation(**obs_data), + reward=payload.get("reward"), + done=payload.get("done", False) + ) + + def _parse_state(self, payload: dict) -> Any: + return payload + +# Async usage (recommended) +async def main(): + client = await MyEnvClient.from_docker_image("my-env:latest") + async with client: + result = await client.reset() + step_result = await client.step(MyAction(text="hello")) + +asyncio.run(main()) + +# Sync usage (via .sync() wrapper) +with MyEnvClient(base_url="http://localhost:8000").sync() as client: + result = client.reset() + step_result = client.step(MyAction(text="hello")) +``` + +### Creating an Environment Server + +```python +from openenv.core.env_server import Environment, HTTPEnvServer, create_app +from dataclasses import dataclass + +@dataclass +class MyAction: + text: str + +@dataclass +class MyObservation: + response: str + reward: float = 0.0 + done: bool = False + +class MyEnvironment(Environment): + def reset(self) -> MyObservation: + return MyObservation(response="Ready") + + def step(self, action: MyAction) -> MyObservation: + return MyObservation( + response=f"Echo: {action.text}", + reward=1.0, + done=False + ) + +# Create FastAPI app +env = MyEnvironment() +app = create_app(env, MyAction, MyObservation) + +# Run with: uvicorn module:app --host 0.0.0.0 --port 8000 +``` + +## Container Providers + +OpenEnv Core supports multiple container providers: + +### Local Docker Provider + +```python +from openenv.core.containers.runtime import LocalDockerProvider + +provider = LocalDockerProvider() +base_url = provider.start_container("my-env:latest") +provider.wait_for_ready(base_url) +# Use environment... +provider.stop_container() +``` + +### Kubernetes Provider (Coming Soon) + +```python +from openenv.core.containers.runtime import KubernetesProvider + +provider = KubernetesProvider(namespace="envs") +base_url = provider.start_container("my-env:latest") +# Use environment... +provider.stop_container() +``` + + +## API Reference + +### EnvClient + +Async base class for environment clients. Key methods: + +- `async connect()`: Establish WebSocket connection +- `async reset(**kwargs)`: Reset environment +- `async step(action)`: Execute action +- `async state()`: Get current state +- `async close()`: Close connection and cleanup +- `sync()`: Return a SyncEnvClient wrapper for synchronous usage + +Abstract methods to implement: +- `_step_payload(action)`: Convert action to JSON +- `_parse_result(payload)`: Parse response to StepResult +- `_parse_state(payload)`: Parse state response + +### SyncEnvClient + +Synchronous wrapper around EnvClient. Use `client.sync()` to get one: + +```python +sync_client = async_client.sync() +with sync_client: + result = sync_client.reset() + result = sync_client.step(action) +``` + +### HTTPEnvServer + +Server wrapper with these methods: + +- `register_routes(app)`: Register endpoints on FastAPI app +- `_deserialize_action(data)`: Convert JSON to Action +- `_serialize_observation(obs)`: Convert Observation to JSON + +### Environment Interface + +Base interface for environment implementations: + +- `reset()`: Reset environment and return initial observation +- `step(action)`: Execute action and return observation +- `state`: Property returning current environment state + +## License + +This project is licensed under the BSD-3-Clause License - see the LICENSE file for details. + +## Contributing + +Contributions are welcome! Please see the main OpenEnv repository for contribution guidelines. + +## Links + +- **Homepage**: https://github.com/meta-pytorch/OpenEnv +- **Documentation**: https://github.com/meta-pytorch/OpenEnv/blob/main/README.md +- **Bug Tracker**: https://github.com/meta-pytorch/OpenEnv/issues diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96065d6a80463e2fe599de7728243fc2adad7135 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core components for agentic environments.""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING + +from . import env_server +from .env_server import * # noqa: F403 + +if TYPE_CHECKING: + from .env_client import EnvClient + from .generic_client import GenericAction, GenericEnvClient + from .llm_client import ( + AnthropicClient, + create_llm_client, + LLMClient, + LLMResponse, + OpenAIClient, + ToolCall, + ) + from .mcp_client import MCPClientBase, MCPToolClient + from .sync_client import SyncEnvClient + +__all__ = [ + "EnvClient", + "SyncEnvClient", + "GenericEnvClient", + "GenericAction", + "MCPClientBase", + "MCPToolClient", + "AnthropicClient", + "LLMClient", + "LLMResponse", + "OpenAIClient", + "ToolCall", + "create_llm_client", +] + env_server.__all__ # type: ignore + + +_LAZY_ATTRS = { + "EnvClient": (".env_client", "EnvClient"), + "SyncEnvClient": (".sync_client", "SyncEnvClient"), + "GenericEnvClient": (".generic_client", "GenericEnvClient"), + "GenericAction": (".generic_client", "GenericAction"), + "MCPClientBase": (".mcp_client", "MCPClientBase"), + "MCPToolClient": (".mcp_client", "MCPToolClient"), + "AnthropicClient": (".llm_client", "AnthropicClient"), + "LLMClient": (".llm_client", "LLMClient"), + "LLMResponse": (".llm_client", "LLMResponse"), + "OpenAIClient": (".llm_client", "OpenAIClient"), + "ToolCall": (".llm_client", "ToolCall"), + "create_llm_client": (".llm_client", "create_llm_client"), +} + + +def __getattr__(name: str): + if name in _LAZY_ATTRS: + module_path, attr_name = _LAZY_ATTRS[name] + module = import_module(module_path, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value + + try: + value = getattr(env_server, name) + except AttributeError as exc: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc + + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals().keys()) | set(__all__)) diff --git a/src/core/client_types.py b/src/core/client_types.py new file mode 100644 index 0000000000000000000000000000000000000000..c7501c656b66a780f29bf23309aaf00fab8df432 --- /dev/null +++ b/src/core/client_types.py @@ -0,0 +1,23 @@ +# Type definitions for EnvTorch +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar + +# Generic type for observations +ObsT = TypeVar("ObsT") +StateT = TypeVar("StateT") + + +@dataclass +class StepResult(Generic[ObsT]): + """ + Represents the result of one environment step. + + Attributes: + observation: The environment's observation after the action. + reward: Scalar reward for this step (optional). + done: Whether the episode is finished. + """ + + observation: ObsT + reward: Optional[float] = None + done: bool = False diff --git a/src/core/containers/__init__.py b/src/core/containers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38e67ef3cd60bf13a26ef7c8bf23986c3eb5990e --- /dev/null +++ b/src/core/containers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container management for environment servers.""" diff --git a/src/core/containers/images/Dockerfile b/src/core/containers/images/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..97bb1cf5e2ce0e58c82496cced3e58976baead4c --- /dev/null +++ b/src/core/containers/images/Dockerfile @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# OpenEnv Base Image +# +# This is the standard base image for all OpenEnv environment servers. +# It includes the minimal dependencies needed to run HTTP environment servers +# and uv for fast dependency management. +# +# Build from repo root: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . +# Tag: docker tag openenv-base:latest openenv-base:0.2.0 +# + +FROM ghcr.io/astral-sh/uv:0.5.27-python3.11-bookworm-slim AS builder + +# Set working directory +WORKDIR /app + +# Copy core pyproject.toml and lockfile for dependency installation +COPY pyproject.toml uv.lock* ./ + +# Install core dependencies using uv with cache mount +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r pyproject.toml + +# Final runtime stage +FROM python:3.11-slim + +# Set metadata +LABEL maintainer="OpenEnv Team" +LABEL description="Base image for OpenEnv based environment servers with uv" +LABEL version="0.2.0" + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Copy uv from builder +COPY --from=builder /usr/local/bin/uv /usr/local/bin/uvx /usr/local/bin/ + +# Copy installed Python packages from builder +COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages + +# Copy console scripts installed by pip (uvicorn, fastapi, etc.) +COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/fastapi /usr/local/bin/ + +# Set working directory +WORKDIR /app + +# Default environment variables +ENV PYTHONPATH=/app/src +ENV PYTHONUNBUFFERED=1 +ENV UV_SYSTEM_PYTHON=1 + +# Default expose port (can be overridden) +EXPOSE 8000 + +# Note: CMD should be specified in child Dockerfiles diff --git a/src/core/containers/images/README.md b/src/core/containers/images/README.md new file mode 100644 index 0000000000000000000000000000000000000000..69c387909fc487bf4bebb2a18dced2185ecf477d --- /dev/null +++ b/src/core/containers/images/README.md @@ -0,0 +1,92 @@ +# OpenEnv Base Image + +Standard base image for all OpenEnv environment servers. + +## What's Included + +| Layer | Size | Contents | +|-------|------|----------| +| python:3.11-slim | 200 MB | Base Python runtime | +| + Dependencies | 100 MB | FastAPI, uvicorn, requests | +| **Total** | **~300 MB** | Ready for environment servers | + +## Image Sizes + +``` +openenv-base:latest 300 MB (python + fastapi + uvicorn) +``` +echo-env:latest 500 MB (python + fastapi + uvicorn + app) +coding-env:latest 520 MB (python + fastapi + uvicorn + app + tools) +another-env:latest 510 MB (python + fastapi + uvicorn + app) +--- +Total: 1.5 GB (with lots of duplication) +``` + +### With Base Images (✅ Solution) +``` +openenv-base:latest 300 MB (python + fastapi + uvicorn) +echo-env:latest 50 MB (app only, uses base) +coding-env:latest 70 MB (app + tools, uses base) +another-env:latest 45 MB (app only, uses base) +--- +Total: 465 MB (base shared, minimal duplication) +``` + +## Building the Base Image + +```bash +# From project root +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . +``` + +## Usage in Environment Dockerfiles + +Each environment Dockerfile should start with: + +```dockerfile +FROM openenv-base:latest + +# Copy only environment-specific files +COPY src/openenv/core/ /app/src/openenv/core/ +COPY envs/my_env/ /app/envs/my_env/ + +# Run the server +CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +## Base Image Contents + +- Python 3.11-slim +- FastAPI >= 0.104.0 +- Uvicorn >= 0.24.0 +- Requests >= 2.25.0 +- curl (for health checks) + +## Example: Building Echo Environment + +```bash +# Step 1: Build base image (do this once) +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . + +# Step 2: Build echo environment (uses base) +docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile . + +# Step 3: Run echo environment +docker run -p 8000:8000 echo-env:latest +``` + +## Updating the Base + +When dependencies need updating: + +1. Update `src/openenv/core/containers/images/Dockerfile` +2. Rebuild base image +3. Rebuild all environment images (they'll use new base) + +```bash +# Update base +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . + +# Rebuild environments (they automatically use new base) +docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile . +``` diff --git a/src/core/containers/runtime/__init__.py b/src/core/containers/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd514dc2fb78007e4ee1bf1f2e9777864bc76b00 --- /dev/null +++ b/src/core/containers/runtime/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container runtime providers.""" + +from .providers import ( + ContainerProvider, + DockerSwarmProvider, + KubernetesProvider, + LocalDockerProvider, + RuntimeProvider, +) +from .uv_provider import UVProvider + +__all__ = [ + "ContainerProvider", + "DockerSwarmProvider", + "LocalDockerProvider", + "KubernetesProvider", + "RuntimeProvider", + "UVProvider", +] diff --git a/src/core/containers/runtime/daytona_provider.py b/src/core/containers/runtime/daytona_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..08c899fa3f16520dbe7cb8c0804e23250d97f605 --- /dev/null +++ b/src/core/containers/runtime/daytona_provider.py @@ -0,0 +1,572 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Daytona container provider for running OpenEnv environments in Daytona cloud sandboxes. + +Requires the ``daytona`` SDK: ``pip install daytona>=0.10`` +""" + +from __future__ import annotations + +import json +import os +import shlex +import time +from typing import Any, Callable, Dict, Optional + +import yaml + +from .providers import ContainerProvider + + +class DaytonaProvider(ContainerProvider): + """ + Container provider that runs environments in Daytona cloud sandboxes. + + Example: + >>> provider = DaytonaProvider(api_key="your-key") + >>> image = DaytonaProvider.image_from_dockerfile("envs/echo_env/server/Dockerfile") + >>> base_url = provider.start_container(image) + >>> provider.wait_for_ready(base_url) + >>> provider.stop_container() + """ + + _dockerfile_registry: Dict[str, Dict[str, Any]] = {} + + def __init__( + self, + *, + api_key: Optional[str] = None, + public: bool = False, + resources: Optional[Any] = None, + auto_stop_interval: int = 15, + target: Optional[str] = None, + on_snapshot_create_logs: Optional[Callable[[str], None]] = None, + cmd: Optional[str] = None, + create_timeout: float = 300, + ): + """ + Args: + api_key: Daytona API key. Falls back to ``DAYTONA_API_KEY`` env var. + public: If True, the sandbox preview is publicly accessible. + resources: Optional ``daytona.Resources`` instance for CPU/memory. + auto_stop_interval: Minutes of inactivity before auto-stop (0 disables). + target: Daytona target region (e.g. "us"). + on_snapshot_create_logs: Callback for snapshot build log lines. + cmd: Shell command to start the server inside the sandbox. + create_timeout: Seconds to wait for sandbox creation (default 300). + Heavy images (e.g. with Playwright/Chromium) may need more. + """ + from daytona import Daytona, DaytonaConfig + + config_kwargs: Dict[str, Any] = {} + resolved_key = api_key or os.environ.get("DAYTONA_API_KEY") + if resolved_key: + config_kwargs["api_key"] = resolved_key + if target: + config_kwargs["target"] = target + + self._daytona = Daytona(DaytonaConfig(**config_kwargs)) + self._public = public + self._resources = resources + self._auto_stop_interval = auto_stop_interval + self._on_snapshot_create_logs = on_snapshot_create_logs + self._cmd = cmd + self._create_timeout = create_timeout + self._sandbox: Any = None + self._preview_url: Optional[str] = None + + def _discover_server_cmd(self, sandbox: Any, port: int = 8000) -> str: + """Discover the server command from ``openenv.yaml`` inside *sandbox*. + + Finds the file, reads the ``app`` field, and constructs a command + of the form ``cd && python -m uvicorn --host 0.0.0.0 --port ``. + + Raises: + ValueError: If ``openenv.yaml`` is not found or lacks an ``app`` field. + """ + yaml_path = self._find_openenv_yaml(sandbox) + if yaml_path is None: + raise ValueError( + "Could not find openenv.yaml inside the sandbox. " + "Pass an explicit cmd= to DaytonaProvider or start_container()." + ) + + cat_resp = sandbox.process.exec(f"cat {shlex.quote(yaml_path)}", timeout=10) + content = cat_resp.result if hasattr(cat_resp, "result") else str(cat_resp) + app = self._parse_app_field(content) + if app is None: + raise ValueError( + f"openenv.yaml at {yaml_path} does not contain an 'app' field. " + "Pass an explicit cmd= to DaytonaProvider or start_container()." + ) + + # The directory containing openenv.yaml is the env root + env_root = yaml_path.rsplit("/", 1)[0] + return ( + f"cd {shlex.quote(env_root)} && " + f"python -m uvicorn {shlex.quote(app)} --host 0.0.0.0 --port {port}" + ) + + def _find_openenv_yaml(self, sandbox: Any) -> Optional[str]: + """Locate ``openenv.yaml`` inside the sandbox. + + Tries the modern layout path ``/app/env/openenv.yaml`` first, + then falls back to a ``find`` command for the old layout. + """ + # Fast path: modern Dockerfile layout + resp = sandbox.process.exec( + "test -f /app/env/openenv.yaml && echo found", timeout=10 + ) + out = resp.result if hasattr(resp, "result") else str(resp) + if "found" in (out or ""): + return "/app/env/openenv.yaml" + + # Fallback: search for it (redirect stderr so error messages + # like "No such file or directory" don't get mistaken for paths). + resp = sandbox.process.exec( + "find /app -maxdepth 4 -name openenv.yaml -print -quit 2>/dev/null", + timeout=10, + ) + path = (resp.result if hasattr(resp, "result") else str(resp) or "").strip() + if path and path.startswith("/"): + return path + + return None + + @staticmethod + def _parse_app_field(yaml_content: str) -> Optional[str]: + """Extract the ``app`` value from raw openenv.yaml content. + + Uses PyYAML to handle comments, quotes, and nested keys correctly. + """ + try: + data = yaml.safe_load(yaml_content) or {} + except Exception: + return None + + if not isinstance(data, dict): + return None + + value = data.get("app") + if isinstance(value, str): + value = value.strip() + return value if value else None + return None + + @staticmethod + def _parse_dockerfile_cmd(dockerfile_content: str) -> Optional[str]: + """Extract the server command from the last ``CMD`` in a Dockerfile. + + Handles exec form (``CMD ["prog", "arg"]``) and shell form + (``CMD prog arg``). When a Dockerfile has multiple ``CMD`` + instructions (e.g. multi-stage builds), the last one wins - same + semantics as Docker itself. Lines where ``CMD`` appears inside a + comment are ignored. + + Returns: + The command as a single string, or ``None`` if no ``CMD`` found. + """ + import re + + last_cmd: Optional[str] = None + for line in dockerfile_content.splitlines(): + stripped = line.strip() + # Skip comments + if stripped.startswith("#"): + continue + match = re.match(r"CMD\s+(.+)", stripped, flags=re.IGNORECASE) + if match: + last_cmd = match.group(1).strip() + + if last_cmd is None: + return None + + # Exec form: CMD ["executable", "param1", ...] + if last_cmd.startswith("["): + try: + parts = json.loads(last_cmd) + if isinstance(parts, list) and all(isinstance(p, str) for p in parts): + return " ".join(parts) + except (json.JSONDecodeError, TypeError): + pass + + # Shell form: CMD executable param1 ... + return last_cmd if last_cmd else None + + @staticmethod + def strip_buildkit_syntax(dockerfile_content: str) -> str: + """Remove BuildKit ``--mount=...`` flags from ``RUN`` instructions. + + Handles single-line flags, multi-line continuations, and multiple + ``--mount`` flags spread across continuation lines. Only leading + ``--mount`` flags are removed (before the actual command starts). + + Daytona's ``Image.from_dockerfile`` does not support BuildKit + ``--mount`` syntax. This helper strips the flags so that standard + Dockerfiles (like the ones generated by ``openenv build``) can + be used directly. + """ + import re + + def strip_leading_mounts(text: str) -> str: + remaining = text + while True: + match = re.match(r"\s*--mount=\S+\s*", remaining) + if not match: + return remaining + remaining = remaining[match.end() :] + + lines = dockerfile_content.split("\n") + result: list[str] = [] + in_run = False + in_mount_prefix = False + + for line in lines: + line_out = line + run_start = False + if re.match(r"\s*RUN(\s+|$)", line, flags=re.IGNORECASE): + in_run = True + in_mount_prefix = True + run_start = True + + if in_run and in_mount_prefix: + original_ends_with_slash = line_out.rstrip().endswith("\\") + if run_start: + match = re.match(r"(\s*RUN\s+)(.*)$", line_out, flags=re.IGNORECASE) + if match: + run_prefix, remainder = match.group(1), match.group(2) + else: + run_prefix, remainder = line_out, "" + new_remainder = strip_leading_mounts(remainder) + line_out = run_prefix + new_remainder + content_for_check = new_remainder + else: + new_remainder = strip_leading_mounts(line_out) + line_out = new_remainder + content_for_check = new_remainder + + if original_ends_with_slash and not line_out.rstrip().endswith("\\"): + line_out = line_out.rstrip() + " \\" + + if content_for_check.strip() not in ("", "\\"): + in_mount_prefix = False + + if in_run and not line_out.rstrip().endswith("\\"): + in_run = False + in_mount_prefix = False + + result.append(line_out) + + return "\n".join(result) + + @classmethod + def image_from_dockerfile( + cls, + dockerfile_path: str, + context_dir: str | None = None, + ) -> str: + """Validate a Dockerfile and return a ``dockerfile:`` URI for + :meth:`start_container`. + + Eagerly validates the Dockerfile (existence, COPY sources, + BuildKit stripping) and stores the processed content in an + internal registry. The actual ``daytona.Image`` is created + later inside ``start_container``. + + Args: + dockerfile_path: Path to the Dockerfile on disk. + context_dir: Build context directory. Defaults to the + Dockerfile's grandparent directory, matching the + ``openenv init`` convention where Dockerfiles live in + ``/server/Dockerfile`` and the build context is + ``/``. Pass explicitly for non-standard layouts + (e.g. ``context_dir="."`` for repo-root contexts). + + Returns: + A ``"dockerfile:"`` string to pass to + ``start_container``. + + Raises: + FileNotFoundError: If *dockerfile_path* does not exist. + ValueError: If *context_dir* is given but does not exist, + or if COPY sources in the Dockerfile cannot be found + under the resolved context directory. + """ + import pathlib + import re + + src = pathlib.Path(dockerfile_path).resolve() + if not src.is_file(): + raise FileNotFoundError(f"Dockerfile not found: {dockerfile_path}") + + if context_dir is not None: + ctx = pathlib.Path(context_dir) + if not ctx.is_dir(): + raise ValueError(f"context_dir does not exist: {context_dir}") + else: + # Default: grandparent of the Dockerfile, matching the + # openenv init layout (/server/Dockerfile -> /). + ctx = src.parent.parent + + content = src.read_text() + stripped = cls.strip_buildkit_syntax(content) + + # Validate that COPY sources exist under the context directory. + # This catches mismatches early (e.g. a Dockerfile expecting repo + # root as context when we defaulted to the env directory). + for line in stripped.splitlines(): + m = re.match(r"^\s*COPY\s+(?!--from=)(\S+)\s+", line, re.IGNORECASE) + if not m: + continue + copy_src = m.group(1) + if copy_src.startswith("/"): + continue + resolved = ctx / copy_src + if not resolved.exists() and not any(ctx.glob(copy_src)): + raise ValueError( + f"Dockerfile COPY source '{copy_src}' not found " + f"under context_dir '{ctx}'. This Dockerfile may " + f"expect a different build context (e.g. the repo " + f"root). Pass context_dir explicitly." + ) + + # Parse CMD from the original Dockerfile so start_container can + # use it as a fallback when openenv.yaml is unavailable. + parsed_cmd = cls._parse_dockerfile_cmd(content) + + cls._dockerfile_registry[str(src)] = { + "stripped_content": stripped, + "context_dir": str(ctx), + "server_cmd": parsed_cmd, + } + + return f"dockerfile:{src}" + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Create a Daytona sandbox from a Docker image or snapshot. + + Daytona does not execute the image's CMD (known bug — ENTRYPOINT + runs, CMD does not). The server command is resolved in order: + + 1. Explicit ``cmd`` passed to the constructor. + 2. ``cmd`` key in ``**kwargs`` (popped before forwarding). + 3. Auto-discovered from ``openenv.yaml`` inside the sandbox. + 4. ``CMD`` parsed from the Dockerfile (when *image* came from + ``image_from_dockerfile``). + + Args: + image: Docker image name (e.g. ``"echo-env:latest"``), + ``"snapshot:"`` to create from a pre-built snapshot, + or ``"dockerfile:"`` returned by + :meth:`image_from_dockerfile`. + port: Must be ``None`` or ``8000``. Daytona exposes port 8000 + via its preview proxy; other ports raise ``ValueError``. + env_vars: Environment variables forwarded to the sandbox. + **kwargs: ``cmd`` (str) to override the server command; + remaining kwargs passed through to ``Daytona.create()``. + + Returns: + HTTPS preview URL for the sandbox (base_url). + """ + if port is not None and port != 8000: + raise ValueError( + f"DaytonaProvider only supports port 8000 (got {port}). " + "The Daytona preview proxy routes to port 8000 inside the sandbox." + ) + + # Resolve the server command (may be None; discovery happens after + # sandbox creation when we can inspect the filesystem). + cmd = kwargs.pop("cmd", None) or self._cmd + + # CMD parsed from Dockerfile (populated for "dockerfile:" images). + parsed_cmd: Optional[str] = None + + # Build creation params + create_kwargs: Dict[str, Any] = {} + if env_vars: + create_kwargs["env_vars"] = env_vars + if self._public: + create_kwargs["public"] = True + if self._auto_stop_interval != 15: + create_kwargs["auto_stop_interval"] = self._auto_stop_interval + + if image.startswith("snapshot:"): + from daytona import CreateSandboxFromSnapshotParams + + snapshot_name = image[len("snapshot:") :] + params = CreateSandboxFromSnapshotParams( + snapshot=snapshot_name, **create_kwargs + ) + elif image.startswith("dockerfile:"): + from daytona import CreateSandboxFromImageParams, Image + + dockerfile_path = image[len("dockerfile:") :] + meta = self._dockerfile_registry.get(dockerfile_path) + if meta is None: + raise ValueError( + f"No registered Dockerfile metadata for {dockerfile_path}. " + "Call DaytonaProvider.image_from_dockerfile() first." + ) + + parsed_cmd = meta.get("server_cmd") + + # Build the daytona Image from the pre-stripped content. + import pathlib + import uuid + + ctx = pathlib.Path(meta["context_dir"]) + tmp_name = f".daytona-{uuid.uuid4().hex[:8]}.dockerfile" + tmp_path = ctx / tmp_name + try: + tmp_path.write_text(meta["stripped_content"]) + daytona_image = Image.from_dockerfile(str(tmp_path)) + finally: + tmp_path.unlink(missing_ok=True) + + img_kwargs: Dict[str, Any] = { + "image": daytona_image, + **create_kwargs, + } + if self._resources is not None: + img_kwargs["resources"] = self._resources + params = CreateSandboxFromImageParams(**img_kwargs) + else: + from daytona import CreateSandboxFromImageParams + + img_kwargs = {"image": image, **create_kwargs} + if self._resources is not None: + img_kwargs["resources"] = self._resources + params = CreateSandboxFromImageParams(**img_kwargs) + + # Create sandbox + extra: Dict[str, Any] = dict(kwargs) + if self._on_snapshot_create_logs is not None: + extra["on_snapshot_create_logs"] = self._on_snapshot_create_logs + + self._sandbox = self._daytona.create( + params, timeout=self._create_timeout, **extra + ) + + try: + # Discover server command from openenv.yaml if not explicitly set. + if cmd is None: + try: + cmd = self._discover_server_cmd(self._sandbox) + except ValueError: + # Fall back to CMD parsed from Dockerfile (if available). + if parsed_cmd: + cmd = parsed_cmd + else: + raise + + # Wrap in bash -c so compound commands (cd ... && uvicorn ...) + # are handled correctly by nohup. Write PID so we can check + # if the process crashed later in wait_for_ready(). + escaped_cmd = shlex.quote(cmd) + self._sandbox.process.exec( + f"nohup bash -c {escaped_cmd} > /tmp/openenv-server.log 2>&1 &" + " echo $! > /tmp/openenv-server.pid", + timeout=10, + ) + + # Get a signed preview URL for port 8000. The token is + # embedded in the URL itself so no extra headers are needed. + signed = self._sandbox.create_signed_preview_url( + 8000, expires_in_seconds=86400 + ) + self._preview_url = signed.url + except Exception: + self.stop_container() + raise + + return self._preview_url + + def refresh_preview_url(self) -> str: + """Get a fresh signed preview URL (valid for 24h). + + Daytona signed URLs expire after at most 24 hours. Call this to + get a new one for long-running sessions. The returned URL points + to the same sandbox — clients will need to reconnect using it. + """ + if self._sandbox is None: + raise RuntimeError("No active sandbox to refresh URL for.") + signed = self._sandbox.create_signed_preview_url(8000, expires_in_seconds=86400) + self._preview_url = signed.url + return self._preview_url + + def stop_container(self) -> None: + """Delete the Daytona sandbox.""" + if self._sandbox is None: + return + + try: + self._daytona.delete(self._sandbox) + finally: + self._sandbox = None + self._preview_url = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None: + """ + Poll the /health endpoint until the sandbox is ready. + + Uses a longer default timeout (120s) than Docker providers because + Daytona sandboxes may have cold-start latency. + + Args: + base_url: Preview URL returned by ``start_container()``. + timeout_s: Maximum seconds to wait. + + Raises: + TimeoutError: If the sandbox doesn't become ready in time. + RuntimeError: If the server process died (detected via PID check). + """ + import requests + + health_url = f"{base_url}/health" + + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + response = requests.get(health_url, timeout=5.0) + if response.status_code == 200: + return + except requests.RequestException: + pass + + # Early exit: if the server process died, raise immediately + # instead of waiting for the full health-check timeout. + if self._sandbox is not None: + resp = self._sandbox.process.exec( + "kill -0 $(cat /tmp/openenv-server.pid) 2>/dev/null" + " && echo RUNNING || echo DEAD", + timeout=10, + ) + out = resp.result if hasattr(resp, "result") else str(resp) + if "DEAD" in (out or ""): + log_resp = self._sandbox.process.exec( + "cat /tmp/openenv-server.log 2>/dev/null", timeout=10 + ) + log = ( + log_resp.result + if hasattr(log_resp, "result") + else str(log_resp) + ) + raise RuntimeError(f"Server process died.\nLog:\n{log}") + + time.sleep(1.0) + + raise TimeoutError( + f"Daytona sandbox at {base_url} did not become ready within {timeout_s}s" + ) diff --git a/src/core/containers/runtime/providers.py b/src/core/containers/runtime/providers.py new file mode 100644 index 0000000000000000000000000000000000000000..54232a2495746f89cc81590ca87d03e6e48e3d2b --- /dev/null +++ b/src/core/containers/runtime/providers.py @@ -0,0 +1,669 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Container provider abstractions for running environment servers. + +This module provides a pluggable architecture for different container providers +(local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Sequence + + +class ContainerProvider(ABC): + """ + Abstract base class for container providers. + + Providers implement this interface to support different container platforms: + - LocalDockerProvider: Runs containers on local Docker daemon + - KubernetesProvider: Runs containers in Kubernetes cluster + - FargateProvider: Runs containers on AWS Fargate + - CloudRunProvider: Runs containers on Google Cloud Run + + The provider manages a single container lifecycle and provides the base URL + for connecting to it. + + Example: + >>> provider = LocalDockerProvider() + >>> base_url = provider.start_container("echo-env:latest") + >>> print(base_url) # http://localhost:8000 + >>> # Use the environment via base_url + >>> provider.stop_container() + """ + + @abstractmethod + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a container from the specified image. + + Args: + image: Container image name (e.g., "echo-env:latest") + port: Port to expose (if None, provider chooses) + env_vars: Environment variables to pass to container + **kwargs: Provider-specific options + + Returns: + Base URL to connect to the container (e.g., "http://localhost:8000") + + Raises: + RuntimeError: If container fails to start + """ + pass + + @abstractmethod + def stop_container(self) -> None: + """ + Stop and remove the running container. + + This cleans up the container that was started by start_container(). + """ + pass + + @abstractmethod + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for the container to be ready to accept requests. + + This typically polls the /health endpoint until it returns 200. + + Args: + base_url: Base URL of the container + timeout_s: Maximum time to wait + + Raises: + TimeoutError: If container doesn't become ready in time + """ + pass + + +class LocalDockerProvider(ContainerProvider): + """ + Container provider for local Docker daemon. + + This provider runs containers on the local machine using Docker. + Useful for development and testing. + + Example: + >>> provider = LocalDockerProvider() + >>> base_url = provider.start_container("echo-env:latest") + >>> # Container running on http://localhost: + >>> provider.stop_container() + """ + + def __init__(self): + """Initialize the local Docker provider.""" + self._container_id: Optional[str] = None + self._container_name: Optional[str] = None + + # Check if Docker is available + import subprocess + + try: + subprocess.run( + ["docker", "version"], + check=True, + capture_output=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): + raise RuntimeError( + "Docker is not available. Please install Docker Desktop or Docker Engine." + ) + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a Docker container locally. + + Args: + image: Docker image name + port: Port to expose (if None, finds available port) + env_vars: Environment variables for the container + **kwargs: Additional Docker run options + + Returns: + Base URL to connect to the container + """ + import subprocess + import time + + # Find available port if not specified + if port is None: + port = self._find_available_port() + + # Generate container name + self._container_name = self._generate_container_name(image) + + # Build docker run command + cmd = [ + "docker", + "run", + "-d", # Detached + "--name", + self._container_name, + "-p", + f"{port}:8000", # Map port + ] + + # Add environment variables + if env_vars: + for key, value in env_vars.items(): + cmd.extend(["-e", f"{key}={value}"]) + + # Add image + cmd.append(image) + + # Run container + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + self._container_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}" + raise RuntimeError(error_msg) from e + + # Wait a moment for container to start + time.sleep(1) + + base_url = f"http://localhost:{port}" + return base_url + + def stop_container(self) -> None: + """ + Stop and remove the Docker container. + """ + if self._container_id is None: + return + + import subprocess + + try: + # Stop container + subprocess.run( + ["docker", "stop", self._container_id], + capture_output=True, + check=True, + timeout=10, + ) + + # Remove container + subprocess.run( + ["docker", "rm", self._container_id], + capture_output=True, + check=True, + timeout=10, + ) + except subprocess.CalledProcessError: + # Container might already be stopped/removed + pass + finally: + self._container_id = None + self._container_name = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for container to be ready by polling /health endpoint. + + Args: + base_url: Base URL of the container + timeout_s: Maximum time to wait + + Raises: + TimeoutError: If container doesn't become ready + """ + import time + + import requests + + start_time = time.time() + health_url = f"{base_url}/health" + + # Bypass proxy for localhost to avoid proxy issues + proxies = {"http": None, "https": None} + + while time.time() - start_time < timeout_s: + try: + response = requests.get(health_url, timeout=2.0, proxies=proxies) + if response.status_code == 200: + return + except requests.RequestException: + pass + + time.sleep(0.5) + + raise TimeoutError( + f"Container at {base_url} did not become ready within {timeout_s}s" + ) + + def _find_available_port(self) -> int: + """ + Find an available port on localhost. + + Returns: + An available port number + """ + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _generate_container_name(self, image: str) -> str: + """ + Generate a unique container name based on image name and timestamp. + + Args: + image: Docker image name + + Returns: + A unique container name + """ + import time + + clean_image = image.split("/")[-1].split(":")[0] + timestamp = int(time.time() * 1000) + return f"{clean_image}-{timestamp}" + + +class DockerSwarmProvider(ContainerProvider): + """ + Container provider that uses Docker Swarm services for local concurrency. + + This provider creates a replicated Swarm service backed by the local Docker + engine. The built-in load-balancer fans requests across the replicas, + allowing multiple container instances to run concurrently on the developer + workstation (mirroring the workflow described in the Docker stack docs). + """ + + def __init__( + self, + *, + auto_init_swarm: bool = True, + overlay_network: Optional[str] = None, + ): + """ + Args: + auto_init_swarm: Whether to call ``docker swarm init`` when Swarm + is not active. Otherwise, user must manually initialize Swarm. + overlay_network: Optional overlay network name for the service. + When provided, the network is created with + ``docker network create --driver overlay --attachable`` if it + does not already exist. + """ + self._service_name: Optional[str] = None + self._service_id: Optional[str] = None + self._published_port: Optional[int] = None + self._overlay_network = overlay_network + self._auto_init_swarm = auto_init_swarm + + self._ensure_docker_available() + self._ensure_swarm_initialized() + if self._overlay_network: + self._ensure_overlay_network(self._overlay_network) + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start (or scale) a Swarm service for the given image. + + Supported kwargs: + replicas (int): Number of container replicas (default: 2). + cpu_limit (float | str): CPU limit passed to ``--limit-cpu``. + memory_limit (str): Memory limit passed to ``--limit-memory``. + constraints (Sequence[str]): Placement constraints. + labels (Dict[str, str]): Service labels. + command (Sequence[str] | str): Override container command. + """ + import shlex + import subprocess + import time + + allowed_kwargs = { + "replicas", + "cpu_limit", + "memory_limit", + "constraints", + "labels", + "command", + } + unknown = set(kwargs) - allowed_kwargs + if unknown: + raise ValueError(f"Unsupported kwargs for DockerSwarmProvider: {unknown}") + + replicas = int(kwargs.get("replicas", 2)) + cpu_limit = kwargs.get("cpu_limit") + memory_limit = kwargs.get("memory_limit") + constraints: Optional[Sequence[str]] = kwargs.get("constraints") + labels: Optional[Dict[str, str]] = kwargs.get("labels") + command_override = kwargs.get("command") + + if port is None: + port = self._find_available_port() + + self._service_name = self._generate_service_name(image) + self._published_port = port + + cmd = [ + "docker", + "service", + "create", + "--detach", + "--name", + self._service_name, + "--replicas", + str(max(1, replicas)), + "--publish", + f"{port}:8000", + ] + + if self._overlay_network: + cmd.extend(["--network", self._overlay_network]) + + if env_vars: + for key, value in env_vars.items(): + cmd.extend(["--env", f"{key}={value}"]) + + if cpu_limit is not None: + cmd.extend(["--limit-cpu", str(cpu_limit)]) + + if memory_limit is not None: + cmd.extend(["--limit-memory", str(memory_limit)]) + + if constraints: + for constraint in constraints: + cmd.extend(["--constraint", constraint]) + + if labels: + for key, value in labels.items(): + cmd.extend(["--label", f"{key}={value}"]) + + cmd.append(image) + + if command_override: + if isinstance(command_override, str): + cmd.extend(shlex.split(command_override)) + else: + cmd.extend(command_override) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + ) + self._service_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = ( + "Failed to start Docker Swarm service.\n" + f"Command: {' '.join(cmd)}\n" + f"Exit code: {e.returncode}\n" + f"Stdout: {e.stdout}\n" + f"Stderr: {e.stderr}" + ) + raise RuntimeError(error_msg) from e + + # Give Swarm a brief moment to schedule the tasks. + time.sleep(1.0) + + return f"http://localhost:{port}" + + def stop_container(self) -> None: + """ + Remove the Swarm service (and keep the Swarm manager running). + """ + if not self._service_name: + return + + import subprocess + + try: + subprocess.run( + ["docker", "service", "rm", self._service_name], + capture_output=True, + check=True, + timeout=10, + ) + except subprocess.CalledProcessError: + # Service may already be gone; ignore. + pass + finally: + self._service_name = None + self._service_id = None + self._published_port = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for at least one replica to become healthy by polling /health. + + Note: With Swarm's load balancer, requests round-robin across replicas, + so this only verifies that at least one replica is responding. Some + replicas may still be starting when this returns. + """ + import time + + import requests + + deadline = time.time() + timeout_s + health_url = f"{base_url}/health" + + # Bypass proxy for localhost to avoid proxy issues + proxies = {"http": None, "https": None} + + while time.time() < deadline: + try: + response = requests.get(health_url, timeout=2.0, proxies=proxies) + if response.status_code == 200: + return + except requests.RequestException: + pass + + time.sleep(0.5) + + raise TimeoutError( + f"Swarm service at {base_url} did not become ready within {timeout_s}s" + ) + + def _ensure_docker_available(self) -> None: + import subprocess + + try: + subprocess.run( + ["docker", "version"], + check=True, + capture_output=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ) as exc: + raise RuntimeError( + "Docker is not available. Please install Docker Desktop or Docker Engine." + ) from exc + + def _ensure_swarm_initialized(self) -> None: + import subprocess + + try: + result = subprocess.run( + ["docker", "info", "--format", "{{.Swarm.LocalNodeState}}"], + capture_output=True, + text=True, + check=True, + timeout=5, + ) + state = result.stdout.strip().lower() + if state == "active": + return + except subprocess.CalledProcessError: + state = "unknown" + + if not self._auto_init_swarm: + raise RuntimeError( + f"Docker Swarm is not active (state={state}). Enable Swarm manually or pass auto_init_swarm=True." + ) + + try: + subprocess.run( + ["docker", "swarm", "init"], + check=True, + capture_output=True, + timeout=10, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError("Failed to initialize Docker Swarm") from e + + def _ensure_overlay_network(self, network: str) -> None: + import subprocess + + inspect = subprocess.run( + ["docker", "network", "inspect", network], + capture_output=True, + text=True, + check=False, + ) + if inspect.returncode == 0: + return + + try: + subprocess.run( + [ + "docker", + "network", + "create", + "--driver", + "overlay", + "--attachable", + network, + ], + check=True, + capture_output=True, + timeout=10, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to create overlay network '{network}'") from e + + def _find_available_port(self) -> int: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _generate_service_name(self, image: str) -> str: + import time + + clean_image = image.split("/")[-1].split(":")[0] + timestamp = int(time.time() * 1000) + return f"{clean_image}-swarm-{timestamp}" + + +class KubernetesProvider(ContainerProvider): + """ + Container provider for Kubernetes clusters. + + This provider creates pods in a Kubernetes cluster and exposes them + via services or port-forwarding. + + Example: + >>> provider = KubernetesProvider(namespace="envtorch-dev") + >>> base_url = provider.start_container("echo-env:latest") + >>> # Pod running in k8s, accessible via service or port-forward + >>> provider.stop_container() + """ + + pass + + +class RuntimeProvider(ABC): + """ + Abstract base class for runtime providers that are not container providers. + Providers implement this interface to support different runtime platforms: + - UVProvider: Runs environments via `uv run` + + The provider manages a single runtime lifecycle and provides the base URL + for connecting to it. + + Example: + >>> provider = UVProvider(project_path="/path/to/env") + >>> base_url = provider.start() + >>> print(base_url) # http://localhost:8000 + >>> provider.stop() + """ + + @abstractmethod + def start( + self, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a runtime from the specified image. + + Args: + image: Runtime image name + port: Port to expose (if None, provider chooses) + env_vars: Environment variables for the runtime + **kwargs: Additional runtime options + """ + + @abstractmethod + def stop(self) -> None: + """ + Stop the runtime. + """ + pass + + @abstractmethod + def wait_for_ready(self, timeout_s: float = 30.0) -> None: + """ + Wait for the runtime to be ready to accept requests. + """ + pass + + def __enter__(self) -> "RuntimeProvider": + """ + Enter the runtime provider. + """ + self.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + """ + Exit the runtime provider. + """ + self.stop() + return False diff --git a/src/core/containers/runtime/uv_provider.py b/src/core/containers/runtime/uv_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3ddc89b9bdccbd0d18604c3de5f49fd3cbc74612 --- /dev/null +++ b/src/core/containers/runtime/uv_provider.py @@ -0,0 +1,224 @@ +"""Providers for launching ASGI applications via ``uv run``.""" + +from __future__ import annotations + +import os +import socket +import subprocess +import time +from typing import Dict, Optional + +import requests + +from .providers import RuntimeProvider + + +def _check_uv_installed() -> None: + try: + subprocess.check_output(["uv", "--version"]) + except FileNotFoundError as exc: + raise RuntimeError( + "`uv` executable not found. Install uv from https://docs.astral.sh and ensure it is on PATH." + ) from exc + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + sock.listen(1) + return sock.getsockname()[1] + + +def _create_uv_command( + *, + host: str, + port: int, + reload: bool, + workers: int, + app: str, + project_path: str, +) -> list[str]: + command: list[str] = ["uv", "run", "--isolated", "--project", project_path] + + command.append("--") + command.extend( + [ + "uvicorn", + app, + "--host", + host, + "--port", + str(port), + "--workers", + str(workers), + ] + ) + + if reload: + command.append("--reload") + + return command + + +def _poll_health(health_url: str, timeout_s: float) -> None: + """Poll a health endpoint until it returns HTTP 200 or times out.""" + + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + timeout = max(0.0001, min(deadline - time.time(), 2.0)) + response = requests.get(health_url, timeout=timeout) + if response.status_code == 200: + return + except requests.RequestException: + continue + + time.sleep(0.5) + + raise TimeoutError(f"Server did not become ready within {timeout_s:.1f} seconds") + + +class UVProvider(RuntimeProvider): + """ + RuntimeProvider implementation backed by ``uv run``. + + Args: + project_path: Local path to a uv project (passed to ``uv run --project``) + app: ASGI application path for uvicorn (defaults to ``server.app:app``) + host: Host interface to bind to (defaults to ``0.0.0.0``) + reload: Whether to enable uvicorn's reload mode + env_vars: Environment variables to pass through to the spawned process + context_timeout_s: How long to wait for the environment to become ready + + Example: + >>> provider = UVProvider(project_path="/path/to/env") + >>> base_url = provider.start() + >>> print(base_url) # http://localhost:8000 + >>> # Use the environment via base_url + >>> provider.stop() + """ + + def __init__( + self, + *, + project_path: str, + app: str = "server.app:app", + host: str = "0.0.0.0", + reload: bool = False, + env_vars: Optional[Dict[str, str]] = None, + context_timeout_s: float = 60.0, + ): + """Initialize the UVProvider.""" + self.project_path = os.path.abspath(project_path) + self.app = app + self.host = host + self.reload = reload + self.env_vars = env_vars + self.context_timeout_s = context_timeout_s + _check_uv_installed() + self._process = None + self._base_url = None + + def start( + self, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + workers: int = 1, + **_: Dict[str, str], + ) -> str: + """ + Start the environment via `uv run`. + + Args: + port: The port to bind the environment to + env_vars: Environment variables to pass to the environment + workers: The number of workers to use + + Returns: + The base URL of the environment + + Raises: + RuntimeError: If the environment is already running + """ + if self._process is not None and self._process.poll() is None: + raise RuntimeError("UVProvider is already running") + + bind_port = port or _find_free_port() + + command = _create_uv_command( + host=self.host, + port=bind_port, + reload=self.reload, + workers=workers, + app=self.app, + project_path=self.project_path, + ) + + env = os.environ.copy() + + if self.env_vars: + env.update(self.env_vars) + if env_vars: + env.update(env_vars) + + try: + self._process = subprocess.Popen(command, env=env) + except OSError as exc: + raise RuntimeError(f"Failed to launch `uv run`: {exc}") from exc + + client_host = "127.0.0.1" if self.host in {"0.0.0.0", "::"} else self.host + self._base_url = f"http://{client_host}:{bind_port}" + return self._base_url + + def wait_for_ready(self, timeout_s: float = 60.0) -> None: + """ + Wait for the environment to become ready. + + Args: + timeout_s: The timeout to wait for the environment to become ready + + Raises: + RuntimeError: If the environment is not running + TimeoutError: If the environment does not become ready within the timeout + """ + if self._process and self._process.poll() is not None: + code = self._process.returncode + raise RuntimeError(f"uv process exited prematurely with code {code}") + + _poll_health(f"{self._base_url}/health", timeout_s=timeout_s) + + def stop(self) -> None: + """ + Stop the environment. + + Raises: + RuntimeError: If the environment is not running + """ + if self._process is None: + return + + if self._process.poll() is None: + self._process.terminate() + try: + self._process.wait(timeout=10.0) + except subprocess.TimeoutExpired: + self._process.kill() + self._process.wait(timeout=5.0) + + self._process = None + self._base_url = None + + @property + def base_url(self) -> str: + """ + The base URL of the environment. + + Returns: + The base URL of the environment + + Raises: + RuntimeError: If the environment is not running + """ + if self._base_url is None: + raise RuntimeError("UVProvider has not been started") + return self._base_url diff --git a/src/core/containers/test_local_docker_provider.py b/src/core/containers/test_local_docker_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..ac520a4b68afa699894dd68c0508b1e41936704c --- /dev/null +++ b/src/core/containers/test_local_docker_provider.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +""" +End-to-end test for LocalDockerProvider. + +This script tests the complete flow: +1. Start a container using LocalDockerProvider +2. Wait for it to be ready +3. Make HTTP requests to test the environment +4. Clean up the container +""" + +import sys +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import requests +from openenv.core.containers.runtime import LocalDockerProvider + + +# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env +def test_local_docker_provider(): + """Test LocalDockerProvider end-to-end.""" + print("=" * 60) + print("LocalDockerProvider End-to-End Test") + print("=" * 60) + print() + + provider = None + + try: + # Step 1: Create provider + print("Step 1: Creating LocalDockerProvider...") + provider = LocalDockerProvider() + print("✓ Provider created\n") + + # Step 2: Start container + print("Step 2: Starting echo-env container...") + base_url = provider.start_container("echo-env:latest") + print(f"✓ Container started at: {base_url}") + if provider._container_id: + print(f" Container ID: {provider._container_id[:12]}...") + if provider._container_name: + print(f" Container name: {provider._container_name}\n") + + # Step 3: Wait for ready + print("Step 3: Waiting for container to be ready...") + provider.wait_for_ready(base_url, timeout_s=30.0) + print("✓ Container is ready!\n") + + # Step 4: Test health endpoint + print("Step 4: Testing /health endpoint...") + response = requests.get(f"{base_url}/health") + print(f" Status: {response.status_code}") + print(f" Response: {response.json()}") + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + print("✓ Health check passed\n") + + # Step 5: Test reset endpoint + print("Step 5: Testing /reset endpoint...") + response = requests.post( + f"{base_url}/reset", + json={}, + headers={"Content-Type": "application/json"}, + ) + print(f" Status: {response.status_code}") + data = response.json() + print(f" Message: {data['observation']['echoed_message']}") + print(f" Reward: {data['reward']}") + print(f" Done: {data['done']}") + assert response.status_code == 200 + assert data["observation"]["echoed_message"] == "Echo environment ready!" + print("✓ Reset test passed\n") + + # Step 6: Test step endpoint + print("Step 6: Testing /step endpoint...") + response = requests.post( + f"{base_url}/step", + json={"action": {"message": "Hello from LocalDockerProvider!"}}, + headers={"Content-Type": "application/json"}, + ) + print(f" Status: {response.status_code}") + data = response.json() + print(f" Echoed: {data['observation']['echoed_message']}") + print(f" Length: {data['observation']['message_length']}") + print(f" Reward: {data['reward']}") + assert response.status_code == 200 + assert ( + data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!" + ) + assert data["observation"]["message_length"] == 31 + print("✓ Step test passed\n") + + # Step 7: Test state endpoint + print("Step 7: Testing /state endpoint...") + response = requests.get(f"{base_url}/state") + print(f" Status: {response.status_code}") + data = response.json() + print(f" Episode ID: {data['episode_id']}") + print(f" Step count: {data['step_count']}") + assert response.status_code == 200 + assert data["step_count"] == 1 # One step from above + print("✓ State test passed\n") + + # Step 8: Multiple steps + print("Step 8: Testing multiple steps...") + for i in range(3): + response = requests.post( + f"{base_url}/step", + json={"action": {"message": f"Message {i + 1}"}}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200 + print(f" Step {i + 1}: ✓") + + # Check state updated + response = requests.get(f"{base_url}/state") + data = response.json() + assert data["step_count"] == 4 # 1 + 3 more steps + print(f" Final step count: {data['step_count']}") + print("✓ Multiple steps test passed\n") + + print("=" * 60) + print("✓ All tests passed!") + print("=" * 60) + print() + + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + import traceback + + traceback.print_exc() + return False + + finally: + # Step 9: Cleanup + if provider is not None: + print("\nStep 9: Cleaning up container...") + try: + provider.stop_container() + print("✓ Container stopped and removed\n") + except Exception as e: + print(f"⚠️ Cleanup warning: {e}\n") + + +def test_provider_with_custom_port(): + """Test provider with custom port.""" + print("=" * 60) + print("LocalDockerProvider with Custom Port Test") + print("=" * 60) + print() + + provider = None + + try: + provider = LocalDockerProvider() + + print("Starting container on custom port 8123...") + base_url = provider.start_container("echo-env:latest", port=8123) + print(f"✓ Started at: {base_url}") + assert ":8123" in base_url + + print("Waiting for ready...") + provider.wait_for_ready(base_url) + print("✓ Ready!") + + print("Testing health...") + response = requests.get(f"{base_url}/health") + assert response.status_code == 200 + print("✓ Health check passed") + + print("\n✓ Custom port test passed!\n") + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + return False + + finally: + if provider is not None: + provider.stop_container() + print("✓ Cleaned up\n") + + +def test_provider_with_env_vars(): + """Test provider with environment variables.""" + print("=" * 60) + print("LocalDockerProvider with Environment Variables Test") + print("=" * 60) + print() + + provider = None + + try: + provider = LocalDockerProvider() + + print("Starting container with environment variables...") + base_url = provider.start_container( + "echo-env:latest", env_vars={"DEBUG": "true", "LOG_LEVEL": "info"} + ) + print(f"✓ Started at: {base_url}") + + print("Waiting for ready...") + provider.wait_for_ready(base_url) + print("✓ Ready!") + + print("Testing health...") + response = requests.get(f"{base_url}/health") + assert response.status_code == 200 + print("✓ Health check passed") + + print("\n✓ Environment variables test passed!\n") + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + return False + + finally: + if provider is not None: + provider.stop_container() + print("✓ Cleaned up\n") + + +if __name__ == "__main__": + print() + print("🐳 LocalDockerProvider Test Suite") + print() + + results = [] + + # Run basic test + results.append(("Basic End-to-End", test_local_docker_provider())) + + # Run custom port test + results.append(("Custom Port", test_provider_with_custom_port())) + + # Run environment variables test + results.append(("Environment Variables", test_provider_with_env_vars())) + + # Summary + print("=" * 60) + print("Test Summary") + print("=" * 60) + for name, passed in results: + status = "✓ PASSED" if passed else "✗ FAILED" + print(f"{name:25} {status}") + print("=" * 60) + + all_passed = all(result for _, result in results) + if all_passed: + print("\n🎉 All tests passed!") + exit(0) + else: + print("\n❌ Some tests failed") + exit(1) diff --git a/src/core/env_client.py b/src/core/env_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceb344bca20d55d2f9e7ba9aa39595ef61fca30 --- /dev/null +++ b/src/core/env_client.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Environment client for persistent sessions. + +This module provides a WebSocket-based client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. + +The client is async by default. For synchronous usage, use the `.sync()` method +to get a `SyncEnvClient` wrapper. + +Example (async): + >>> async with GenericEnvClient(base_url="ws://localhost:8000") as env: + ... result = await env.reset() + ... result = await env.step({"code": "print('hello')"}) + +Example (sync wrapper): + >>> env = GenericEnvClient(base_url="ws://localhost:8000").sync() + >>> with env: + ... result = env.reset() + ... result = env.step({"code": "print('hello')"}) +""" + +from __future__ import annotations + +import asyncio +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StateT, StepResult +from .containers.runtime import LocalDockerProvider, UVProvider +from .utils import convert_to_ws_url + +if TYPE_CHECKING: + from websockets.asyncio.client import ClientConnection + + from .containers.runtime import ContainerProvider, RuntimeProvider + from .sync_client import SyncEnvClient + +from websockets.asyncio.client import connect as ws_connect + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +EnvClientT = TypeVar("EnvClientT", bound="EnvClient") + + +class EnvClient(ABC, Generic[ActT, ObsT, StateT]): + """ + Async environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + The client is async by default. For synchronous usage, use the `.sync()` + method to get a `SyncEnvClient` wrapper. + + Features: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + - Async by default for modern Python async/await patterns + + Example (async): + >>> from envs.coding_env.client import CodingEnv + >>> + >>> # Connect to a server using async context manager + >>> async with CodingEnv(base_url="ws://localhost:8000") as env: + ... result = await env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = await env.step(action) + + Example (sync wrapper): + >>> env = CodingEnv(base_url="ws://localhost:8000").sync() + >>> with env: + ... result = env.reset(seed=42) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + max_message_size_mb: float = 100.0, + provider: Optional["ContainerProvider | RuntimeProvider"] = None, + mode: Optional[str] = None, + ): + """ + Initialize environment client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + max_message_size_mb: Maximum WebSocket message size in megabytes. + Default 100MB to handle large observations (screenshots, DOM, etc.) + provider: Optional container/runtime provider for lifecycle management. + Can be a ContainerProvider (Docker) or RuntimeProvider (UV). + mode: Communication mode: 'simulation' for Gym-style API (default) or + 'production' for MCP JSON-RPC protocol. Can also be set via the + OPENENV_CLIENT_MODE environment variable. Constructor parameter + takes precedence over environment variable. Case-insensitive. + """ + # Determine mode (constructor > env var > default) + if mode is None: + mode = os.environ.get("OPENENV_CLIENT_MODE", "simulation") + + # Normalize and validate mode + mode = mode.lower() + if mode not in ("simulation", "production"): + raise ValueError( + f"Invalid mode: '{mode}'. Must be 'simulation' or 'production'. " + f"Set via constructor parameter or OPENENV_CLIENT_MODE environment variable." + ) + + # Store mode (use object.__setattr__ to bypass immutability) + object.__setattr__(self, "_mode", mode) + + # Convert HTTP URL to WebSocket URL + ws_url = convert_to_ws_url(base_url) + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._max_message_size = int( + max_message_size_mb * 1024 * 1024 + ) # Convert MB to bytes + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def __setattr__(self, name: str, value: Any) -> None: + """Prevent modification of _mode after initialization.""" + if name == "_mode" and hasattr(self, "_mode"): + raise AttributeError("Cannot modify mode after initialization") + super().__setattr__(name, value) + + async def connect(self) -> "EnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + # Bypass proxy for localhost connections + ws_url_lower = self._ws_url.lower() + is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower + + old_no_proxy = os.environ.get("NO_PROXY") + if is_localhost: + # Set NO_PROXY to bypass proxy for localhost + current_no_proxy = old_no_proxy or "" + if "localhost" not in current_no_proxy.lower(): + os.environ["NO_PROXY"] = ( + f"{current_no_proxy},localhost,127.0.0.1" + if current_no_proxy + else "localhost,127.0.0.1" + ) + + try: + self._ws = await ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + max_size=self._max_message_size, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + finally: + # Restore original NO_PROXY value + if is_localhost: + if old_no_proxy is None: + os.environ.pop("NO_PROXY", None) + else: + os.environ["NO_PROXY"] = old_no_proxy + + return self + + async def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + await self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + async def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + await self.connect() + + async def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + await self._ensure_connected() + assert self._ws is not None + await self._ws.send(json.dumps(message)) + + async def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = await asyncio.wait_for(self._ws.recv(), timeout=self._message_timeout) + return json.loads(raw) + + async def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + await self._send(message) + response = await self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + async def from_docker_image( + cls: Type[EnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + await client.connect() + + return client + + @classmethod + async def from_env( + cls: Type[EnvClientT], + repo_id: str, + *, + use_docker: bool = True, + provider: Optional["ContainerProvider | RuntimeProvider"] = None, + **provider_kwargs: Any, + ) -> EnvClientT: + """ + Create a client from a Hugging Face Space. + + Args: + repo_id: Hugging Face space identifier ``{org}/{space}``. + use_docker: When ``True`` (default) pull from the HF registry and + launch via :class:`LocalDockerProvider`. When ``False`` run the + space locally with :class:`UVProvider`. + provider: Optional provider instance to reuse. Must be a + :class:`ContainerProvider` when ``use_docker=True`` and a + :class:`RuntimeProvider` otherwise. + provider_kwargs: Additional keyword arguments forwarded to + either the container provider's ``start_container`` (docker) + or to the ``UVProvider`` constructor/start (uv). When + ``use_docker=False``, the ``project_path`` argument can be + used to override the default git URL + (``git+https://huggingface.co/spaces/{repo_id}``). + + Returns: + Connected client instance + + Examples: + >>> # Pull and run from HF Docker registry + >>> env = await MyEnv.from_env("openenv/echo-env") + >>> + >>> # Run locally with UV (clones the space) + >>> env = await MyEnv.from_env("openenv/echo-env", use_docker=False) + >>> + >>> # Run from a local checkout + >>> env = await MyEnv.from_env( + ... "openenv/echo-env", + ... use_docker=False, + ... project_path="/path/to/local/checkout" + ... ) + """ + # Extract start args that apply to both providers + start_args = {} + for key in ("port", "env_vars", "workers"): + if key in provider_kwargs: + start_args[key] = provider_kwargs.pop(key) + + if use_docker: + # Docker mode: pull from HF registry + docker_provider = provider or LocalDockerProvider() + tag = provider_kwargs.pop("tag", "latest") + image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + base_url = docker_provider.start_container( + image, **start_args, **provider_kwargs + ) + docker_provider.wait_for_ready(base_url) + + client = cls(base_url=base_url, provider=docker_provider) + await client.connect() + return client + else: + # UV mode: clone and run with uv + if provider is None: + uv_kwargs = dict(provider_kwargs) + project_path = uv_kwargs.pop("project_path", None) + if project_path is None: + project_path = f"git+https://huggingface.co/spaces/{repo_id}" + + provider = UVProvider(project_path=project_path, **uv_kwargs) + else: + if provider_kwargs: + raise ValueError( + "provider_kwargs cannot be used when supplying a provider instance" + ) + + base_url = provider.start(**start_args) + provider.wait_for_ready() + + client = cls(base_url=base_url, provider=provider) + await client.connect() + return client + + @abstractmethod + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + async def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = await self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = await self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + async def state(self) -> StateT: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = await self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + async def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image() or from_env(), + this will also stop and remove the associated container/process. + """ + await self.disconnect() + + if self._provider is not None: + # Handle both ContainerProvider and RuntimeProvider + if hasattr(self._provider, "stop_container"): + self._provider.stop_container() + elif hasattr(self._provider, "stop"): + self._provider.stop() + + async def __aenter__(self) -> "EnvClient": + """Enter async context manager, ensuring connection is established.""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit async context manager, closing connection.""" + await self.close() + + def __enter__(self) -> "EnvClient": + """Sync context manager entry - raises error suggesting async usage.""" + raise TypeError( + "EnvClient is async by default. Use 'async with' instead of 'with', " + "or call .sync() to get a synchronous wrapper:\n" + " async with client: # async usage\n" + " with client.sync(): # sync wrapper" + ) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Sync context manager exit - should not be reached.""" + pass # pragma: no cover + + def sync(self) -> "SyncEnvClient": + """ + Return a synchronous wrapper around this async client. + + Use this method when you need synchronous access to the environment + without async/await syntax. This is useful for: + - Integration with synchronous codebases + - Interactive/REPL usage + - Stopping async from "infecting" the call stack + + Returns: + SyncEnvClient wrapper that provides synchronous methods + + Example: + >>> # Create async client and get sync wrapper + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous API + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"code": "print('hello')"}) + """ + from .sync_client import SyncEnvClient + + return SyncEnvClient(self) diff --git a/src/core/env_server/__init__.py b/src/core/env_server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0f1f2845f09ec758c1fcedb16dbb771059156b --- /dev/null +++ b/src/core/env_server/__init__.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core environment interfaces and types.""" + +from .base_transforms import CompositeTransform, NullTransform +from .exceptions import ( + ConcurrencyConfigurationError, + EnvironmentFactoryError, + OpenEnvError, + SessionCapacityError, + SessionCreationError, + SessionNotFoundError, +) +from .http_server import create_app, create_fastapi_app, HTTPEnvServer +from .interfaces import Environment, Message, ModelTokenizer, Transform + +try: + from .mcp_environment import MCPEnvironment +except ModuleNotFoundError: + MCPEnvironment = None # type: ignore[assignment] + +from .mcp_types import ( + CallToolAction, + CallToolObservation, + JsonRpcError, + # JSON-RPC types + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + ListToolsAction, + ListToolsObservation, + McpMethod, + RESERVED_TOOL_NAMES, + Tool, + ToolError, + ToolErrorType, + WSMCPMessage, + WSMCPResponse, +) +from .route_config import GetEndpointConfig +from .serialization import ( + deserialize_action, + deserialize_action_with_preprocessing, + serialize_observation, +) +from .types import ( + Action, + BaseMessage, + ConcurrencyConfig, + HealthResponse, + HealthStatus, + Observation, + SchemaResponse, + ServerCapacityStatus, + ServerMode, + SessionInfo, + State, + WSCloseMessage, + WSErrorCode, + WSErrorResponse, + WSIncomingMessage, + WSObservationResponse, + WSResetMessage, + WSStateMessage, + WSStateResponse, + WSStepMessage, +) + +try: + from .web_interface import create_web_interface_app, WebInterfaceManager +except ModuleNotFoundError: + create_web_interface_app = None # type: ignore[assignment] + WebInterfaceManager = None # type: ignore[assignment] + +__all__ = [ + # Core interfaces + "Environment", + "Transform", + "Message", + "ModelTokenizer", + # Types + "Action", + "Observation", + "State", + "SchemaResponse", + "HealthResponse", + # Enums + "HealthStatus", + "ServerMode", + "WSErrorCode", + # WebSocket message types + "BaseMessage", + "WSIncomingMessage", + "WSResetMessage", + "WSStepMessage", + "WSStateMessage", + "WSCloseMessage", + "WSObservationResponse", + "WSStateResponse", + "WSErrorResponse", + # Concurrency types + "ConcurrencyConfig", + "ServerCapacityStatus", + "SessionInfo", + # Exceptions + "OpenEnvError", + "ConcurrencyConfigurationError", + "SessionCapacityError", + "SessionNotFoundError", + "SessionCreationError", + "EnvironmentFactoryError", + # Base transforms + "CompositeTransform", + "NullTransform", + # HTTP Server + "HTTPEnvServer", + "create_app", + "create_fastapi_app", + # Web Interface + "create_web_interface_app", + "WebInterfaceManager", + # Serialization utilities + "deserialize_action", + "deserialize_action_with_preprocessing", + "serialize_observation", + # Route configuration + "GetEndpointConfig", + # MCP types + "Tool", + "ToolError", + "ToolErrorType", + "ListToolsAction", + "CallToolAction", + "ListToolsObservation", + "CallToolObservation", + "WSMCPMessage", + "WSMCPResponse", + "RESERVED_TOOL_NAMES", + "MCPEnvironment", + # JSON-RPC types + "JsonRpcErrorCode", + "JsonRpcError", + "JsonRpcRequest", + "JsonRpcResponse", + "McpMethod", +] diff --git a/src/core/env_server/base_transforms.py b/src/core/env_server/base_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ab48ebb48b58962ff56d282713a1d63907b0f390 --- /dev/null +++ b/src/core/env_server/base_transforms.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base transform implementations for composing environment-specific transforms.""" + +from .interfaces import Transform +from .types import Observation + + +class CompositeTransform(Transform): + """Combines multiple transforms into a single transform.""" + + def __init__(self, transforms: list[Transform]): + self.transforms = transforms + + def __call__(self, observation: Observation) -> Observation: + for transform in self.transforms: + observation = transform(observation) + return observation + + +class NullTransform(Transform): + """Default transform that passes through unchanged.""" + + def __call__(self, observation: Observation) -> Observation: + return observation diff --git a/src/core/env_server/exceptions.py b/src/core/env_server/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..5701913e0bcac67e6f84d3861d57c4949665677a --- /dev/null +++ b/src/core/env_server/exceptions.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Custom exceptions for environment server operations.""" + +from typing import Optional + + +class OpenEnvError(Exception): + """Base exception for all OpenEnv errors.""" + + pass + + +class ConcurrencyConfigurationError(OpenEnvError): + """ + Raised when an environment is misconfigured for concurrent sessions. + + This error is raised during server startup when max_concurrent_envs > 1 + is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + + def __init__( + self, + environment_name: str, + max_concurrent_envs: int, + message: Optional[str] = None, + ): + self.environment_name = environment_name + self.max_concurrent_envs = max_concurrent_envs + + if message is None: + message = ( + f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " + f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " + f"Either set max_concurrent_envs=1 or ensure the environment " + f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." + ) + + super().__init__(message) + + +class SessionCapacityError(OpenEnvError): + """ + Raised when the server cannot accept new sessions due to capacity limits. + + This error is raised when a new WebSocket connection is attempted but + the server has already reached max_concurrent_envs active sessions. + """ + + def __init__( + self, + active_sessions: int, + max_sessions: int, + message: Optional[str] = None, + ): + self.active_sessions = active_sessions + self.max_sessions = max_sessions + + if message is None: + message = ( + f"Server at capacity: {active_sessions}/{max_sessions} sessions active. " + f"Cannot accept new connections." + ) + + super().__init__(message) + + +class SessionNotFoundError(OpenEnvError): + """Raised when attempting to access a session that does not exist.""" + + def __init__(self, session_id: str, message: Optional[str] = None): + self.session_id = session_id + + if message is None: + message = f"Session '{session_id}' not found." + + super().__init__(message) + + +class SessionCreationError(OpenEnvError): + """Raised when a session cannot be created.""" + + def __init__(self, reason: str, message: Optional[str] = None): + self.reason = reason + + if message is None: + message = f"Failed to create session: {reason}" + + super().__init__(message) + + +class EnvironmentFactoryError(OpenEnvError): + """Raised when the environment factory fails to create an instance.""" + + def __init__(self, factory_name: str, message: Optional[str] = None): + self.factory_name = factory_name + + if message is None: + message = f"Environment factory '{factory_name}' failed to create instance." + + super().__init__(message) diff --git a/src/core/env_server/gradio_theme.py b/src/core/env_server/gradio_theme.py new file mode 100644 index 0000000000000000000000000000000000000000..7cebea2284d8d19e41d5954b498bcc3bb7ff39a4 --- /dev/null +++ b/src/core/env_server/gradio_theme.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unified terminal-style theme for OpenEnv Gradio UI (light/dark).""" + +from __future__ import annotations + +import gradio as gr + +_MONO_FONTS = ( + "JetBrains Mono", + "Fira Code", + "Cascadia Code", + "Consolas", + "ui-monospace", + "monospace", +) + +_CORE_FONT = ( + "Lato", + "Inter", + "Arial", + "Helvetica", + "sans-serif", +) + +_ZERO_RADIUS = gr.themes.Size( + xxs="0px", + xs="0px", + sm="0px", + md="0px", + lg="0px", + xl="0px", + xxl="0px", +) + +_GREEN_HUE = gr.themes.Color( + c50="#e6f4ea", + c100="#ceead6", + c200="#a8dab5", + c300="#6fcc8b", + c400="#3fb950", + c500="#238636", + c600="#1a7f37", + c700="#116329", + c800="#0a4620", + c900="#033a16", + c950="#04200d", +) + +_NEUTRAL_HUE = gr.themes.Color( + c50="#f6f8fa", + c100="#eaeef2", + c200="#d0d7de", + c300="#afb8c1", + c400="#8c959f", + c500="#6e7781", + c600="#57606a", + c700="#424a53", + c800="#32383f", + c900="#24292f", + c950="#1b1f24", +) + +OPENENV_GRADIO_THEME = gr.themes.Base( + primary_hue=_GREEN_HUE, + secondary_hue=_NEUTRAL_HUE, + neutral_hue=_NEUTRAL_HUE, + font=_CORE_FONT, + font_mono=_MONO_FONTS, + radius_size=_ZERO_RADIUS, +).set( + body_background_fill="#ffffff", + background_fill_primary="#ffffff", + background_fill_secondary="#f6f8fa", + block_background_fill="#ffffff", + block_border_color="#ffffff", + block_label_text_color="#57606a", + block_title_text_color="#24292f", + border_color_primary="#d0d7de", + input_background_fill="#ffffff", + input_border_color="#d0d7de", + button_primary_background_fill="#1a7f37", + button_primary_background_fill_hover="#116329", + button_primary_text_color="#ffffff", + button_secondary_background_fill="#f6f8fa", + button_secondary_background_fill_hover="#eaeef2", + button_secondary_text_color="#24292f", + button_secondary_border_color="#d0d7de", + body_background_fill_dark="#0d1117", + background_fill_primary_dark="#0d1117", + background_fill_secondary_dark="#0d1117", + block_background_fill_dark="#0d1117", + block_border_color_dark="#0d1117", + block_label_text_color_dark="#8b949e", + block_title_text_color_dark="#c9d1d9", + border_color_primary_dark="#30363d", + input_background_fill_dark="#0d1117", + input_border_color_dark="#30363d", + button_primary_background_fill_dark="#30363d", + button_primary_background_fill_hover_dark="#484f58", + button_primary_text_color_dark="#c9d1d9", + button_secondary_background_fill_dark="#21262d", + button_secondary_background_fill_hover_dark="#30363d", + button_secondary_text_color_dark="#c9d1d9", + button_secondary_border_color_dark="#30363d", +) + +OPENENV_GRADIO_CSS = """ +* { border-radius: 0 !important; } +.col-left { padding: 16px !important; } +.col-right { padding: 16px !important; } +.prose, .markdown-text, .md, +.prose > *, .markdown-text > * { + background: transparent !important; + border: none !important; + box-shadow: none !important; +} +.dark .col-left { + border-left-color: rgba(139, 148, 158, 0.4) !important; +} +.dark .col-right { + border-left-color: rgba(201, 209, 217, 0.3) !important; +} +""" diff --git a/src/core/env_server/gradio_ui.py b/src/core/env_server/gradio_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1a630bd1db39588304b42520f08bb45f477e81 --- /dev/null +++ b/src/core/env_server/gradio_ui.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Gradio-based web UI for OpenEnv environments. + +Replaces the legacy HTML/JavaScript interface when ENABLE_WEB_INTERFACE is set. +Mount at /web via gr.mount_gradio_app() from create_web_interface_app(). +""" + +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Optional + +import gradio as gr + +from .types import EnvironmentMetadata + + +def _escape_md(text: str) -> str: + """Escape Markdown special characters in user-controlled content.""" + return re.sub(r"([\\`*_\{\}\[\]()#+\-.!|~>])", r"\\\1", str(text)) + + +def _format_observation(data: Dict[str, Any]) -> str: + """Format reset/step response for Markdown display.""" + lines: List[str] = [] + obs = data.get("observation", {}) + if isinstance(obs, dict): + if obs.get("prompt"): + lines.append(f"**Prompt:**\n\n{_escape_md(obs['prompt'])}\n") + messages = obs.get("messages", []) + if messages: + lines.append("**Messages:**\n") + for msg in messages: + sender = _escape_md(str(msg.get("sender_id", "?"))) + content = _escape_md(str(msg.get("content", ""))) + cat = _escape_md(str(msg.get("category", ""))) + lines.append(f"- `[{cat}]` Player {sender}: {content}") + lines.append("") + reward = data.get("reward") + done = data.get("done") + if reward is not None: + lines.append(f"**Reward:** `{reward}`") + if done is not None: + lines.append(f"**Done:** `{done}`") + return "\n".join(lines) if lines else "*No observation data*" + + +def _readme_section(metadata: Optional[EnvironmentMetadata]) -> str: + """README content for the left panel.""" + if not metadata or not metadata.readme_content: + return "*No README available.*" + return metadata.readme_content + + +def get_gradio_display_title( + metadata: Optional[EnvironmentMetadata], + fallback: str = "OpenEnv Environment", +) -> str: + """Return the title used for the Gradio app (browser tab and Blocks).""" + name = metadata.name if metadata else fallback + return f"OpenEnv Agentic Environment: {name}" + + +def build_gradio_app( + web_manager: Any, + action_fields: List[Dict[str, Any]], + metadata: Optional[EnvironmentMetadata], + is_chat_env: bool, + title: str = "OpenEnv Environment", + quick_start_md: Optional[str] = None, +) -> gr.Blocks: + """ + Build a Gradio Blocks app for the OpenEnv web interface. + + Args: + web_manager: WebInterfaceManager (reset/step_environment, get_state). + action_fields: Field dicts from _extract_action_fields(action_cls). + metadata: Environment metadata for README/name. + is_chat_env: If True, single message textbox; else form from action_fields. + title: App title (overridden by metadata.name when present; see get_gradio_display_title). + quick_start_md: Optional Quick Start markdown (class names already replaced). + + Returns: + gr.Blocks to mount with gr.mount_gradio_app(app, blocks, path="/web"). + """ + readme_content = _readme_section(metadata) + display_title = get_gradio_display_title(metadata, fallback=title) + + async def reset_env(): + try: + data = await web_manager.reset_environment() + obs_md = _format_observation(data) + return ( + obs_md, + json.dumps(data, indent=2), + "Environment reset successfully.", + ) + except Exception as e: + return ("", "", f"Error: {e}") + + def _step_with_action(action_data: Dict[str, Any]): + async def _run(): + try: + data = await web_manager.step_environment(action_data) + obs_md = _format_observation(data) + return ( + obs_md, + json.dumps(data, indent=2), + "Step complete.", + ) + except Exception as e: + return ("", "", f"Error: {e}") + + return _run + + async def step_chat(message: str): + if not (message or str(message).strip()): + return ("", "", "Please enter an action message.") + action = {"message": str(message).strip()} + return await _step_with_action(action)() + + def get_state_sync(): + try: + data = web_manager.get_state() + return json.dumps(data, indent=2) + except Exception as e: + return f"Error: {e}" + + with gr.Blocks(title=display_title) as demo: + with gr.Row(): + with gr.Column(scale=1, elem_classes="col-left"): + if quick_start_md: + with gr.Accordion("Quick Start", open=True): + gr.Markdown(quick_start_md) + with gr.Accordion("README", open=False): + gr.Markdown(readme_content) + + with gr.Column(scale=2, elem_classes="col-right"): + obs_display = gr.Markdown( + value=("# Playground\n\nClick **Reset** to start a new episode."), + ) + with gr.Group(): + if is_chat_env: + action_input = gr.Textbox( + label="Action message", + placeholder="e.g. Enter your message...", + ) + step_inputs = [action_input] + step_fn = step_chat + else: + step_inputs = [] + for field in action_fields: + name = field["name"] + field_type = field.get("type", "text") + label = name.replace("_", " ").title() + placeholder = field.get("placeholder", "") + if field_type == "checkbox": + inp = gr.Checkbox(label=label) + elif field_type == "number": + inp = gr.Number(label=label) + elif field_type == "select": + choices = field.get("choices") or [] + inp = gr.Dropdown( + choices=choices, + label=label, + allow_custom_value=False, + ) + elif field_type in ("textarea", "tensor"): + inp = gr.Textbox( + label=label, + placeholder=placeholder, + lines=3, + ) + else: + inp = gr.Textbox( + label=label, + placeholder=placeholder, + ) + step_inputs.append(inp) + + async def step_form(*values): + if not action_fields: + return await _step_with_action({})() + action_data = {} + for i, field in enumerate(action_fields): + if i >= len(values): + break + name = field["name"] + val = values[i] + if field.get("type") == "checkbox": + action_data[name] = bool(val) + elif val is not None and val != "": + action_data[name] = val + return await _step_with_action(action_data)() + + step_fn = step_form + + with gr.Row(): + step_btn = gr.Button("Step", variant="primary") + reset_btn = gr.Button("Reset", variant="secondary") + state_btn = gr.Button("Get state", variant="secondary") + with gr.Row(): + status = gr.Textbox( + label="Status", + interactive=False, + ) + raw_json = gr.Code( + label="Raw JSON response", + language="json", + interactive=False, + ) + + reset_btn.click( + fn=reset_env, + outputs=[obs_display, raw_json, status], + ) + step_btn.click( + fn=step_fn, + inputs=step_inputs, + outputs=[obs_display, raw_json, status], + ) + if is_chat_env: + action_input.submit( + fn=step_fn, + inputs=step_inputs, + outputs=[obs_display, raw_json, status], + ) + state_btn.click( + fn=get_state_sync, + outputs=[raw_json], + ) + + return demo diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py new file mode 100644 index 0000000000000000000000000000000000000000..658f63ef98bf78d278b8926271c217da23c79a37 --- /dev/null +++ b/src/core/env_server/http_server.py @@ -0,0 +1,1391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +HTTP server wrapper for Environment instances. + +This module provides utilities to wrap any Environment subclass and expose it +over HTTP and WebSocket endpoints that EnvClient can consume. +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import os +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Dict, Optional, Type + +from fastapi import ( + Body, + FastAPI, + HTTPException, + Request, + status, + WebSocket, + WebSocketDisconnect, +) +from pydantic import ValidationError + +from .interfaces import Environment +from .mcp_environment import get_server_tools +from .mcp_types import ( + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + McpMethod, + WSMCPMessage, + WSMCPResponse, +) +from .route_config import GetEndpointConfig, register_get_endpoints +from .serialization import deserialize_action, serialize_observation +from .types import ( + Action, + ConcurrencyConfig, + EnvironmentMetadata, + HealthResponse, + HealthStatus, + Observation, + ResetRequest, + ResetResponse, + SchemaResponse, + ServerCapacityStatus, + ServerMode, + SessionInfo, + State, + StepRequest, + StepResponse, + WSCloseMessage, + WSErrorCode, + WSErrorResponse, + WSObservationResponse, + WSResetMessage, + WSStateMessage, + WSStateResponse, + WSStepMessage, +) + + +def _make_json_serializable(obj: Any) -> Any: + """ + Convert an object to a JSON-serializable form. + + Handles Pydantic models, dataclasses, and other common types. + + Args: + obj: The object to convert + + Returns: + A JSON-serializable representation of the object + """ + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, (list, tuple)): + return [_make_json_serializable(item) for item in obj] + if isinstance(obj, dict): + return {k: _make_json_serializable(v) for k, v in obj.items()} + if hasattr(obj, "model_dump"): + # Pydantic model + return obj.model_dump() + if hasattr(obj, "__dict__"): + # Object with __dict__ + return {k: _make_json_serializable(v) for k, v in obj.__dict__.items()} + # Fallback to string representation + return str(obj) + + +from .exceptions import ( + ConcurrencyConfigurationError, + EnvironmentFactoryError, + SessionCapacityError, +) + + +class HTTPEnvServer: + """ + HTTP server wrapper for Environment instances. + + This class wraps an Environment and exposes its reset(), step(), and state + methods as HTTP and WebSocket endpoints compatible with EnvClient. + + The server expects: + - Action deserialization: Converts JSON dict to Action subclass + - Observation serialization: Converts Observation subclass to JSON dict + + Example: + >>> from core.env_server import HTTPEnvServer + >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation + >>> + >>> # Pass environment class (factory pattern) + >>> server = HTTPEnvServer( + ... env=CodeExecutionEnvironment, + ... action_cls=CodeAction, + ... observation_cls=CodeObservation, + ... max_concurrent_envs=4, + ... ) + >>> + >>> # Register routes with FastAPI + >>> from fastapi import FastAPI + >>> app = FastAPI() + >>> server.register_routes(app) + """ + + def __init__( + self, + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, + ): + """ + Initialize HTTP server wrapper. + + Args: + env: Environment factory (callable) that creates new instances. + Will be called to create a new environment for each WebSocket session. + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Raises: + ValueError: If both max_concurrent_envs and concurrency_config are provided. + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + # Validate that env is callable + if not callable(env): + raise TypeError( + f"env must be a callable (class or factory function), got {type(env)}. " + f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." + ) + + self._env_factory: Callable[[], Environment] = env + + # Handle concurrency configuration + if max_concurrent_envs is not None and concurrency_config is not None: + raise ValueError( + "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. " + "Please use only one method to configure concurrency." + ) + + if concurrency_config is not None: + self._concurrency_config = concurrency_config + elif max_concurrent_envs is not None: + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=max_concurrent_envs, + session_timeout=None, + ) + else: + # Default configuration + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=1, + session_timeout=None, + ) + + self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs + + # Validate concurrency configuration + self._validate_concurrency_safety() + + self.action_cls = action_cls + self.observation_cls = observation_cls + + # Session management for WebSocket connections + self._sessions: Dict[str, Environment] = {} + self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_info: Dict[str, SessionInfo] = {} + self._session_lock = asyncio.Lock() + + # Create thread pool for running sync code in async context + # This is needed for environments using sync libraries (e.g., Playwright) + self._executor = ThreadPoolExecutor(max_workers=32) + + def _validate_concurrency_safety(self) -> None: + """ + Validate that the environment supports the configured concurrency level. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + if self._max_concurrent_envs <= 1: + return + + if inspect.isclass(self._env_factory): + env_cls = self._env_factory + else: + _temp_env = self._env_factory() + env_cls = type(_temp_env) + _temp_env.close() + del _temp_env + + if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False): + raise ConcurrencyConfigurationError( + environment_name=env_cls.__name__, + max_concurrent_envs=self._max_concurrent_envs, + ) + + def get_capacity_status(self) -> ServerCapacityStatus: + """ + Get the current capacity status of the server. + + Returns: + ServerCapacityStatus with current session counts and availability. + """ + return ServerCapacityStatus.from_counts( + active=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + async def _run_sync_in_thread_pool( + self, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the thread pool executor.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) + + def _get_valid_kwargs( + self, + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: + """Filter kwargs to only include parameters accepted by the function signature.""" + if skip_params is None: + skip_params = set() + + valid_kwargs = {} + + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + for k, v in kwargs.items(): + if k in sig.parameters or has_kwargs: + if k not in skip_params: + valid_kwargs[k] = v + + return valid_kwargs + + async def _create_session(self) -> tuple[str, Environment]: + """ + Create a new WebSocket session with its own environment instance. + + Returns: + Tuple of (session_id, environment) + + Raises: + SessionCapacityError: If max concurrent sessions reached + EnvironmentFactoryError: If the factory fails to create an environment + """ + async with self._session_lock: + if len(self._sessions) >= self._max_concurrent_envs: + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + session_id = str(uuid.uuid4()) + current_time = time.time() + + # Create executor and reserve slot so capacity is not exceeded while + # we create the env outside the lock (avoids blocking other sessions) + executor = ThreadPoolExecutor(max_workers=1) + self._session_executors[session_id] = executor + self._sessions[session_id] = None # placeholder until env is ready + + try: + # Create environment in the executor thread (outside lock) + loop = asyncio.get_event_loop() + env = await loop.run_in_executor(executor, self._env_factory) + except Exception as e: + async with self._session_lock: + executor.shutdown(wait=False) + self._session_executors.pop(session_id, None) + self._sessions.pop(session_id, None) + factory_name = getattr( + self._env_factory, "__name__", str(self._env_factory) + ) + raise EnvironmentFactoryError(factory_name) from e + + async with self._session_lock: + self._sessions[session_id] = env + self._session_info[session_id] = SessionInfo( + session_id=session_id, + created_at=current_time, + last_activity_at=current_time, + step_count=0, + environment_type=type(env).__name__, + ) + + return session_id, env + + async def _destroy_session(self, session_id: str) -> None: + """ + Destroy a WebSocket session and cleanup resources. + + Args: + session_id: The session ID to destroy + """ + async with self._session_lock: + env = self._sessions.pop(session_id, None) + executor = self._session_executors.pop(session_id, None) + self._session_info.pop(session_id, None) + + # Run close() in the same executor where the env was created + # This is required for thread-sensitive libraries like Playwright/greenlet + if env is not None: + if executor is not None: + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(executor, env.close) + except Exception: + # If executor close fails, try direct close as fallback + try: + env.close() + except Exception: + pass # Best effort cleanup + else: + try: + env.close() + except Exception: + pass # Best effort cleanup + + # Shutdown executor after close is done + if executor is not None: + executor.shutdown(wait=False) + + def _update_session_activity( + self, session_id: str, increment_step: bool = False + ) -> None: + """ + Update session activity timestamp and optionally increment step count. + + Args: + session_id: The session ID to update + increment_step: If True, increment the step count + """ + if session_id in self._session_info: + self._session_info[session_id].last_activity_at = time.time() + if increment_step: + self._session_info[session_id].step_count += 1 + + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: + """ + Get information about a specific session. + + Args: + session_id: The session ID to query + + Returns: + SessionInfo if the session exists, None otherwise + """ + return self._session_info.get(session_id) + + async def _run_in_session_executor( + self, session_id: str, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the session's thread pool executor.""" + executor = self._session_executors.get(session_id, self._executor) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + @property + def active_sessions(self) -> int: + """Return the number of active WebSocket sessions.""" + return len(self._sessions) + + @property + def max_concurrent_envs(self) -> int: + """Return the maximum number of concurrent environments.""" + return self._max_concurrent_envs + + @property + def is_concurrency_safe(self) -> bool: + """Return whether the environment is marked as concurrency safe.""" + import inspect + + if inspect.isclass(self._env_factory): + return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + else: + _temp_env = self._env_factory() + result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + _temp_env.close() + del _temp_env + return result + + @property + def concurrency_config(self) -> ConcurrencyConfig: + """Return the concurrency configuration.""" + return self._concurrency_config + + def register_routes( + self, app: FastAPI, mode: ServerMode | str = ServerMode.SIMULATION + ) -> None: + """ + Register HTTP routes on a FastAPI application. + + Args: + app: FastAPI application instance + mode: Server mode - either SIMULATION or PRODUCTION (or string equivalents). + In production mode, simulation control endpoints (/reset, /step, /state) + are NOT registered. Only safe endpoints (/health, /schema, /metadata, /ws) + are available. Defaults to SIMULATION for backwards compatibility. + + Raises: + ValueError: If mode is not a valid ServerMode or string equivalent. + """ + # Convert string to ServerMode enum for backwards compatibility + if isinstance(mode, str): + try: + mode = ServerMode(mode.lower()) + except ValueError: + valid_modes = [m.value for m in ServerMode] + raise ValueError( + f"Invalid mode: '{mode}'. Must be one of: {valid_modes}" + ) + + # Helper function to handle reset endpoint + async def reset_handler( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + """Reset endpoint - returns initial observation.""" + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True) + + is_async = _env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(_env.reset_async) + else: + sig = inspect.signature(_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + if is_async: + observation = await _env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + finally: + _env.close() + + # Helper function to handle step endpoint + async def step_handler(request: StepRequest) -> StepResponse: + """Step endpoint - executes action and returns observation.""" + action_data = request.action + + try: + action = deserialize_action(action_data, self.action_cls) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() + ) + + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + + is_async = _env.step_async.__func__ is not Environment.step_async + + if is_async: + sig = inspect.signature(_env.step_async) + else: + sig = inspect.signature(_env.step) + valid_kwargs = self._get_valid_kwargs( + sig, kwargs, skip_params={"action"} + ) + + if is_async: + observation = await _env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.step, action, **valid_kwargs + ) + + return StepResponse(**serialize_observation(observation)) + finally: + _env.close() + + # Helper function to handle MCP endpoint + async def mcp_handler( + request: JsonRpcRequest, session_env: Optional[Environment] = None + ) -> JsonRpcResponse: + """ + Handle MCP JSON-RPC requests. + + Supports tools/list and tools/call methods in JSON-RPC 2.0 format. + """ + method = request.method + request_id = request.id + + # Use provided session environment or create temporary one + if session_env is not None: + _env = session_env + should_close = False + else: + _env = self._env_factory() + should_close = True + try: + if method == McpMethod.TOOLS_LIST: + # Check if environment is MCP-enabled + if not hasattr(_env, "mcp_client"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ) + + # Use async context manager for MCP client + async with _env.mcp_client: + tools = await _env.mcp_client.list_tools() + + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() if hasattr(t, "model_dump") else dict(t) + for t in tools + ] + }, + request_id=request_id, + ) + + elif method == McpMethod.TOOLS_CALL: + params = request.params + tool_name = params.get("name") + arguments = params.get("arguments", {}) + + if not hasattr(_env, "mcp_client"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ) + + if not tool_name: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + "Missing 'name' in params", + request_id=request_id, + ) + + # Use async context manager for MCP client + async with _env.mcp_client: + result = await _env.mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + + # Ensure result is JSON serializable + serializable_result = _make_json_serializable(result) + + return JsonRpcResponse.success( + result=serializable_result, + request_id=request_id, + ) + + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.METHOD_NOT_FOUND, + f"Method not found: {method}", + request_id=request_id, + ) + + except Exception as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=request_id, + ) + finally: + if should_close: + _env.close() + + # Register MCP WebSocket endpoint (available in both production and simulation modes) + @app.websocket("/mcp") + async def mcp_websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for MCP JSON-RPC requests. + + Each WebSocket connection gets its own environment instance for MCP operations. + + Message Protocol: + - Client sends: JSON-RPC 2.0 request (tools/list, tools/call) + - Server responds: JSON-RPC 2.0 response (result or error) + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + jsonrpc_dict = json.loads(raw_message) + jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) + except json.JSONDecodeError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + except ValidationError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + try: + # Call mcp_handler with session environment + response = await mcp_handler( + jsonrpc_request, session_env=session_env + ) + await websocket.send_text(response.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=jsonrpc_request.id, + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + data={ + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + }, + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + data={"factory_name": e.factory_name}, + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + + # Register simulation control routes only in simulation mode + if mode == ServerMode.SIMULATION: + + @app.post( + "/reset", + response_model=ResetResponse, + tags=["Environment Control"], + summary="Reset the environment", + description=""" +Reset the environment to its initial state and return the first observation. + +You can optionally provide a seed for reproducibility and an episode_id for tracking. + """, + responses={ + 200: { + "description": "Environment reset successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "ready", "data": {}}, + "reward": None, + "done": False, + } + } + }, + } + }, + ) + async def reset( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + return await reset_handler(request) + + @app.post( + "/step", + response_model=StepResponse, + tags=["Environment Control"], + summary="Execute an action in the environment", + description=""" +Execute an action in the environment and receive the resulting observation. + +The action must conform to the environment's action schema, which can be +retrieved from the `/schema` endpoint. If the action is invalid, +the endpoint will return HTTP 422 with detailed validation errors. + +The response includes: +- **observation**: The environment's response to the action +- **reward**: Optional reward signal (float or None) +- **done**: Boolean indicating if the episode has terminated + """, + responses={ + 200: { + "description": "Action executed successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "success", "data": {}}, + "reward": 1.0, + "done": False, + } + } + }, + }, + 422: { + "description": "Validation error - invalid action format or values", + "content": { + "application/json": { + "example": { + "detail": [ + { + "type": "string_too_short", + "loc": ["body", "action", "message"], + "msg": "String should have at least 1 character", + "input": "", + } + ] + } + } + }, + }, + 500: { + "description": "Internal server error during action execution" + }, + }, + ) + async def step(request: StepRequest) -> StepResponse: + return await step_handler(request) + + def get_state_handler() -> State: + _env = self._env_factory() + try: + return _env.state + finally: + _env.close() + + def get_metadata_handler() -> EnvironmentMetadata: + _env = self._env_factory() + try: + return _env.get_metadata() + finally: + _env.close() + + # Build list of GET endpoints based on mode + get_endpoints = [ + GetEndpointConfig( + path="/metadata", + handler=get_metadata_handler, + response_model=EnvironmentMetadata, + tag="Environment Info", + summary="Get environment metadata", + description=""" +Get metadata about this environment. + +Returns information about the environment including name, description, +version, author, and documentation links. + """, + ), + GetEndpointConfig( + path="/health", + handler=lambda: HealthResponse(status=HealthStatus.HEALTHY), + response_model=HealthResponse, + tag="Health", + summary="Health check", + description="Check if the environment server is running and healthy.", + ), + ] + + # Only register /state endpoint in simulation mode + if mode == ServerMode.SIMULATION: + get_endpoints.insert( + 0, + GetEndpointConfig( + path="/state", + handler=get_state_handler, + response_model=State, + tag="State Management", + summary="Get current environment state", + description=""" +Retrieve the current internal state of the environment. + +The structure of the state object is defined by the environment's State model. + """, + ), + ) + + register_get_endpoints(app, get_endpoints) + + # Register combined schema endpoint + @app.get( + "/schema", + response_model=SchemaResponse, + tags=["Schema"], + summary="Get all JSON schemas", + description=""" +Get JSON schemas for actions, observations, and state in a single response. + +Returns a combined schema object containing: +- **action**: JSON schema for actions accepted by this environment +- **observation**: JSON schema for observations returned by this environment +- **state**: JSON schema for environment state objects + +This is more efficient than calling individual schema endpoints and provides +all schema information needed to interact with the environment. + """, + responses={ + 200: { + "description": "Combined schemas retrieved successfully", + "content": { + "application/json": { + "example": { + "action": { + "type": "object", + "properties": {"message": {"type": "string"}}, + }, + "observation": { + "type": "object", + "properties": {"response": {"type": "string"}}, + }, + "state": { + "type": "object", + "properties": {"step_count": {"type": "integer"}}, + }, + } + } + }, + } + }, + ) + async def get_schemas() -> SchemaResponse: + """Return all schemas in one response.""" + return SchemaResponse( + action=self.action_cls.model_json_schema(), + observation=self.observation_cls.model_json_schema(), + state=State.model_json_schema(), + ) + + # Register MCP endpoint for production mode (direct MCP access) + @app.post("/mcp") + async def mcp_endpoint(request_raw: Request) -> Dict[str, Any]: + """ + MCP JSON-RPC endpoint for production mode. + + Bypasses step() overhead and provides direct access to MCP tools. + Supports tools/list and tools/call methods. + """ + # Parse JSON manually to handle parse errors gracefully + try: + body = await request_raw.body() + request_dict = json.loads(body) + request = JsonRpcRequest(**request_dict) + except json.JSONDecodeError: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR + ).model_dump() + except ValidationError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ).model_dump() + except Exception: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR + ).model_dump() + + method = request.method + params = request.params + request_id = request.id + + # Create a temporary environment for MCP access + _env = self._env_factory() + + try: + # Check if environment supports MCP + if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ).model_dump() + + if method == McpMethod.TOOLS_LIST: + # List tools from MCP server + if hasattr(_env, "mcp_client") and _env.mcp_client: + async with _env.mcp_client: + tools = await _env.mcp_client.list_tools() + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() + if hasattr(t, "model_dump") + else dict(t) + for t in tools + ] + }, + request_id=request_id, + ).model_dump() + elif hasattr(_env, "mcp_server") and _env.mcp_server: + # Use server directly + tools = [] + for tool_name, tool in get_server_tools( + _env.mcp_server + ).items(): + tool_dict = { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.parameters or {}, + } + tools.append(tool_dict) + return JsonRpcResponse.success( + result={"tools": tools}, + request_id=request_id, + ).model_dump() + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, + ).model_dump() + + elif method == McpMethod.TOOLS_CALL: + tool_name = params.get("name") + arguments = params.get("arguments", {}) + + if not tool_name: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Invalid params - 'name' is required", + request_id=request_id, + ).model_dump() + + # Call tool via MCP + if hasattr(_env, "mcp_client") and _env.mcp_client: + async with _env.mcp_client: + result = await _env.mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif hasattr(_env, "mcp_server") and _env.mcp_server: + # Call tool directly on FastMCP server + server_tools = get_server_tools(_env.mcp_server) + if tool_name in server_tools: + tool = server_tools[tool_name] + result = tool.fn(**arguments) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Tool not found: {tool_name}", + request_id=request_id, + ).model_dump() + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, + ).model_dump() + + # Make result JSON serializable + serializable_result = _make_json_serializable(result) + + return JsonRpcResponse.success( + result=serializable_result, + request_id=request_id, + ).model_dump() + + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.METHOD_NOT_FOUND, + f"Method not found: {method}", + request_id=request_id, + ).model_dump() + + except Exception as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=request_id, + ).model_dump() + finally: + _env.close() + + # Register WebSocket endpoint for persistent sessions + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for persistent environment sessions. + + Each WebSocket connection gets its own environment instance. + + Message Protocol: + - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage + - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": WSErrorCode.INVALID_JSON, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + msg_type = message_dict.get("type", "") + + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) + + is_async = ( + session_env.reset_async.__func__ + is not Environment.reset_async + ) + + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async( + **valid_kwargs + ) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation), + ) + + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action(msg.data, self.action_cls) + + is_async = ( + session_env.step_async.__func__ + is not Environment.step_async + ) + + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) + + self._update_session_activity( + session_id, increment_step=True + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case "mcp": + msg = WSMCPMessage(**message_dict) + try: + rpc_request = JsonRpcRequest(**msg.data) + except (ValidationError, Exception) as e: + rpc_response = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + else: + rpc_response = await mcp_handler( + rpc_request, + session_env=session_env, + ) + response = WSMCPResponse(data=rpc_response.model_dump()) + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": WSErrorCode.UNKNOWN_TYPE, + } + ) + + await websocket.send_text(response.model_dump_json()) + + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": WSErrorCode.VALIDATION_ERROR, + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.EXECUTION_ERROR, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.CAPACITY_REACHED, + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.FACTORY_ERROR, + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": WSErrorCode.SESSION_ERROR} + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + + +def create_app( + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, + gradio_builder: Optional[Callable[..., Any]] = None, +) -> FastAPI: + """ + Create a FastAPI application with or without web interface. + + This function creates a FastAPI app with the web interface enabled by default, + including README integration for better user experience. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + gradio_builder: Optional callable to build a custom Gradio UI at /web. + Signature: (web_manager, action_fields, metadata, is_chat_env, title, + quick_start_md) -> gr.Blocks. When None, the default Gradio app is used. + See docs/customizing-web-ui.md. + + Returns: + FastAPI application instance with or without web interface and README integration + """ + # Check if web interface should be enabled + # This can be controlled via environment variable or build argument + enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( + "true", + "1", + "yes", + ) + + if enable_web: + # Gradio-based web UI (gradio is a core dependency) + from .web_interface import create_web_interface_app + + return create_web_interface_app( + env, + action_cls, + observation_cls, + env_name, + max_concurrent_envs, + concurrency_config, + gradio_builder=gradio_builder, + ) + else: + # Use standard FastAPI app without web interface + return create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) + + +def create_fastapi_app( + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, +) -> FastAPI: + """ + Create a FastAPI application with comprehensive documentation. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Returns: + FastAPI application instance + """ + try: + from fastapi import FastAPI + except ImportError: + raise ImportError( + "FastAPI is required. Install with: pip install fastapi uvicorn" + ) + + app = FastAPI( + title="OpenEnv Environment HTTP API", + version="1.0.0", + description=""" +# OpenEnv Environment HTTP API + +HTTP API for interacting with OpenEnv environments through a standardized interface. + +## Features + +* **Environment Reset**: Initialize or restart episodes +* **Action Execution**: Send actions and receive observations +* **State Inspection**: Query current environment state +* **Schema Access**: Retrieve JSON schemas for actions and observations + +## Workflow + +1. Call `/reset` to start a new episode and get initial observation +2. Call `/step` repeatedly with actions to interact with environment +3. Episode ends when observation returns `done: true` +4. Call `/state` anytime to inspect current environment state + +## Documentation + +* **Swagger UI**: Available at `/docs` +* **ReDoc**: Available at `/redoc` +* **OpenAPI Schema**: Available at `/openapi.json` + """, + openapi_tags=[ + { + "name": "Environment Control", + "description": "Core operations for environment interaction (reset, step)", + }, + { + "name": "State Management", + "description": "Operations for inspecting environment state", + }, + { + "name": "Environment Info", + "description": "Information about the environment", + }, + { + "name": "Schema", + "description": "JSON Schema endpoints for actions, observations, and state", + }, + {"name": "Health", "description": "Service health and status checks"}, + ], + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + contact={ + "name": "OpenEnv Team", + "url": "https://github.com/meta-pytorch/OpenEnv", + }, + license_info={ + "name": "BSD-3-Clause", + "url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE", + }, + ) + + server = HTTPEnvServer( + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config=concurrency_config, + ) + server.register_routes(app) + return app diff --git a/src/core/env_server/interfaces.py b/src/core/env_server/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..9fa837549aa1e2bf1c439f1d7a52e845a556ae18 --- /dev/null +++ b/src/core/env_server/interfaces.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Generic, Optional, Protocol, TYPE_CHECKING, TypeVar + +from typing_extensions import TypedDict + +from .types import Action, EnvironmentMetadata, Observation, State + +if TYPE_CHECKING: + from openenv.core.rubrics import Rubric + +ActT = TypeVar("ActT", bound=Action) +ObsT = TypeVar("ObsT", bound=Observation) +StateT = TypeVar("StateT", bound=State) + + +class Message(TypedDict): + """A message in a conversation. + + Compatible with Huggingface chat template format. + """ + + role: str + content: str + + +class ModelTokenizer(Protocol): + """Protocol for tokenizers that support chat templates. + + This protocol defines the interface that tokenizers must implement + to work with chat-based environments. It's compatible with + Huggingface transformers tokenizers. + """ + + def apply_chat_template( + self, + conversation: list[Message], + tokenize: bool = True, + return_tensors: str | None = None, + **kwargs: Any, + ) -> Any: + """Apply a chat template to format and optionally tokenize a conversation. + + Args: + conversation: List of message dictionaries with 'role' and 'content' + tokenize: Whether to tokenize the output + return_tensors: Format for returned tensors ('pt' for PyTorch) + **kwargs: Additional arguments + + Returns: + Formatted and optionally tokenized conversation + """ + ... + + def decode( + self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any + ) -> str: + """Decode token IDs back to text. + + Args: + token_ids: Token IDs to decode + skip_special_tokens: Whether to skip special tokens in output + **kwargs: Additional arguments + + Returns: + Decoded text string + """ + ... + + +class Transform(ABC, Generic[ObsT]): + """Transform observations to add rewards, metrics, or other modifications. + + Transforms follow the TorchRL pattern where they take an observation + and return a (potentially modified) observation. This allows for + flexible reward computation and observation augmentation. + """ + + @abstractmethod + def __call__(self, observation: ObsT) -> ObsT: + """Transform an observation. + + Args: + observation: The input observation + + Returns: + The transformed observation + """ + pass + + +class Environment(ABC, Generic[ActT, ObsT, StateT]): + """Base class for all environment servers following Gym/Gymnasium API. + + Args: + transform: Optional transform to apply to observations + rubric: Optional rubric for reward computation. When provided, the + rubric's output can be used to set the observation's reward in step(). + + Class Attributes: + SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. + When True, multiple WebSocket connections can each have their own + environment instance (up to max_concurrent_envs). When False (default), + the environment should only be used with a single session at a time. + + Set this to True in your Environment subclass if: + - The environment uses proper session isolation (e.g., unique working dirs) + - No shared mutable state exists between instances + - External resources (databases, APIs) can handle concurrent access + + Attributes: + rubric: Optional rubric for computing rewards. Environments can set this + in __init__ and use it in step() to compute observation rewards. + Training infrastructure can access it for introspection: + for name, r in env.rubric.named_rubrics(): + print(f"{name}: {r.last_score}") + + See RFC 004 for rubric design: rfcs/004-rubrics.md + """ + + # Class-level flag indicating whether this environment supports concurrent sessions + SUPPORTS_CONCURRENT_SESSIONS: bool = False + + # Optional rubric for reward computation + rubric: Optional["Rubric"] + + def __init__( + self, + transform: Optional[Transform[ObsT]] = None, + rubric: Optional["Rubric"] = None, + ): + self.transform = transform + self.rubric = rubric + + @abstractmethod + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Reset the environment and return initial observation.""" + pass + + async def reset_async( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of reset. Default implementation calls sync reset. + + Override to provide true async implementation. + """ + return self.reset(seed=seed, episode_id=episode_id, **kwargs) + + @abstractmethod + def step( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Take a step in the environment.""" + pass + + async def step_async( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of step. Default implementation calls sync step. + + Override to provide true async implementation. + """ + return self.step(action, timeout_s=timeout_s, **kwargs) + + @property + @abstractmethod + def state(self) -> StateT: + """Get the current environment state.""" + pass + + def get_metadata(self) -> EnvironmentMetadata: + """ + Get metadata about this environment. + + Override this method to provide custom metadata for the environment. + Default implementation returns basic metadata derived from class name. + + Returns: + EnvironmentMetadata with environment information + """ + return EnvironmentMetadata( + name=self.__class__.__name__, + description=f"{self.__class__.__name__} environment", + version="1.0.0", + ) + + def _apply_transform(self, observation: ObsT) -> ObsT: + """Apply transform if one is provided.""" + if self.transform is not None: + return self.transform(observation) + return observation + + def _apply_rubric(self, action: ActT, observation: ObsT) -> float: + """Apply rubric if one is provided. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value from the rubric, or 0.0 if no rubric is set. + + Usage in step(): + def step(self, action: MyAction, ...) -> MyObservation: + # ... execute action and create observation ... + observation.reward = self._apply_rubric(action, observation) + return observation + """ + if self.rubric is not None: + return self.rubric(action, observation) + return 0.0 + + async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float: + """Apply rubric asynchronously if one is provided. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value from the rubric, or 0.0 if no rubric is set. + + Usage in step_async(): + async def step_async(self, action: MyAction, ...) -> MyObservation: + # ... execute action and create observation ... + observation.reward = await self._apply_rubric_async(action, observation) + return observation + """ + if self.rubric is not None: + result = self.rubric(action, observation) + # If rubric returns a coroutine, await it + if inspect.iscoroutine(result): + return await result + return result + return 0.0 + + def _reset_rubric(self) -> None: + """Reset the rubric state if one is provided. + + Call this in reset() to clear any trajectory state in the rubric. + + Usage in reset(): + def reset(self, ...) -> MyObservation: + self._reset_rubric() + # ... create initial observation ... + return observation + """ + if self.rubric is not None: + self.rubric.reset() + + async def _reset_rubric_async(self) -> None: + """Reset the rubric state asynchronously if one is provided. + + Call this in reset_async() to clear any trajectory state in the rubric. + + Usage in reset_async(): + async def reset_async(self, ...) -> MyObservation: + await self._reset_rubric_async() + # ... create initial observation ... + return observation + """ + if self.rubric is not None: + # Check if rubric has async reset method + if hasattr(self.rubric, "reset_async"): + result = self.rubric.reset_async() + if inspect.iscoroutine(result): + await result + else: + self.rubric.reset() + + def close(self) -> None: + """Clean up resources used by the environment. + + Override this method to implement custom cleanup logic. + Called when the environment is being destroyed or reset. + """ + pass diff --git a/src/core/env_server/mcp_environment.py b/src/core/env_server/mcp_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..03f66e37897ec81796d468f3d0590d465deddea1 --- /dev/null +++ b/src/core/env_server/mcp_environment.py @@ -0,0 +1,624 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP Environment base class for OpenEnv. + +This module provides the MCPEnvironment base class that integrates FastMCP servers +with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery +and invocation through the step() API, following RFC 003. + +Key features: +- Automatic routing of ListToolsAction and CallToolAction to MCP server +- Reserved tool name validation (reset, step, state, close are protected) +- Timeout handling for tool calls +- Proper error categorization (tool not found, execution errors, timeouts) +- Mode-aware tool registration (production vs simulation) +- Code mode support via get_callables() and execute_code() + +Usage: + from fastmcp import FastMCP + from openenv.core.env_server.mcp_environment import MCPEnvironment + + class MyMCPEnv(MCPEnvironment): + def __init__(self): + mcp = FastMCP("my-server") + + # Register mode-specific tools + @self.tool(mode="production") + def my_tool(arg: str) -> str: + return f"Production: {arg}" + + @self.tool(mode="simulation") + def my_tool(arg: str) -> str: + return f"Simulation: {arg}" + + super().__init__(mcp) + + def reset(self, seed=None, episode_id=None, **kwargs): + # Reset logic here + ... + + def _step_impl(self, action): + # Handle non-MCP actions + ... + + @property + def state(self): + # Return current state + ... +""" + +import asyncio +import inspect +from abc import abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, Optional + +from fastmcp import Client +from fastmcp.client.client import CallToolResult +from mcp.types import TextContent + +from ..utils import run_async_safely +from .interfaces import Environment +from .mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + RESERVED_TOOL_NAMES, + Tool, + ToolError, + ToolErrorType, +) +from .types import Action, Observation + + +# Default timeout for MCP tool calls in seconds +MCP_TOOL_CALL_TIMEOUT = 30.0 + +# Valid modes for tool registration +VALID_MODES = {"production", "simulation"} + + +def get_server_tools(mcp_server: Any) -> Dict[str, Any]: + """ + Get tools from a FastMCP server, compatible with both 2.x and 3.x. + + Returns: + Dictionary mapping tool names to tool objects. + """ + # FastMCP 2.x: get_tools() returns dict {name: Tool} + if hasattr(mcp_server, "get_tools"): + result = run_async_safely(mcp_server.get_tools()) + if isinstance(result, dict): + return result + # FastMCP 3.x: list_tools() returns list of Tool objects + if hasattr(mcp_server, "list_tools"): + tools_list = run_async_safely(mcp_server.list_tools()) + return {t.name: t for t in tools_list} + return {} + + +class MCPEnvironment(Environment): + """ + Base class for environments that expose tools via MCP (Model Context Protocol). + + MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing + agents to discover and invoke MCP tools through the standard step() interface. + + The class automatically handles: + - ListToolsAction: Returns available tools from the MCP server + - CallToolAction: Invokes a specific tool with arguments + + All other actions are delegated to the abstract _step_impl() method, + which subclasses must implement. + + Args: + mcp_server: A FastMCP server instance containing tool definitions. + The server's tools will be validated against reserved names. + transform: Optional transform to apply to observations (inherited from Environment). + + Raises: + ValueError: If any tool in the MCP server uses a reserved name + (reset, step, state, close). + + Example: + >>> from fastmcp import FastMCP + >>> mcp = FastMCP("calculator") + >>> @mcp.tool() + ... def add(a: int, b: int) -> int: + ... return a + b + >>> env = MyMCPEnvironment(mcp) + >>> obs = env.step(ListToolsAction()) + >>> obs.tools[0].name + 'add' + """ + + def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None: + """ + Initialize the MCP environment. + + Args: + mcp_server: A FastMCP server instance with tool definitions. + transform: Optional transform to apply to observations. + + Raises: + ValueError: If any tool uses a reserved name (reset, step, state, close). + """ + super().__init__(transform=transform) + + # Validate tool names before storing + self._validate_tool_names(mcp_server) + + self.mcp_server = mcp_server + self.mcp_client = Client(mcp_server) + + # Track mode-specific tools: {tool_name: {mode: func}} + # mode can be "production", "simulation", or None (available in all modes) + self._mode_tools = defaultdict(dict) + + # Track tool schemas for list_tools: {tool_name: {mode: schema}} + self._mode_tool_schemas = defaultdict(dict) + + @property + def supports_code_mode(self) -> bool: + """Check if this environment supports code mode (execute_code).""" + return True + + def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]: + """ + Get tools from a FastMCP server, compatible with both 2.x and 3.x. + + Returns: + Dictionary mapping tool names to tool objects. + """ + return get_server_tools(mcp_server) + + def get_callables(self) -> Dict[str, Callable]: + """ + Get callable functions for code mode. + + Returns tool functions as direct Python callables, enabling code mode + where agents write Python code that calls tools directly (no JSON-RPC + overhead). Mode-specific tools are filtered by the current mode. + + Returns: + Dictionary mapping tool names to callables. + """ + callables: Dict[str, Callable] = {} + current_mode = getattr(self, "_mode", None) + + # Extract callables from FastMCP server using public API + for tool_name, tool in self._get_server_tools(self.mcp_server).items(): + if hasattr(tool, "fn") and callable(tool.fn): + callables[tool_name] = tool.fn + + # Add mode-specific tools available in current mode + for tool_name, mode_funcs in self._mode_tools.items(): + if None in mode_funcs: + # Tool available in all modes (already in FastMCP if registered there) + if tool_name not in callables: + callables[tool_name] = mode_funcs[None] + elif current_mode in mode_funcs: + # Tool available in current mode only + callables[tool_name] = mode_funcs[current_mode] + + return callables + + def execute_code(self, code: str) -> Observation: + """ + Execute Python code with tools available as callables. + + This enables the CodeAct pattern where agents write Python code + that calls tools directly as functions, avoiding JSON-RPC overhead. + + Args: + code: Python code to execute. Tools are available as functions + in the execution namespace. Set a variable named 'result' + to capture the return value. + + Returns: + Observation with result in metadata["result"] or error in + metadata["error"]. + """ + namespace = self.get_callables() + + result_dict: Dict[str, Any] = {} + try: + exec(code, namespace, result_dict) + result = result_dict.get("result") + return Observation(done=False, reward=0.0, metadata={"result": result}) + except SyntaxError as e: + return Observation( + done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"} + ) + except Exception as e: + return Observation(done=False, reward=0.0, metadata={"error": str(e)}) + + def _validate_tool_names(self, mcp_server: Any) -> None: + """ + Validate that no tools use reserved names. + + Reserved names (reset, step, state, close) are protected to maintain + the dual API boundary between infrastructure and agent APIs. + + Args: + mcp_server: The FastMCP server to validate. + + Raises: + ValueError: If any tool uses a reserved name. + """ + tools_dict = self._get_server_tools(mcp_server) + if tools_dict: + tool_names = set(tools_dict.keys()) + conflicts = tool_names & RESERVED_TOOL_NAMES + if conflicts: + raise ValueError( + f"MCP tools cannot use reserved names: {sorted(conflicts)}. " + f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" + ) + + def tool(self, mode: Optional[str] = None) -> Callable: + """ + Decorator for registering mode-aware tools. + + Args: + mode: Optional mode for the tool ("production" or "simulation"). + If None, tool is available in all modes. + + Returns: + A decorator function for registering tools. + + Raises: + ValueError: If mode is not None, "production", or "simulation". + """ + if mode is not None and mode not in VALID_MODES: + raise ValueError( + f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None." + ) + + def decorator(func: Callable) -> Callable: + tool_name = func.__name__ + # Validate tool name is not reserved + if tool_name in RESERVED_TOOL_NAMES: + raise ValueError( + f"Tool name '{tool_name}' is reserved and cannot be used. " + f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" + ) + + # If mode is None, register with FastMCP as usual + if mode is None: + decorated_func = self.mcp_server.tool()(func) + self._mode_tools[tool_name][None] = func + return decorated_func + + # For mode-specific tools, don't register with FastMCP + # Instead, track them ourselves + self._mode_tools[tool_name][mode] = func + + # Extract schema information from function signature + sig = inspect.signature(func) + schema = { + "type": "object", + "properties": {}, + "required": [], + } + + for param_name, param in sig.parameters.items(): + # Get type annotation + param_type = param.annotation + json_type = "string" # default + if param_type in (int, "int"): + json_type = "integer" + elif param_type in (float, "float"): + json_type = "number" + elif param_type in (bool, "bool"): + json_type = "boolean" + + schema["properties"][param_name] = {"type": json_type} + + # If no default value, it's required + if param.default == inspect.Parameter.empty: + schema["required"].append(param_name) + + # Store the schema for this mode-specific tool + self._mode_tool_schemas[tool_name][mode] = { + "name": tool_name, + "description": func.__doc__ or "", + "input_schema": schema, + } + + return func + + return decorator + + def step( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """ + Execute an action in the environment. + + This method routes MCP-specific actions (ListToolsAction, CallToolAction) + to the appropriate handlers, while delegating all other actions to + the subclass's _step_impl() method. + + Args: + action: The action to execute. Can be: + - ListToolsAction: Returns available MCP tools + - CallToolAction: Invokes a specific MCP tool + - Any other Action: Delegated to _step_impl() + timeout_s: Optional timeout in seconds for the action. + Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions. + **kwargs: Additional arguments passed to handlers. + + Returns: + Observation appropriate to the action type: + - ListToolsObservation for ListToolsAction + - CallToolObservation for CallToolAction + - Subclass-defined Observation for other actions + """ + if isinstance(action, ListToolsAction): + return self._handle_list_tools() + elif isinstance(action, CallToolAction): + return self._handle_call_tool(action, timeout_s=timeout_s) + else: + return self._step_impl(action, timeout_s=timeout_s, **kwargs) + + def _handle_list_tools(self) -> ListToolsObservation: + """ + Handle a ListToolsAction by querying the MCP server. + + Returns: + ListToolsObservation containing all available tools with their + names, descriptions, and input schemas, filtered by current mode. + """ + try: + # Get current mode + current_mode = getattr(self, "_mode", None) + + # Start with tools from FastMCP server (mode=None tools) + tools_result = run_async_safely(self._async_list_tools()) + + # Build list of Tool objects + tools = [] + + # Add FastMCP tools that are not mode-specific + for tool in tools_result: + if tool.name not in self._mode_tool_schemas: + tools.append( + Tool( + name=tool.name, + description=tool.description or "", + input_schema=tool.inputSchema + if hasattr(tool, "inputSchema") + else {}, + ) + ) + + # Add mode-specific tools available in current mode + for tool_name, mode_schemas in self._mode_tool_schemas.items(): + if None in mode_schemas: + # Tool available in all modes + schema = mode_schemas[None] + tools.append( + Tool( + name=schema["name"], + description=schema["description"], + input_schema=schema["input_schema"], + ) + ) + elif current_mode in mode_schemas: + # Tool available in current mode + schema = mode_schemas[current_mode] + tools.append( + Tool( + name=schema["name"], + description=schema["description"], + input_schema=schema["input_schema"], + ) + ) + + return ListToolsObservation(tools=tools) + + except Exception as e: + # Return an observation with error in metadata + return ListToolsObservation( + tools=[], + metadata={ + "error": str(e), + "error_type": "list_tools_failed", + }, + ) + + async def _async_list_tools(self) -> list: + """ + Async helper to list tools from the MCP client. + + Returns: + List of tool objects from the MCP server. + """ + async with self.mcp_client: + return await self.mcp_client.list_tools() + + def _handle_call_tool( + self, + action: CallToolAction, + timeout_s: Optional[float] = None, + ) -> CallToolObservation: + """ + Handle a CallToolAction by invoking the specified tool. + + Args: + action: The CallToolAction containing tool_name and arguments. + timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). + + Returns: + CallToolObservation with the tool's result or an error. + """ + timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT + + # Check if this is a mode-specific tool + tool_name = action.tool_name + current_mode = getattr(self, "_mode", None) + + if tool_name in self._mode_tools: + mode_info = self._mode_tools[tool_name] + + # Check if tool is available in current mode + # Tool is available if: + # 1. It has a None mode (available in all modes), OR + # 2. It has an implementation for the current mode + if None in mode_info: + # Use the mode-agnostic version + func = mode_info[None] + elif current_mode in mode_info: + # Use the mode-specific version + func = mode_info[current_mode] + else: + # Tool not available in current mode + return CallToolObservation( + tool_name=tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.TOOL_NOT_FOUND, + message=f"Tool '{tool_name}' not available in {current_mode} mode", + ), + ) + + # Call the mode-specific function directly + try: + # Check if function is async and await if necessary + if inspect.iscoroutinefunction(func): + result = run_async_safely(func(**action.arguments)) + else: + result = func(**action.arguments) + + # Wrap result in CallToolResult format to match FastMCP behavior + return CallToolObservation( + tool_name=tool_name, + result=CallToolResult( + content=[TextContent(type="text", text=str(result))], + structured_content={"result": result}, + meta=None, + data=result, + is_error=False, + ), + ) + except Exception as e: + return CallToolObservation( + tool_name=tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.EXECUTION_ERROR, + message=str(e), + ), + ) + + # Not a mode-specific tool, use FastMCP + try: + # Run the async call_tool with timeout + # Use run_async_safely to handle both sync and async contexts + result = run_async_safely( + asyncio.wait_for( + self._async_call_tool(action.tool_name, action.arguments), + timeout=timeout, + ) + ) + + return CallToolObservation( + tool_name=action.tool_name, + result=result, + ) + + except asyncio.TimeoutError: + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.TIMEOUT, + message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", + ), + ) + + except Exception as e: + error_message = str(e) + + # Determine error type based on the exception + if ( + "not found" in error_message.lower() + or "unknown tool" in error_message.lower() + ): + error_type = ToolErrorType.TOOL_NOT_FOUND + elif ( + "invalid" in error_message.lower() + or "argument" in error_message.lower() + ): + error_type = ToolErrorType.INVALID_ARGS + else: + error_type = ToolErrorType.EXECUTION_ERROR + + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=error_type, + message=error_message, + ), + ) + + async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + """ + Async helper to call a tool on the MCP server. + + Args: + tool_name: Name of the tool to invoke. + arguments: Dictionary of arguments to pass to the tool. + + Returns: + The result from the tool execution. + """ + async with self.mcp_client: + return await self.mcp_client.call_tool(tool_name, arguments) + + @abstractmethod + def _step_impl( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """ + Handle non-MCP actions in the environment. + + Subclasses must implement this method to handle any actions that are + not ListToolsAction or CallToolAction. This is where environment-specific + action processing should occur. + + Args: + action: The action to execute (guaranteed not to be an MCP action). + timeout_s: Optional timeout in seconds. + **kwargs: Additional arguments. + + Returns: + An Observation appropriate for the action. + """ + pass + + def close(self) -> None: + """ + Clean up resources used by the environment. + + This method cleans up the MCP client and any other resources. + Subclasses should call super().close() if they override this method. + """ + # The MCP client uses async context manager, so cleanup happens + # automatically when the context exits. We just clear references. + self.mcp_client = None + self.mcp_server = None diff --git a/src/core/env_server/mcp_types.py b/src/core/env_server/mcp_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa5b7449e2fa60dea46efc6b0992a6359146b2b --- /dev/null +++ b/src/core/env_server/mcp_types.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP (Model Context Protocol) type definitions for OpenEnv. + +This module defines strongly typed models for MCP tool discovery and invocation, +following RFC 003. These types map MCP's REST-like API (tools/list, tools/call) +to Gym-style action types. + +Key design decisions: +- Tool discovery (list_tools) does NOT require reset() first +- Reserved tool names (reset, step, state, close) are prohibited +- Both step() and WebSocket /mcp paths are supported +""" + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +from .types import Action, BaseMessage, Observation + + +# ============================================================================= +# JSON-RPC 2.0 Types +# ============================================================================= + + +class JsonRpcErrorCode(int, Enum): + """ + Standard JSON-RPC 2.0 error codes. + + See: https://www.jsonrpc.org/specification#error_object + """ + + # Standard JSON-RPC errors + PARSE_ERROR = -32700 # Invalid JSON was received + INVALID_REQUEST = -32600 # JSON is not a valid Request object + METHOD_NOT_FOUND = -32601 # Method does not exist / is not available + INVALID_PARAMS = -32602 # Invalid method parameter(s) + INTERNAL_ERROR = -32603 # Internal JSON-RPC error + + # Server errors (reserved for implementation-defined errors) + SERVER_ERROR = -32000 # Generic server error + + +class McpMethod(str, Enum): + """Supported MCP method names.""" + + TOOLS_LIST = "tools/list" + TOOLS_CALL = "tools/call" + + +class JsonRpcError(BaseModel): + """ + JSON-RPC 2.0 error object. + + See: https://www.jsonrpc.org/specification#error_object + """ + + model_config = ConfigDict(extra="forbid") + + code: int = Field(description="Error code indicating the error type") + message: str = Field(description="Short description of the error") + data: Optional[Any] = Field( + default=None, description="Additional error information" + ) + + @classmethod + def from_code( + cls, code: JsonRpcErrorCode, message: Optional[str] = None, data: Any = None + ) -> "JsonRpcError": + """Create an error from a standard error code.""" + default_messages = { + JsonRpcErrorCode.PARSE_ERROR: "Parse error", + JsonRpcErrorCode.INVALID_REQUEST: "Invalid Request", + JsonRpcErrorCode.METHOD_NOT_FOUND: "Method not found", + JsonRpcErrorCode.INVALID_PARAMS: "Invalid params", + JsonRpcErrorCode.INTERNAL_ERROR: "Internal error", + JsonRpcErrorCode.SERVER_ERROR: "Server error", + } + return cls( + code=code.value, + message=message or default_messages.get(code, "Unknown error"), + data=data, + ) + + +class JsonRpcRequest(BaseModel): + """ + JSON-RPC 2.0 request object. + + See: https://www.jsonrpc.org/specification#request_object + """ + + model_config = ConfigDict(extra="forbid") + + jsonrpc: Literal["2.0"] = Field(description="JSON-RPC version, must be '2.0'") + method: str = Field(description="Name of the method to be invoked") + params: Dict[str, Any] = Field( + default_factory=dict, description="Parameter values for the method" + ) + id: Optional[Union[str, int]] = Field( + default=None, description="Request identifier established by the client" + ) + + +class JsonRpcResponse(BaseModel): + """ + JSON-RPC 2.0 response object. + + Per JSON-RPC 2.0 spec, a response has either 'result' or 'error', not both. + This model excludes None values during serialization to comply with the spec. + + See: https://www.jsonrpc.org/specification#response_object + """ + + model_config = ConfigDict(extra="forbid") + + jsonrpc: Literal["2.0"] = Field(default="2.0", description="JSON-RPC version") + result: Optional[Any] = Field( + default=None, description="Result of the method invocation" + ) + error: Optional[JsonRpcError] = Field( + default=None, description="Error object if method invocation failed" + ) + id: Optional[Union[str, int]] = Field( + default=None, description="Request identifier from the request" + ) + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """Serialize to dict, excluding result or error when None (JSON-RPC compliance).""" + # Always include jsonrpc and id, but only include result OR error + data: Dict[str, Any] = {"jsonrpc": self.jsonrpc, "id": self.id} + if self.error is not None: + data["error"] = ( + self.error.model_dump() + if hasattr(self.error, "model_dump") + else self.error + ) + else: + # Only include result if there's no error + data["result"] = self.result + return data + + def model_dump_json(self, **kwargs) -> str: + """Serialize to JSON string, excluding result or error when None (JSON-RPC compliance).""" + import json + + return json.dumps(self.model_dump()) + + @classmethod + def success( + cls, result: Any, request_id: Optional[Union[str, int]] = None + ) -> "JsonRpcResponse": + """Create a success response.""" + return cls(result=result, id=request_id) + + @classmethod + def error_response( + cls, + code: JsonRpcErrorCode, + message: Optional[str] = None, + data: Any = None, + request_id: Optional[Union[str, int]] = None, + ) -> "JsonRpcResponse": + """Create an error response from a standard error code.""" + return cls( + error=JsonRpcError.from_code(code, message, data), + id=request_id, + ) + + +# ============================================================================= +# MCP Tool Types +# ============================================================================= + + +class Tool(BaseModel): + """ + Strongly typed MCP tool specification. + + Follows the MCP ToolSpec format for tool discovery. + See: https://modelcontextprotocol.io/specification/2025-06-18/server/tools + """ + + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="Unique identifier for the tool") + description: str = Field( + description="Human-readable description of what the tool does" + ) + input_schema: Dict[str, Any] = Field( + description="JSON Schema for the tool's input parameters" + ) + + +class ToolErrorType(str, Enum): + """Types of errors that can occur during tool execution.""" + + EXECUTION_ERROR = "execution_error" # Tool ran but failed + INVALID_ARGS = "invalid_args" # Invalid arguments provided + TRANSPORT_ERROR = "transport_error" # Communication failure + TOOL_NOT_FOUND = "tool_not_found" # Tool doesn't exist + TIMEOUT = "timeout" # Operation timed out + + +class ToolError(BaseModel): + """ + Structured error for tool execution failures. + + This is used for transport/framework errors, NOT for errors returned + by the tool itself (those go in the result field). + """ + + model_config = ConfigDict(extra="forbid") + + error_type: ToolErrorType = Field(description="Category of the error") + message: str = Field(description="Human-readable error message") + + +# --- MCP Actions --- + + +class ListToolsAction(Action): + """ + Request list of available tools from the environment. + + This action triggers MCP's tools/list operation and returns + all available tools with their schemas. + + Note: Does NOT require reset() to be called first. + """ + + type: Literal["list_tools"] = Field( + default="list_tools", description="Action type discriminator" + ) + + +class CallToolAction(Action): + """ + Call a specific tool via MCP. + + This action triggers MCP's tools/call operation with the + specified tool name and arguments. + """ + + type: Literal["call_tool"] = Field( + default="call_tool", description="Action type discriminator" + ) + tool_name: str = Field(description="Name of the tool to call") + arguments: Dict[str, Any] = Field( + default_factory=dict, description="Arguments to pass to the tool" + ) + + +# --- MCP Observations --- + + +class ListToolsObservation(Observation): + """ + Response containing available tools. + + Returned when processing a ListToolsAction. + """ + + tools: List[Tool] = Field(description="List of available tools with their schemas") + + +class CallToolObservation(Observation): + """ + Response from tool execution. + + Contains the tool's result or an error if the call failed. + Tool-specific errors (from the tool itself) are included in the result. + Transport/framework errors use the error field. + """ + + tool_name: str = Field(description="Name of the tool that was called") + result: Any = Field( + default=None, description="Tool-specific result (may include tool errors)" + ) + error: Optional[ToolError] = Field( + default=None, description="Transport/framework error if call failed" + ) + + +# --- WebSocket Message Types for MCP --- + + +class WSMCPMessage(BaseMessage): + """ + WebSocket message for MCP JSON-RPC requests. + + Allows direct MCP access via WebSocket for production inference, + bypassing the step() API. + """ + + type: Literal["mcp"] = Field(default="mcp", description="Message type") + data: Dict[str, Any] = Field(description="JSON-RPC payload (method, params, id)") + + +class WSMCPResponse(BaseModel): + """ + WebSocket response for MCP JSON-RPC. + + Contains the JSON-RPC response from the MCP server. + """ + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="mcp", description="Response type") + data: Dict[str, Any] = Field(description="JSON-RPC response payload") + + +# Reserved tool names that cannot be used (protects dual API boundary) +RESERVED_TOOL_NAMES = frozenset(["reset", "step", "state", "close"]) diff --git a/src/core/env_server/route_config.py b/src/core/env_server/route_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d74a7f202be0731400a6b954dfd37d9012c1f8f7 --- /dev/null +++ b/src/core/env_server/route_config.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Route configuration utilities for declarative FastAPI route registration. + +This module provides utilities to reduce boilerplate in route registration +by using configuration objects instead of repeated function calls. +""" + +from dataclasses import dataclass +from typing import Callable, List, Type + +from fastapi import FastAPI +from pydantic import BaseModel + + +@dataclass +class GetEndpointConfig: + """Configuration for a simple GET endpoint.""" + + path: str + handler: Callable[[], BaseModel | dict] + response_model: Type[BaseModel] | type[dict] + tag: str + summary: str + description: str + + +def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None: + """ + Register multiple GET endpoints from configuration. + + Args: + app: FastAPI application instance + configs: List of GET endpoint configurations + """ + for config in configs: + # Capture handler in a closure to avoid non-serializable default parameter + def make_endpoint( + handler: Callable[[], BaseModel | dict], + ) -> Callable[[], BaseModel | dict]: + async def endpoint() -> BaseModel | dict: + return handler() + + return endpoint + + app.get( + config.path, + response_model=config.response_model, + tags=[config.tag], + summary=config.summary, + description=config.description, + )(make_endpoint(config.handler)) diff --git a/src/core/env_server/serialization.py b/src/core/env_server/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b50d9aeb873794044e77ee398a7f2b5fca8093 --- /dev/null +++ b/src/core/env_server/serialization.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared serialization and deserialization utilities for OpenEnv HTTP servers. + +This module provides common utilities for converting between JSON dictionaries +and Pydantic models (Action/Observation) to eliminate code duplication across +HTTP server and web interface implementations. +""" + +from typing import Any, Dict, Type + +from .types import Action, Observation + + +def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action: + """ + Convert JSON dict to Action instance using Pydantic validation. + + This is a basic deserialization that works for most environments. + For special cases (e.g., tensor fields, custom type conversions), + use deserialize_action_with_preprocessing(). + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + + Note: + This uses Pydantic's model_validate() for automatic validation. + """ + return action_cls.model_validate(action_data) + + +def deserialize_action_with_preprocessing( + action_data: Dict[str, Any], action_cls: Type[Action] +) -> Action: + """ + Convert JSON dict to Action instance with preprocessing for special types. + + This version handles common type conversions needed for web interfaces: + - Converting lists/strings to tensors for 'tokens' field + - Converting string action_id to int + - Other custom preprocessing as needed + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + """ + processed_data = {} + + for key, value in action_data.items(): + if key == "tokens" and isinstance(value, (list, str)): + # Convert list or string to tensor + if isinstance(value, str): + # If it's a string, try to parse it as a list of numbers + try: + import json + + value = json.loads(value) + except Exception: + # If parsing fails, treat as empty list + value = [] + if isinstance(value, list): + try: + import torch # type: ignore + + processed_data[key] = torch.tensor(value, dtype=torch.long) + except ImportError: + # If torch not available, keep as list + processed_data[key] = value + else: + processed_data[key] = value + elif key == "action_id" and isinstance(value, str): + # Convert action_id from string to int + try: + processed_data[key] = int(value) + except ValueError: + # If conversion fails, keep original value + processed_data[key] = value + else: + processed_data[key] = value + + return action_cls.model_validate(processed_data) + + +def serialize_observation(observation: Observation) -> Dict[str, Any]: + """ + Convert Observation instance to JSON-compatible dict using Pydantic. + + Args: + observation: Observation instance + + Returns: + Dictionary compatible with EnvClient._parse_result() + + The format matches what EnvClient expects: + { + "observation": {...}, # Observation fields + "reward": float | None, + "done": bool, + } + """ + # Use Pydantic's model_dump() for serialization + obs_dict = observation.model_dump( + exclude={ + "reward", + "done", + "metadata", + } # Exclude these from observation dict + ) + + # Extract reward and done directly from the observation + reward = observation.reward + done = observation.done + + # Return in EnvClient expected format + return { + "observation": obs_dict, + "reward": reward, + "done": done, + } diff --git a/src/core/env_server/types.py b/src/core/env_server/types.py new file mode 100644 index 0000000000000000000000000000000000000000..34a198013442e5000f7fbf75b7f24157b6c04683 --- /dev/null +++ b/src/core/env_server/types.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Annotated, Any, Dict, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +# Type aliases +Scalar = Union[int, float, bool] + + +# ============================================================================= +# Enums for Type Safety +# ============================================================================= + + +class ServerMode(str, Enum): + """Server operation mode.""" + + SIMULATION = "simulation" + PRODUCTION = "production" + + +class HealthStatus(str, Enum): + """Server health status values.""" + + HEALTHY = "healthy" + UNHEALTHY = "unhealthy" + DEGRADED = "degraded" + + +class WSErrorCode(str, Enum): + """WebSocket error codes for structured error handling.""" + + INVALID_JSON = "INVALID_JSON" + UNKNOWN_TYPE = "UNKNOWN_TYPE" + VALIDATION_ERROR = "VALIDATION_ERROR" + EXECUTION_ERROR = "EXECUTION_ERROR" + CAPACITY_REACHED = "CAPACITY_REACHED" + FACTORY_ERROR = "FACTORY_ERROR" + SESSION_ERROR = "SESSION_ERROR" + + +# ============================================================================= +# Core Types +# ============================================================================= + + +class Action(BaseModel): + """Base class for all environment actions. + + All action subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", # Reject unknown fields + validate_assignment=True, # Validate on field assignment + arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc. + ) + + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the action" + ) + + +class Observation(BaseModel): + """Base class for all environment observations. + + All observation subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + done: bool = Field(default=False, description="Whether the episode has terminated") + reward: bool | int | float | None = Field( + default=None, description="Reward signal from the last action" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the observation" + ) + + +class ResetRequest(BaseModel): + """Request model for environment reset.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom reset parameters + json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]}, + ) + + seed: Optional[int] = Field( + default=None, ge=0, description="Random seed for reproducible episodes" + ) + episode_id: Optional[str] = Field( + default=None, max_length=255, description="Custom episode identifier" + ) + + +class ResetResponse(BaseModel): + """Response model for environment reset.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Initial observation from the environment" + ) + reward: Optional[float] = Field( + default=None, description="Initial reward (typically None at reset)" + ) + done: bool = Field( + default=False, description="Whether episode is already done (typically False)" + ) + + +class StepRequest(BaseModel): + """Request model for environment step.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom step parameters + json_schema_extra={ + "examples": [ + {"action": {"value": 1}, "timeout_s": 30.0}, + {"action": {"value": 1}, "render": True, "verbose": False}, + ] + }, + ) + + action: Dict[str, Any] = Field( + ..., + description="Action to execute, must conform to environment's action schema", + ) + timeout_s: Optional[float] = Field( + default=None, + gt=0, + description="Optional timeout in seconds for action execution", + ) + request_id: Optional[str] = Field( + default=None, + max_length=255, + description="Optional request identifier for tracking", + ) + + +class StepResponse(BaseModel): + """Response model for environment step.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Observation resulting from the action" + ) + reward: Optional[float] = Field( + default=None, description="Reward signal from the action" + ) + done: bool = Field(default=False, description="Whether the episode has terminated") + + +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + +class State(BaseModel): + """Base class for environment state. + + Represents internal environment state, separate from observations. + """ + + model_config = ConfigDict( + extra="allow", # Allow extra fields for flexibility + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + episode_id: Optional[str] = Field( + default=None, description="Unique identifier for the current episode" + ) + step_count: int = Field( + default=0, + ge=0, # Greater than or equal to 0 + description="Number of steps taken in the current episode", + ) + + +class CodeExecResult(BaseMessage): + """Result of code execution containing stdout, stderr, and exit code.""" + + stdout: str = Field(description="Standard output from code execution") + stderr: str = Field(description="Standard error from code execution") + exit_code: int = Field(description="Exit code from code execution") + + +class EnvironmentMetadata(BaseMessage): + """Metadata about an environment for documentation and UI purposes.""" + + name: str = Field(description="Name of the environment") + description: str = Field(description="Description of what the environment does") + readme_content: Optional[str] = Field( + default=None, description="Content of the README file for the environment" + ) + version: Optional[str] = Field( + default=None, description="Version of the environment" + ) + author: Optional[str] = Field(default=None, description="Author of the environment") + documentation_url: Optional[str] = Field( + default=None, description="URL to the environment's documentation" + ) + + +class SchemaResponse(BaseMessage): + """Response model for the combined schema endpoint.""" + + action: Dict[str, Any] = Field( + description="JSON schema for actions accepted by this environment" + ) + observation: Dict[str, Any] = Field( + description="JSON schema for observations returned by this environment" + ) + state: Dict[str, Any] = Field( + description="JSON schema for environment state objects" + ) + + +class HealthResponse(BaseMessage): + """Response model for health check endpoint.""" + + status: HealthStatus = Field( + default=HealthStatus.HEALTHY, + description="Health status of the environment server", + ) + + +class WSResetMessage(BaseMessage): + """WebSocket message to reset the environment.""" + + type: Literal["reset"] = Field(default="reset", description="Message type") + data: Dict[str, Any] = Field( + default_factory=dict, + description="Optional reset parameters (seed, episode_id, etc.)", + ) + + +class WSStepMessage(BaseMessage): + """WebSocket message to execute a step.""" + + type: Literal["step"] = Field(default="step", description="Message type") + data: Dict[str, Any] = Field( + ..., description="Action data conforming to environment's action schema" + ) + + +class WSStateMessage(BaseMessage): + """WebSocket message to request current state.""" + + type: Literal["state"] = Field(default="state", description="Message type") + + +class WSCloseMessage(BaseMessage): + """WebSocket message to close the session.""" + + type: Literal["close"] = Field(default="close", description="Message type") + + +# Discriminated union for incoming WebSocket messages +# Note: WSMCPMessage is defined in mcp_types.py to avoid circular imports +# The union here covers the core message types; MCP messages are handled separately +WSIncomingMessage = Annotated[ + WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, + Field(discriminator="type"), +] + + +class WSObservationResponse(BaseModel): + """WebSocket response containing an observation.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["observation"] = Field( + default="observation", description="Response type" + ) + data: Dict[str, Any] = Field(description="Observation data") + + +class WSStateResponse(BaseModel): + """WebSocket response containing environment state.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["state"] = Field(default="state", description="Response type") + data: Dict[str, Any] = Field(description="State data") + + +class WSErrorResponse(BaseModel): + """WebSocket response for errors.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["error"] = Field(default="error", description="Response type") + data: Dict[str, Any] = Field(description="Error details including message and code") + + +class ConcurrencyConfig(BaseMessage): + """Configuration for concurrent environment sessions.""" + + max_concurrent_envs: int = Field( + default=1, + ge=1, + description="Maximum number of concurrent WebSocket sessions allowed", + ) + session_timeout: Optional[float] = Field( + default=None, + gt=0, + description="Timeout in seconds for inactive sessions. None means no timeout.", + ) + + +class ServerCapacityStatus(BaseMessage): + """Status of server capacity for concurrent sessions.""" + + active_sessions: int = Field( + ge=0, + description="Number of currently active sessions", + ) + max_sessions: int = Field( + ge=1, + description="Maximum number of allowed sessions", + ) + + @model_validator(mode="after") + def check_capacity_bounds(self) -> "ServerCapacityStatus": + if self.active_sessions > self.max_sessions: + raise ValueError( + f"active_sessions ({self.active_sessions}) cannot exceed " + f"max_sessions ({self.max_sessions})" + ) + return self + + @property + def available_slots(self) -> int: + """Number of available session slots.""" + return self.max_sessions - self.active_sessions + + @property + def is_at_capacity(self) -> bool: + """Whether the server has reached maximum capacity.""" + return self.available_slots == 0 + + @classmethod + def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": + """Create status from active and max session counts.""" + return cls( + active_sessions=active, + max_sessions=max_sessions, + ) + + +class SessionInfo(BaseMessage): + """Information about an active session.""" + + session_id: str = Field(description="Unique identifier for the session") + created_at: float = Field(description="Unix timestamp when the session was created") + last_activity_at: float = Field( + description="Unix timestamp of the last activity in the session" + ) + step_count: int = Field( + default=0, + ge=0, + description="Number of steps executed in this session", + ) + environment_type: str = Field( + description="Environment type for this session (e.g. `CodingEnv`)" + ) diff --git a/src/core/env_server/web_interface.py b/src/core/env_server/web_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..284740eb408b8e2b798037918967b7a50abee72d --- /dev/null +++ b/src/core/env_server/web_interface.py @@ -0,0 +1,644 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Web interface for OpenEnv environments. + +When ENABLE_WEB_INTERFACE is set, the server exposes a Gradio UI at /web for +reset, step, and state observation. Controlled by the CLI enable_interface +option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var. +""" + +from __future__ import annotations + +import asyncio +import json +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Type + +import gradio as gr +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from pydantic import BaseModel, ConfigDict, Field + +from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME +from .gradio_ui import build_gradio_app, get_gradio_display_title +from .interfaces import Environment +from .serialization import deserialize_action_with_preprocessing, serialize_observation +from .types import Action, EnvironmentMetadata, Observation, State + +# Quick Start markdown template; placeholders match init suffixes (__ENV_NAME__, __ENV_CLASS_NAME__*). +DEFAULT_QUICK_START_MARKDOWN = """ +### Connect to this environment + +Connect from Python using `__ENV_CLASS_NAME__Env`: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +with __ENV_CLASS_NAME__Env.from_env("") as env: + result = await env.step(__ENV_CLASS_NAME__Action(message="...")) +``` + +Or connect directly to a running server: + +```python +env = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") +``` + +### Contribute to this environment + +Submit improvements via pull request on the Hugging Face Hub. + +```bash +openenv fork --repo-id / +``` + +Then make your changes and submit a pull request: + +```bash +cd +openenv push --create-pr +``` + +For more information, see the [OpenEnv documentation](https://meta-pytorch.org/OpenEnv/). +""" + + +def get_quick_start_markdown( + metadata: Optional[EnvironmentMetadata], + action_cls: Type[Action], + observation_cls: Type[Observation], +) -> str: + """ + Build Quick Start markdown with class names replaced from current env (init-style suffixes). + + Uses the same placeholder names as the init template so that __ENV_CLASS_NAME__Env, + __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation and __ENV_NAME__ are + replaced with the actual class/package names. + """ + import os + + # Prefix from action class (e.g. EchoAction -> Echo) + action_name = getattr(action_cls, "__name__", "Action") + if action_name.endswith("Action"): + prefix = action_name[: -len("Action")] + else: + prefix = action_name.replace("Action", "").strip() or "Env" + + env_client_name = f"{prefix}Env" + obs_name = getattr(observation_cls, "__name__", "Observation") + pkg_name = (metadata.name if metadata else "env").replace(" ", "_").lower() + + space_id = os.environ.get("SPACE_ID", "/") + + content = DEFAULT_QUICK_START_MARKDOWN + content = content.replace("__ENV_CLASS_NAME__Env", env_client_name) + content = content.replace("__ENV_CLASS_NAME__Action", action_name) + content = content.replace("__ENV_CLASS_NAME__Observation", obs_name) + content = content.replace("__ENV_CLASS_NAME__", prefix) + content = content.replace("__ENV_NAME__", pkg_name) + content = content.replace("", space_id) + return content.strip() + + +def load_environment_metadata( + env: Environment, env_name: Optional[str] = None +) -> EnvironmentMetadata: + """ + Load environment metadata including README content. + + Args: + env: The environment instance, class, or factory function. + - If a class: used as a factory, won't call instance methods + - If a function: used as a factory, won't call instance methods + - If an instance: may call get_metadata() if available + env_name: Optional environment name for README file lookup + + Returns: + EnvironmentMetadata with loaded information + """ + import inspect + + # Determine what type of env we received: + # 1. A class (used as factory) - e.g., PythonCodeActEnv + # 2. A function (factory function) - e.g., create_chat_environment + # 3. An actual instance - e.g., SnakeEnvironment() + is_class = inspect.isclass(env) + is_function = inspect.isfunction(env) or inspect.ismethod(env) + is_factory = is_class or is_function + + # Try to get metadata from environment if it's an instance with get_metadata + if not is_factory and hasattr(env, "get_metadata"): + return env.get_metadata() + + # Determine the class name for default metadata + if is_class: + # env is the class itself + class_name = env.__name__ + elif is_function: + # env is a factory function - use its name or derive from env_name + class_name = env_name or env.__name__ + else: + # env is an instance + class_name = env.__class__.__name__ + + # Default metadata + metadata = EnvironmentMetadata( + name=env_name or class_name, + description=f"{class_name} environment", + version="1.0.0", + ) + + # Try to load README from file system + readme_content = _load_readme_from_filesystem(env_name) + if readme_content: + metadata.readme_content = readme_content + + return metadata + + +def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: + """ + Load README content from the filesystem. + + Tries multiple locations: + 1. Container filesystem: /app/README.md + 2. Local development: src/envs/{env_name}/README.md + 3. Environment variable: ENV_README_PATH + """ + import os + from pathlib import Path + + # Try container filesystem first + container_readme = Path("/app/README.md") + if container_readme.exists(): + try: + return container_readme.read_text(encoding="utf-8") + except Exception: + pass + + # Try environment variable path + custom_path = os.environ.get("ENV_README_PATH") + if custom_path and Path(custom_path).exists(): + try: + return Path(custom_path).read_text(encoding="utf-8") + except Exception: + pass + + # Try local development path + if env_name: + local_readme = Path(f"src/envs/{env_name}/README.md") + if local_readme.exists(): + try: + return local_readme.read_text(encoding="utf-8") + except Exception: + pass + + return None + + +class ActionLog(BaseModel): + """Log entry for an action taken.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + timestamp: str = Field(description="Timestamp when action was taken") + action: Dict[str, Any] = Field(description="Action that was taken") + observation: Dict[str, Any] = Field(description="Observation returned from action") + reward: Optional[float] = Field( + default=None, description="Reward received from action" + ) + done: bool = Field(description="Whether the episode is done after this action") + step_count: int = Field(description="Step count when this action was taken") + + +class EpisodeState(BaseModel): + """Current episode state for the web interface.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + episode_id: Optional[str] = Field(default=None, description="Current episode ID") + step_count: int = Field(description="Current step count in episode") + current_observation: Optional[Dict[str, Any]] = Field( + default=None, description="Current observation" + ) + action_logs: List[ActionLog] = Field( + default_factory=list, description="List of action logs" + ) + is_reset: bool = Field( + default=True, description="Whether the episode has been reset" + ) + + +class WebInterfaceManager: + """Manages the web interface for an environment.""" + + MAX_ACTION_LOGS = 1000 + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + metadata: Optional[EnvironmentMetadata] = None, + ): + import inspect + + # If env is a class or factory function, instantiate it + if inspect.isclass(env) or inspect.isfunction(env): + self.env = env() + else: + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + self.metadata = metadata or EnvironmentMetadata( + name=env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + ) + self.episode_state = EpisodeState( + episode_id=None, + step_count=0, + current_observation=None, + action_logs=[], + ) + self.connected_clients: List[WebSocket] = [] + # Thread pool for running sync code (e.g., Playwright sync API) in async context + self._executor = ThreadPoolExecutor(max_workers=1) + + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + """Run a synchronous function in the thread pool executor. + + This is needed for environments using sync libraries (e.g., Playwright sync API) + that cannot be called directly from an async context. + """ + loop = asyncio.get_event_loop() + # Use default arguments to capture values at lambda definition time + # to avoid closure issues with late binding + return await loop.run_in_executor( + self._executor, lambda f=func, a=args, kw=kwargs: f(*a, **kw) + ) + + async def connect_websocket(self, websocket: WebSocket): + """Connect a new WebSocket client.""" + await websocket.accept() + self.connected_clients.append(websocket) + + # Send current state to the new client + await self._send_state_update() + + async def disconnect_websocket(self, websocket: WebSocket): + """Disconnect a WebSocket client.""" + if websocket in self.connected_clients: + self.connected_clients.remove(websocket) + + async def _send_state_update(self): + """Send current state to all connected clients.""" + if not self.connected_clients: + return + + state_data = { + "type": "state_update", + "episode_state": self.episode_state.model_dump(), + } + + # Send to all connected clients + disconnected_clients = [] + for client in self.connected_clients: + try: + await client.send_text(json.dumps(state_data)) + except Exception: + disconnected_clients.append(client) + + # Remove disconnected clients + for client in disconnected_clients: + self.connected_clients.remove(client) + + async def reset_environment(self) -> Dict[str, Any]: + """Reset the environment and update state.""" + # Run sync reset in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool(self.env.reset) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = 0 + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs = [] + self.episode_state.is_reset = True + + # Send state update + await self._send_state_update() + + return serialized + + async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: + """Execute a step in the environment and update state.""" + # Deserialize action with preprocessing for web interface special cases + action: Action = deserialize_action_with_preprocessing( + action_data, self.action_cls + ) + + # Run sync step in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool( + self.env.step, action + ) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Create action log + action_log = ActionLog( + timestamp=datetime.now().isoformat(), + action=action.model_dump(exclude={"metadata"}), + observation=serialized["observation"], + reward=observation.reward, + done=observation.done, + step_count=state.step_count, + ) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = state.step_count + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs.append(action_log) + if len(self.episode_state.action_logs) > self.MAX_ACTION_LOGS: + self.episode_state.action_logs = self.episode_state.action_logs[ + -self.MAX_ACTION_LOGS : + ] + self.episode_state.is_reset = False + + # Send state update + await self._send_state_update() + + return serialized + + def get_state(self) -> Dict[str, Any]: + """Get current environment state.""" + state: State = self.env.state + return state.model_dump() + + +def create_web_interface_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[Any] = None, + gradio_builder: Optional[Callable[..., Any]] = None, +) -> FastAPI: + """ + Create a FastAPI application with web interface for the given environment. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings + gradio_builder: Optional callable (web_manager, action_fields, metadata, + is_chat_env, title, quick_start_md) -> gr.Blocks to use instead of the + default Gradio UI. Lets envs replace or customize the /web interface. + + Returns: + FastAPI application instance with web interface + """ + from .http_server import create_fastapi_app + + # Create the base environment app + app = create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) + + # Load environment metadata + metadata = load_environment_metadata(env, env_name) + + # Create web interface manager + web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + + # Web API routes first (so they take precedence over Gradio mount at /web) + @app.get("/web/metadata") + async def web_metadata(): + """Get environment metadata.""" + return web_manager.metadata.model_dump() + + @app.websocket("/ws/ui") + async def websocket_ui_endpoint(websocket: WebSocket): + """WebSocket endpoint for web UI real-time updates. + + Note: Uses /ws/ui to avoid conflict with /ws in http_server.py + which is used for concurrent environment sessions. + """ + await web_manager.connect_websocket(websocket) + try: + while True: + # Keep connection alive + await websocket.receive_text() + except WebSocketDisconnect: + await web_manager.disconnect_websocket(websocket) + + @app.post("/web/reset") + async def web_reset(): + """Reset endpoint for web interface.""" + return await web_manager.reset_environment() + + @app.post("/web/step") + async def web_step(request: Dict[str, Any]): + """Step endpoint for web interface.""" + # Check if this is a message-based request (chat environment) + if "message" in request: + message = request["message"] + if hasattr(web_manager.env, "message_to_action"): + action = web_manager.env.message_to_action(message) + if hasattr(action, "tokens"): + action_data = {"tokens": action.tokens.tolist()} + else: + action_data = action.model_dump(exclude={"metadata"}) + else: + action_data = {"message": message} + else: + action_data = request.get("action", {}) + + return await web_manager.step_environment(action_data) + + @app.get("/web/state") + async def web_state(): + """State endpoint for web interface.""" + return web_manager.get_state() + + action_fields = _extract_action_fields(action_cls) + is_chat_env = _is_chat_env(action_cls) + quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls) + + default_blocks = build_gradio_app( + web_manager, + action_fields, + metadata, + is_chat_env, + title=metadata.name, + quick_start_md=quick_start_md, + ) + if gradio_builder is not None: + custom_blocks = gradio_builder( + web_manager, + action_fields, + metadata, + is_chat_env, + metadata.name, + quick_start_md, + ) + if not isinstance(custom_blocks, gr.Blocks): + raise TypeError( + f"gradio_builder must return a gr.Blocks instance, " + f"got {type(custom_blocks).__name__}" + ) + gradio_blocks = gr.TabbedInterface( + [default_blocks, custom_blocks], + tab_names=["Playground", "Visualization"], + title=get_gradio_display_title(metadata), + ) + else: + gradio_blocks = default_blocks + app = gr.mount_gradio_app( + app, + gradio_blocks, + path="/web", + theme=OPENENV_GRADIO_THEME, + css=OPENENV_GRADIO_CSS, + ) + + return app + + +def _is_chat_env(action_cls: Type[Action]) -> bool: + """Return True if the action class is a chat-style env (tokens field).""" + if hasattr(action_cls, "model_fields"): + for field_name, field_info in action_cls.model_fields.items(): + if ( + field_name == "tokens" + and hasattr(field_info.annotation, "__name__") + and "Tensor" in str(field_info.annotation) + ): + return True + return False + + +def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: + """Extract enhanced field metadata from Action class for form generation.""" + # Use Pydantic's JSON schema generation for robust metadata extraction + try: + schema = action_cls.model_json_schema() + except AttributeError: + # Fallback for non-Pydantic v2 models or if something goes wrong + return [] + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + action_fields = [] + + for field_name, field_info in properties.items(): + if field_name == "metadata": + continue + + # JSON schema "type" can be a string or list/undefined + # Determine our internal input type + input_type = _determine_input_type_from_schema(field_info, field_name) + + is_required = field_name in required_fields + + action_fields.append( + { + "name": field_name, + "type": input_type, + "required": is_required, + "description": field_info.get("description", ""), + "default_value": field_info.get("default"), + "choices": field_info.get("enum"), + "min_value": field_info.get("minimum"), + "max_value": field_info.get("maximum"), + "min_length": field_info.get("minLength"), + "max_length": field_info.get("maxLength"), + "pattern": field_info.get("pattern"), + "placeholder": _generate_placeholder(field_name, field_info), + "help_text": _generate_help_text(field_name, field_info), + } + ) + + return action_fields + + +def _determine_input_type_from_schema( + field_info: Dict[str, Any], field_name: str +) -> str: + """Determine input type from JSON schema for form generation (Gradio UI).""" + schema_type = field_info.get("type") + + # Check for specific tensor field convention + if "tokens" in field_name.lower(): + return "tensor" + + if "enum" in field_info: + return "select" + + if schema_type == "boolean": + return "checkbox" + + if schema_type == "integer" or schema_type == "number": + return "number" + + if schema_type == "string": + # Check if it should be a textarea + if ( + field_info.get("maxLength", 0) > 100 + or "message" in field_name.lower() + or "code" in field_name.lower() + ): + return "textarea" + return "text" + + # Default fallback + return "text" + + +def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate placeholder text.""" + if "message" in field_name.lower(): + return f"Enter {field_name.replace('_', ' ')}..." + elif "code" in field_name.lower(): + return "Enter Python code here..." + elif "tokens" in field_name.lower(): + return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" + else: + return f"Enter {field_name.replace('_', ' ')}..." + + +def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate help text.""" + description = field_info.get("description", "") + if description: + return description + + if "action_id" in field_name.lower(): + return "The action ID to execute in environment" + elif "game_name" in field_name.lower(): + return "Name of game or environment" + elif "tokens" in field_name.lower(): + return "Token IDs as a comma-separated list of integers" + elif "code" in field_name.lower(): + return "Python code to execute in environment" + elif "message" in field_name.lower(): + return "Text message to send" + + return "" diff --git a/src/core/evals/__init__.py b/src/core/evals/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52e564a09b5e4976f2cd5a8c1fe1c7848bb47ecb --- /dev/null +++ b/src/core/evals/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Evaluation harness support for OpenEnv.""" + +from openenv.core.evals.base import EvalHarness +from openenv.core.evals.inspect_harness import InspectAIHarness +from openenv.core.evals.types import EvalConfig, EvalResult + +__all__ = [ + "EvalHarness", + "EvalConfig", + "EvalResult", + "InspectAIHarness", +] diff --git a/src/core/evals/base.py b/src/core/evals/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e457d8adb740569ad79143cbf70bc58b05a8cef9 --- /dev/null +++ b/src/core/evals/base.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base class for evaluation harnesses.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + +from openenv.core.evals.types import EvalConfig, EvalResult + + +class EvalHarness(ABC): + """Abstract base class for evaluation harnesses. + + Subclasses implement run() to define evaluation logic. + """ + + @abstractmethod + def run( + self, + harness_version: str, + library_versions: Dict[str, str], + dataset: str, + eval_parameters: Dict[str, Any], + ) -> Dict[str, Any]: + """Run the evaluation and return scores. + + Args: + harness_version: Version of the evaluation harness. + library_versions: Versions of libraries used in the evaluation. + dataset: Name of the dataset to evaluate on. + eval_parameters: Parameters for the evaluation. + + Returns: + Dictionary of scores from the evaluation. + """ + raise NotImplementedError + + def run_from_config(self, config: EvalConfig) -> EvalResult: + """Run evaluation from an EvalConfig and return an EvalResult. + + Args: + config: Configuration for the evaluation. + + Returns: + EvalResult containing the config and scores. + """ + scores = self.run( + harness_version=config.harness_version, + library_versions=config.library_versions, + dataset=config.dataset, + eval_parameters=config.eval_parameters, + ) + return EvalResult(config=config, scores=scores) + + @property + def name(self) -> str: + """Return the name of the harness (class name).""" + return self.__class__.__name__ diff --git a/src/core/evals/inspect_harness.py b/src/core/evals/inspect_harness.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf91105db6cf325587623891905e5cbc71c124e --- /dev/null +++ b/src/core/evals/inspect_harness.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Inspect AI harness integration for OpenEnv. + +Requires the ``inspect-ai`` package: ``pip install 'inspect-ai>=0.3.0'`` +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from openenv.core.evals.base import EvalHarness + + +class InspectAIHarness(EvalHarness): + """Evaluation harness wrapping Inspect AI's ``eval()`` function. + + All ``inspect_ai`` imports are deferred to :meth:`run` so this class is + importable without inspect-ai installed. An ``ImportError`` with a clear + message is raised at call time if the dependency is missing. + + Args: + log_dir: Directory for evaluation log output. Defaults to None + (Inspect AI writes logs to its default location). + + ``eval_parameters`` keys accepted by :meth:`run`: + + +--------------------------+----------+-----------------+-----------------------------------+ + | Key | Type | Default | Purpose | + +==========================+==========+=================+===================================+ + | ``model`` | str | *required* | Model string, e.g. "openai/gpt-4o"| + | ``task`` | str|None | ``dataset`` arg | Task file path or task string | + | ``task_args`` | dict | ``{}`` | Arguments to pass to the task | + | ``max_samples`` | int|None | None | Limit samples per task | + | ``temperature`` | float|None| None | Model generation temperature | + | ``max_tokens`` | int|None | None | Max generation tokens | + | ``epochs`` | int|None | None | Number of evaluation epochs | + | ``solver`` | list|None| None | Solver pipeline override | + | ``scorer`` | list|None| None | Scorer override | + | ``model_args`` | dict | ``{}`` | Provider-specific model kwargs | + +--------------------------+----------+-----------------+-----------------------------------+ + """ + + def __init__( + self, + *, + log_dir: Optional[str] = None, + ): + self.log_dir = log_dir + + def run( + self, + harness_version: str, + library_versions: Dict[str, str], + dataset: str, + eval_parameters: Dict[str, Any], + ) -> Dict[str, Any]: + """Run an Inspect AI evaluation. + + Args: + harness_version: Version of inspect-ai being used. + library_versions: Versions of supporting libraries. + dataset: Default task string (used when ``task`` is not specified + in *eval_parameters*). + eval_parameters: See class docstring for accepted keys. + + Returns: + Dictionary mapping metric names to scores. + + Raises: + ImportError: If ``inspect-ai`` is not installed. + ValueError: If ``model`` is missing from *eval_parameters*. + RuntimeError: If the evaluation fails (log status is not "success"). + """ + try: + from inspect_ai import eval as inspect_eval + except ImportError: + raise ImportError( + "inspect-ai is required for InspectAIHarness. " + "Install it with: pip install 'inspect-ai>=0.3.0'" + ) + + # Extract required model parameter + model = eval_parameters.get("model") + if model is None: + raise ValueError( + "eval_parameters must include 'model' " + "(e.g. 'openai/gpt-4o', 'hf/meta-llama/...')." + ) + + # Task: explicit parameter or fall back to dataset + task = eval_parameters.get("task", dataset) + + # Build eval kwargs + eval_kwargs: Dict[str, Any] = {} + + task_args = eval_parameters.get("task_args", {}) + if task_args: + eval_kwargs["task_args"] = task_args + + model_args = eval_parameters.get("model_args", {}) + if model_args: + eval_kwargs["model_args"] = model_args + + for key in ("max_samples", "temperature", "max_tokens", "epochs"): + value = eval_parameters.get(key) + if value is not None: + eval_kwargs[key] = value + + if eval_parameters.get("solver") is not None: + eval_kwargs["solver"] = eval_parameters["solver"] + + if eval_parameters.get("scorer") is not None: + eval_kwargs["scorer"] = eval_parameters["scorer"] + + if self.log_dir is not None: + eval_kwargs["log_dir"] = self.log_dir + + # Run evaluation + logs = inspect_eval(task, model=model, **eval_kwargs) + + # Extract results from the first log + if not logs: + raise RuntimeError( + "Inspect AI evaluation returned no logs. " + "Check that the task and model arguments are valid." + ) + log = logs[0] + if log.status != "success": + raise RuntimeError( + f"Inspect AI evaluation failed with status: {log.status}" + ) + + return self._extract_scores(log) + + def _extract_scores(self, log: Any) -> Dict[str, Any]: + """Parse an EvalLog's results into a flat score dictionary. + + Iterates over ``log.results.scores`` (a list of ``EvalScore``), + flattening each scorer's ``metrics`` dict into a single output dict. + + Args: + log: An ``inspect_ai`` ``EvalLog`` object. + + Returns: + Dictionary mapping metric names to their values. + """ + scores: Dict[str, Any] = {} + if log.results is None: + return scores + + for eval_score in log.results.scores: + for metric_name, metric in eval_score.metrics.items(): + scores[metric_name] = metric.value + + return scores diff --git a/src/core/evals/types.py b/src/core/evals/types.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6b14f762624c607c345e5dff1bc77faa5b4b56 --- /dev/null +++ b/src/core/evals/types.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pydantic models for eval configuration and results.""" + +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, Field + + +class EvalConfig(BaseModel): + """Configuration for running an evaluation.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + harness_name: str = Field(description="Name of the evaluation harness") + harness_version: str = Field(description="Version of the evaluation harness") + library_versions: Dict[str, str] = Field( + description="Versions of libraries used in the evaluation" + ) + dataset: str = Field(description="Name of the dataset to evaluate on") + eval_parameters: Dict[str, Any] = Field(description="Parameters for the evaluation") + + +class EvalResult(BaseModel): + """Result of running an evaluation.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + config: EvalConfig = Field(description="Configuration used for the evaluation") + scores: Dict[str, Any] = Field(description="Scores from the evaluation") diff --git a/src/core/generic_client.py b/src/core/generic_client.py new file mode 100644 index 0000000000000000000000000000000000000000..17576862293feeebf68b4a90d6a4a80de369dd34 --- /dev/null +++ b/src/core/generic_client.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Generic environment client that works with raw dictionaries. + +This module provides a GenericEnvClient that doesn't require installing +environment-specific packages. It's useful for connecting to remote servers +without running any untrusted code locally. +""" + +from typing import Any, Dict + +from .client_types import StepResult +from .env_client import EnvClient + + +class GenericEnvClient(EnvClient[Dict[str, Any], Dict[str, Any], Dict[str, Any]]): + """ + Environment client that works with raw dictionaries instead of typed classes. + + This client doesn't require installing environment-specific packages, making it + ideal for: + - Connecting to remote servers without installing their packages + - Quick prototyping and testing + - Environments where type safety isn't needed + - Security-conscious scenarios where you don't want to run remote code + + The trade-off is that you lose type safety and IDE autocomplete for actions + and observations. Instead of typed objects, you work with plain dictionaries. + + Example: + >>> # Direct connection to a running server (no installation needed) + >>> with GenericEnvClient(base_url="http://localhost:8000") as env: + ... result = env.reset() + ... result = env.step({"code": "print('hello')"}) + ... print(result.observation) # Dict[str, Any] + ... print(result.observation.get("output")) + + >>> # From local Docker image + >>> env = GenericEnvClient.from_docker_image("coding-env:latest") + >>> result = env.reset() + >>> result = env.step({"code": "x = 1 + 2"}) + >>> env.close() + + >>> # From HuggingFace Hub (pulls Docker image, no pip install) + >>> env = GenericEnvClient.from_env("user/my-env", use_docker=True) + >>> result = env.reset() + >>> env.close() + + Note: + GenericEnvClient inherits `from_docker_image()` and `from_env()` from + EnvClient, so you can use it with Docker containers and HuggingFace + Spaces without any package installation. + """ + + def _step_payload(self, action: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert action to payload for the server. + + For GenericEnvClient, this handles both raw dictionaries and + typed Action objects (Pydantic models). If a Pydantic model is + passed, it will be converted to a dictionary using model_dump(). + + Args: + action: Action as a dictionary or Pydantic BaseModel + + Returns: + The action as a dictionary for the server + """ + # If it's already a dict, return as-is + if isinstance(action, dict): + return action + + # If it's a Pydantic model (Action subclass), convert to dict + if hasattr(action, "model_dump"): + return action.model_dump() + + # Fallback for other objects with __dict__ + if hasattr(action, "__dict__"): + return vars(action) + + # Last resort: try to convert to dict + return dict(action) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Dict[str, Any]]: + """ + Parse server response into a StepResult. + + Extracts the observation, reward, and done fields from the + server response. + + Args: + payload: Response payload from the server + + Returns: + StepResult with observation as a dictionary + """ + return StepResult( + observation=payload.get("observation", {}), + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """ + Parse state response from the server. + + For GenericEnvClient, this returns the payload as-is since + we're working with dictionaries. + + Args: + payload: State payload from the server + + Returns: + The state as a dictionary + """ + return payload + + +class GenericAction(Dict[str, Any]): + """ + A dictionary subclass for creating actions when using GenericEnvClient. + + This provides a semantic wrapper around dictionaries to make code more + readable when working with GenericEnvClient. It behaves exactly like a + dict but signals intent that this is an action for an environment. + + Example: + >>> # Without GenericAction (works fine) + >>> env.step({"code": "print('hello')"}) + + >>> # With GenericAction (more explicit) + >>> action = GenericAction(code="print('hello')") + >>> env.step(action) + + >>> # With multiple fields + >>> action = GenericAction(code="x = 1", timeout=30, metadata={"tag": "test"}) + >>> env.step(action) + + Note: + GenericAction is just a dict with a constructor that accepts keyword + arguments. It's provided for symmetry with typed Action classes and + to make code more readable. + """ + + def __init__(self, **kwargs: Any) -> None: + """ + Create a GenericAction from keyword arguments. + + Args: + **kwargs: Action fields as keyword arguments + + Example: + >>> action = GenericAction(code="print(1)", timeout=30) + >>> action["code"] + 'print(1)' + """ + super().__init__(kwargs) + + def __repr__(self) -> str: + """Return a readable representation.""" + items = ", ".join(f"{k}={v!r}" for k, v in self.items()) + return f"GenericAction({items})" diff --git a/src/core/llm_client.py b/src/core/llm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9df2ff27ae7c2054108ff159b9dec8e4c9dd238c --- /dev/null +++ b/src/core/llm_client.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""LLM client abstraction for calling LLM endpoints. + +Provides a generic RPC abstraction: point it at an endpoint/port, tell it the +protocol, and it works. OpenAI-compatible API is the first implementation, +covering OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, etc. +Anthropic's native API is supported via ``AnthropicClient``. + +Usage: + client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") + response = await client.complete("What is 2+2?") + + # Or use the factory for hosted APIs: + client = create_llm_client("openai", model="gpt-4", api_key="sk-...") + response = await client.complete_with_tools(messages, tools) +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from openai import AsyncOpenAI + + +@dataclass +class ToolCall: + """A single tool/function call returned by the model.""" + + id: str + name: str + args: dict[str, Any] + + +@dataclass +class LLMResponse: + """Normalized response from an LLM, with optional tool calls.""" + + content: str + tool_calls: list[ToolCall] = field(default_factory=list) + + def to_message_dict(self) -> dict[str, Any]: + """Convert to an OpenAI-format assistant message dict.""" + msg: dict[str, Any] = {"role": "assistant", "content": self.content} + if self.tool_calls: + msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.args), + }, + } + for tc in self.tool_calls + ] + return msg + + +class LLMClient(ABC): + """Abstract base for LLM endpoint clients. + + Subclass and implement ``complete()`` for your protocol. + + Args: + endpoint: The base URL of the LLM service (e.g. "http://localhost"). + port: The port the service listens on. + """ + + def __init__(self, endpoint: str, port: int): + self.endpoint = endpoint + self.port = port + + @abstractmethod + async def complete(self, prompt: str, **kwargs) -> str: + """Send a prompt, return the text response. + + Args: + prompt: The user prompt to send. + **kwargs: Override default parameters (temperature, max_tokens, etc.). + + Returns: + The model's text response. + """ + ... + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + """Send messages with tool definitions, return a normalized response. + + Messages use OpenAI-format dicts (``{"role": "...", "content": "..."}``). + Tools use MCP tool definitions; they are converted internally. + + Args: + messages: Conversation history as OpenAI-format message dicts. + tools: MCP tool definitions. + **kwargs: Override default parameters (temperature, max_tokens, etc.). + + Returns: + An ``LLMResponse`` with the model's text and any tool calls. + """ + raise NotImplementedError( + f"{type(self).__name__} does not support tool calling" + ) + + @property + def base_url(self) -> str: + """Construct base URL from endpoint and port.""" + return f"{self.endpoint}:{self.port}" + + +class OpenAIClient(LLMClient): + """Client for OpenAI-compatible APIs. + + Works with: OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, + or any endpoint that speaks the OpenAI chat completions format. + + Args: + endpoint: The base URL (e.g. "http://localhost"). + port: The port number. + model: Model name to pass to the API. + api_key: API key. Defaults to "not-needed" for local endpoints. + system_prompt: Optional system message prepended to every request. + temperature: Default sampling temperature. + max_tokens: Default max tokens in the response. + """ + + def __init__( + self, + endpoint: str, + port: int, + model: str, + api_key: str | None = None, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 256, + ): + super().__init__(endpoint, port) + self.model = model + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + + self._client = AsyncOpenAI( + base_url=f"{self.base_url}/v1", + api_key=api_key if api_key is not None else "not-needed", + ) + + async def complete(self, prompt: str, **kwargs) -> str: + """Send a chat completion request. + + Args: + prompt: The user message. + **kwargs: Overrides for temperature, max_tokens. + + Returns: + The assistant's response text. + """ + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = await self._client.chat.completions.create( + model=self.model, + messages=messages, + temperature=kwargs.get("temperature", self.temperature), + max_tokens=kwargs.get("max_tokens", self.max_tokens), + ) + return response.choices[0].message.content or "" + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": messages, + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + openai_tools = _mcp_tools_to_openai(tools) + if openai_tools: + create_kwargs["tools"] = openai_tools + + response = await self._client.chat.completions.create(**create_kwargs) + msg = response.choices[0].message + + tool_calls = [] + if msg.tool_calls: + for tc in msg.tool_calls: + tool_calls.append( + ToolCall( + id=tc.id, + name=tc.function.name, + args=json.loads(tc.function.arguments), + ) + ) + + return LLMResponse(content=msg.content or "", tool_calls=tool_calls) + + +class AnthropicClient(LLMClient): + """Client for Anthropic's Messages API. + + Requires the ``anthropic`` package (lazy-imported at construction time). + + Args: + endpoint: The base URL (e.g. "https://api.anthropic.com"). + port: The port number. + model: Model name (e.g. "claude-sonnet-4-20250514"). + api_key: Anthropic API key. + system_prompt: Optional system message prepended to every request. + temperature: Default sampling temperature. + max_tokens: Default max tokens in the response. + """ + + def __init__( + self, + endpoint: str, + port: int, + model: str, + api_key: str | None = None, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 256, + ): + super().__init__(endpoint, port) + self.model = model + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + + try: + from anthropic import AsyncAnthropic + except ImportError as exc: + raise ImportError( + "AnthropicClient requires the 'anthropic' package. " + "Install it with: pip install anthropic" + ) from exc + + self._client = AsyncAnthropic( + base_url=self.base_url, + api_key=api_key if api_key is not None else "not-needed", + ) + + async def complete(self, prompt: str, **kwargs) -> str: + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + if self.system_prompt: + create_kwargs["system"] = self.system_prompt + + response = await self._client.messages.create(**create_kwargs) + return "".join(block.text for block in response.content if block.type == "text") + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + system, anthropic_msgs = _openai_msgs_to_anthropic(messages) + + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": anthropic_msgs, + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + system_text = system or self.system_prompt + if system_text: + create_kwargs["system"] = system_text + anthropic_tools = _mcp_tools_to_anthropic(tools) + if anthropic_tools: + create_kwargs["tools"] = anthropic_tools + + response = await self._client.messages.create(**create_kwargs) + + content = "" + tool_calls = [] + for block in response.content: + if block.type == "text": + content += block.text + elif block.type == "tool_use": + tool_calls.append( + ToolCall(id=block.id, name=block.name, args=block.input) + ) + + return LLMResponse(content=content, tool_calls=tool_calls) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +_HOSTED_PROVIDERS: dict[str, tuple[str, int, type[LLMClient]]] = { + "openai": ("https://api.openai.com", 443, OpenAIClient), + "anthropic": ("https://api.anthropic.com", 443, AnthropicClient), +} + + +def create_llm_client( + provider: str, + model: str, + api_key: str, + *, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 4096, +) -> LLMClient: + """Create an LLM client for a hosted provider. + + Args: + provider: Provider name ("openai" or "anthropic"). + model: Model identifier. + api_key: API key for the provider. + system_prompt: Optional system message prepended to every request. + temperature: Sampling temperature. + max_tokens: Maximum tokens in the response. + + Returns: + A configured ``LLMClient`` instance. + """ + key = provider.lower() + if key not in _HOSTED_PROVIDERS: + raise ValueError( + f"Unsupported provider: {provider!r}. " + f"Supported: {sorted(_HOSTED_PROVIDERS)}" + ) + endpoint, port, cls = _HOSTED_PROVIDERS[key] + return cls( + endpoint, + port, + model, + api_key=api_key, + system_prompt=system_prompt, + temperature=temperature, + max_tokens=max_tokens, + ) + + +# --------------------------------------------------------------------------- +# MCP tool-schema helpers +# --------------------------------------------------------------------------- + + +def _clean_mcp_schema(schema: dict[str, Any]) -> dict[str, Any]: + """Normalize an MCP tool ``inputSchema`` for LLM function-calling APIs.""" + if not isinstance(schema, dict): + return {"type": "object", "properties": {}, "required": []} + + # Shallow copy to avoid mutating the caller's schema dict. + schema = dict(schema) + + if "oneOf" in schema: + for option in schema["oneOf"]: + if isinstance(option, dict) and option.get("type") == "object": + schema = option + break + else: + return {"type": "object", "properties": {}, "required": []} + + if "allOf" in schema: + merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + for sub in schema["allOf"]: + if isinstance(sub, dict): + if "properties" in sub: + merged["properties"].update(sub["properties"]) + if "required" in sub: + merged["required"].extend(sub["required"]) + schema = merged + + if "anyOf" in schema: + for option in schema["anyOf"]: + if isinstance(option, dict) and option.get("type") == "object": + schema = option + break + else: + return {"type": "object", "properties": {}, "required": []} + + schema.setdefault("type", "object") + if schema.get("type") == "object" and "properties" not in schema: + schema["properties"] = {} + return schema + + +def _mcp_tools_to_openai( + mcp_tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert MCP tool definitions to OpenAI function-calling format.""" + result = [] + for tool in mcp_tools: + input_schema = tool.get( + "inputSchema", {"type": "object", "properties": {}, "required": []} + ) + result.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": _clean_mcp_schema(input_schema), + }, + } + ) + return result + + +def _mcp_tools_to_anthropic( + mcp_tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert MCP tool definitions to Anthropic tool format.""" + result = [] + for tool in mcp_tools: + input_schema = tool.get( + "inputSchema", {"type": "object", "properties": {}, "required": []} + ) + result.append( + { + "name": tool["name"], + "description": tool.get("description", ""), + "input_schema": _clean_mcp_schema(input_schema), + } + ) + return result + + +def _openai_msgs_to_anthropic( + messages: list[dict[str, Any]], +) -> tuple[str, list[dict[str, Any]]]: + """Convert OpenAI-format messages to Anthropic format. + + Returns ``(system_text, anthropic_messages)``. System-role messages are + extracted and concatenated; tool-result messages are converted to + Anthropic's ``tool_result`` content blocks inside user turns. + """ + system_parts: list[str] = [] + anthropic_msgs: list[dict[str, Any]] = [] + + for msg in messages: + role = msg["role"] + + if role == "system": + system_parts.append(msg["content"]) + + elif role == "user": + anthropic_msgs.append({"role": "user", "content": msg["content"]}) + + elif role == "assistant": + if msg.get("tool_calls"): + content: list[dict[str, Any]] = [] + if msg.get("content"): + content.append({"type": "text", "text": msg["content"]}) + for tc in msg["tool_calls"]: + args = tc["function"]["arguments"] + if isinstance(args, str): + args = json.loads(args) + content.append( + { + "type": "tool_use", + "id": tc["id"], + "name": tc["function"]["name"], + "input": args, + } + ) + anthropic_msgs.append({"role": "assistant", "content": content}) + else: + anthropic_msgs.append( + {"role": "assistant", "content": msg.get("content", "")} + ) + + elif role == "tool": + tool_result = { + "type": "tool_result", + "tool_use_id": msg["tool_call_id"], + "content": msg["content"], + } + # Anthropic requires tool results in user turns; merge if possible. + if ( + anthropic_msgs + and anthropic_msgs[-1]["role"] == "user" + and isinstance(anthropic_msgs[-1]["content"], list) + ): + anthropic_msgs[-1]["content"].append(tool_result) + else: + anthropic_msgs.append({"role": "user", "content": [tool_result]}) + + system = "\n\n".join(system_parts) + return system, anthropic_msgs diff --git a/src/core/mcp_client.py b/src/core/mcp_client.py new file mode 100644 index 0000000000000000000000000000000000000000..edac3529d3a34e798781d86cf4d2495dc9611713 --- /dev/null +++ b/src/core/mcp_client.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP Client classes for tool-calling environments. + +This module provides async client classes for interacting with MCP-enabled environments: +- MCPClientBase: Base class with shared tool discovery +- MCPToolClient: Client for tool-calling style (one tool per step) + +These clients abstract away the MCP protocol details, providing a clean interface +for listing and calling tools on remote environments. All clients are async by default. + +Architecture Overview:: + + ┌─────────────────────────────────────────────────────────┐ + │ HTTPEnvServer │ + ├─────────────────────────────────────────────────────────┤ + │ Simulation Mode (default): │ + │ /ws → OpenEnv protocol (reset/step/state) │ + │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ + │ /reset, /step, /state → HTTP endpoints │ + ├─────────────────────────────────────────────────────────┤ + │ Production Mode (use_production_mode=True): │ + │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ + │ Bypasses step() for direct tool access │ + └─────────────────────────────────────────────────────────┘ + + Client Usage: + MCPToolClient (default) → /ws (step-based, with rewards) + MCPToolClient (production) → /mcp (direct tool access, no rewards) + +Example (async): + >>> from openenv.core.mcp_client import MCPToolClient + >>> + >>> async with MCPToolClient(base_url="http://localhost:8000") as env: + ... # Discover available tools + ... tools = await env.list_tools() + ... print([t.name for t in tools]) + ... + ... # Call a tool + ... result = await env.call_tool("echo_message", message="Hello!") + ... print(result) + +Example (sync wrapper): + >>> env = MCPToolClient(base_url="http://localhost:8000").sync() + >>> with env: + ... tools = env.list_tools() + ... result = env.call_tool("echo_message", message="Hello!") +""" + +from typing import Any, Dict, List, Optional + +from .client_types import StepResult +from .env_client import EnvClient +from .env_server.mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + Tool, + ToolError, +) +from .env_server.types import Observation, State + + +class MCPClientBase(EnvClient[Any, Observation, State]): + """ + Base class for MCP clients with tool discovery. + + This class provides the common `list_tools()` method for discovering + available tools from an MCP-enabled environment. Subclasses implement + specific interaction patterns (tool-calling or CodeAct). + + Attributes: + _tools_cache: Cached list of tools (populated on first `list_tools()` call) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional[Any] = None, + mode: Optional[str] = None, + ): + """ + Initialize MCP client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + connect_timeout_s: Timeout for establishing WebSocket connection. + message_timeout_s: Timeout for receiving responses to messages. + provider: Optional container/runtime provider for lifecycle management. + mode: Communication mode. Must be 'production' for MCP clients. Defaults to 'production'. + """ + # MCPClientBase defaults to production mode, but allow override for validation + if mode is None: + mode = "production" + + # Validate that mode is production + mode_lower = mode.lower() + if mode_lower != "production": + raise ValueError( + f"MCPToolClient only supports 'production' mode, got '{mode}'. " + f"Use GenericEnvClient for simulation mode." + ) + + super().__init__( + base_url=base_url, + connect_timeout_s=connect_timeout_s, + message_timeout_s=message_timeout_s, + provider=provider, + mode=mode, + ) + self._tools_cache: Optional[List[Tool]] = None + self.use_production_mode = False + + async def list_tools(self, use_cache: bool = True) -> List[Tool]: + """ + Discover available tools from the environment. + + Args: + use_cache: If True, return cached tools if available. + Set to False to force a fresh request. + + Returns: + List of Tool objects with name, description, and input_schema. + + Example: + >>> tools = await env.list_tools() + >>> for tool in tools: + ... print(f"{tool.name}: {tool.description}") + """ + if use_cache and self._tools_cache is not None: + return self._tools_cache + + # Use production mode HTTP endpoint if enabled + if self.use_production_mode: + import requests + + # Convert ws:// URL to http:// URL + url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") + # Remove /ws suffix if present and add /mcp + url = url.rstrip("/ws").rstrip("/") + "/mcp" + + try: + response = requests.post( + url, + json={ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 1, + }, + ) + data = response.json() + if "result" in data and "tools" in data["result"]: + tools = [ + Tool( + name=t.get("name", ""), + description=t.get("description", ""), + input_schema=t.get( + "input_schema", t.get("inputSchema", {}) + ), + ) + for t in data["result"]["tools"] + ] + self._tools_cache = tools + return tools + except Exception: + # If HTTP request fails, return empty list + pass + return [] + + result = await self.step(ListToolsAction()) + self._tools_cache = result.observation.tools + return self._tools_cache + + def _step_payload(self, action: Any) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + if isinstance(action, ListToolsAction): + return {"type": "list_tools"} + elif isinstance(action, CallToolAction): + return { + "type": "call_tool", + "tool_name": action.tool_name, + "arguments": action.arguments, + } + else: + # For unknown actions, try to serialize as dict + if hasattr(action, "model_dump"): + return action.model_dump() + return {"action": str(action)} + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: + """Convert a JSON response from the env server to StepResult[Observation].""" + obs_data = payload.get("observation", {}) + + # Check if this is a ListToolsObservation + if "tools" in obs_data: + tools = [ + Tool( + name=t.get("name", ""), + description=t.get("description", ""), + input_schema=t.get("input_schema", t.get("inputSchema", {})), + ) + for t in obs_data.get("tools", []) + ] + observation = ListToolsObservation( + tools=tools, + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + # Check if this is a CallToolObservation + elif "tool_name" in obs_data: + error = None + if obs_data.get("error"): + error = ToolError(**obs_data["error"]) + + observation = CallToolObservation( + tool_name=obs_data.get("tool_name", ""), + result=obs_data.get("result"), + error=error, + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + else: + # Generic observation + observation = Observation( + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> State: + """Convert a JSON response from the state endpoint to a State object.""" + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) + + +class MCPToolClient(MCPClientBase): + """ + Async client for tool-calling style MCP interactions. + + Each step invokes a single tool. Use this for traditional function-calling + agent patterns where the agent decides which tool to call next. + + This client provides convenience methods for tool discovery and invocation: + - `list_tools()`: Get all available tools with their schemas + - `call_tool(name, **kwargs)`: Invoke a tool by name with arguments + + Example (async): + >>> async with MCPToolClient(base_url="http://localhost:8000") as env: + ... # Reset the environment + ... await env.reset() + ... + ... # Discover available tools + ... tools = await env.list_tools() + ... print([t.name for t in tools]) # ['echo_message', 'echo_with_length'] + ... + ... # Call a tool directly + ... result = await env.call_tool("echo_message", message="Hello!") + ... print(result) # "Hello!" + ... + ... # Or use the full action interface + ... from openenv.core.env_server.mcp_types import CallToolAction + ... step_result = await env.step(CallToolAction( + ... tool_name="echo_with_length", + ... arguments={"message": "Test"} + ... )) + ... print(step_result.observation.result) + + Example (sync wrapper): + >>> env = MCPToolClient(base_url="http://localhost:8000").sync() + >>> with env: + ... tools = env.list_tools() + ... result = env.call_tool("echo_message", message="Hello!") + """ + + async def call_tool(self, name: str, **kwargs: Any) -> Any: + """ + Call a tool by name. + + This is a convenience method that creates a CallToolAction, executes it, + and returns the result directly. For more control, use `step()` with + a CallToolAction directly. + + Args: + name: Name of the tool to invoke (must match a tool from `list_tools()`). + **kwargs: Arguments to pass to the tool. Must match the tool's input_schema. + + Returns: + The tool's result. The type depends on the tool being called. + + Raises: + RuntimeError: If the server returns an error response. + + Example: + >>> result = await env.call_tool("add", a=5, b=3) + >>> print(result) # 8 + >>> + >>> result = await env.call_tool("greet", name="Claude") + >>> print(result) # "Hello, Claude!" + """ + action = CallToolAction(tool_name=name, arguments=kwargs) + result = await self.step(action) + obs = result.observation + + # Check for transport/framework errors + if isinstance(obs, CallToolObservation) and obs.error is not None: + raise RuntimeError( + f"Tool '{name}' failed: {obs.error.message} " + f"(type: {obs.error.error_type.value})" + ) + + # Return the result + if isinstance(obs, CallToolObservation): + result = obs.result + # Handle FastMCP CallToolResult objects + # - As object: has .data attribute + # - As dict (from JSON): has "data" key + if hasattr(result, "data"): + return result.data + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + + # Fallback for unexpected observation types + return obs + + async def get_tool(self, name: str) -> Optional[Tool]: + """ + Get a specific tool by name. + + Args: + name: Name of the tool to find. + + Returns: + The Tool object if found, None otherwise. + + Example: + >>> tool = await env.get_tool("echo_message") + >>> if tool: + ... print(tool.description) + ... print(tool.input_schema) + """ + tools = await self.list_tools() + for tool in tools: + if tool.name == name: + return tool + return None + + async def has_tool(self, name: str) -> bool: + """ + Check if a tool exists. + + Args: + name: Name of the tool to check. + + Returns: + True if the tool exists, False otherwise. + """ + return await self.get_tool(name) is not None diff --git a/src/core/openenv/__init__.py b/src/core/openenv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cabe2abc6a70dacafe04f0583b27b2552bab1e47 --- /dev/null +++ b/src/core/openenv/__init__.py @@ -0,0 +1,54 @@ +"""Unified OpenEnv package bundling the CLI and core runtime.""" + +from __future__ import annotations + +from importlib import import_module, metadata + +__all__ = [ + "core", + "cli", + "AutoEnv", + "AutoAction", + "GenericEnvClient", + "GenericAction", + "SyncEnvClient", +] + +try: + __version__ = metadata.version("openenv") # type: ignore[arg-type] +except metadata.PackageNotFoundError: # pragma: no cover - local dev + __version__ = "0.0.0" + + +_LAZY_MODULES = { + "core": ".core", + "cli": ".cli", +} + +_LAZY_ATTRS = { + "AutoEnv": (".auto", "AutoEnv"), + "AutoAction": (".auto", "AutoAction"), + "GenericEnvClient": (".core", "GenericEnvClient"), + "GenericAction": (".core", "GenericAction"), + "SyncEnvClient": (".core", "SyncEnvClient"), +} + + +def __getattr__(name: str): + if name in _LAZY_MODULES: + module = import_module(_LAZY_MODULES[name], __name__) + globals()[name] = module + return module + + if name in _LAZY_ATTRS: + module_path, attr_name = _LAZY_ATTRS[name] + module = import_module(module_path, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return sorted(set(globals().keys()) | set(__all__)) diff --git a/src/core/openenv/auto/__init__.py b/src/core/openenv/auto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a154570d50d01ed430ea221c1896c06cbc1b7f1c --- /dev/null +++ b/src/core/openenv/auto/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv Auto Module +=================== + +Provides HuggingFace-style auto-discovery API for OpenEnv environments. + +This module enables automatic environment and action class loading without +manual imports: + + >>> from openenv import AutoEnv, AutoAction + >>> + >>> # Load environment from installed package or HuggingFace Hub + >>> env = AutoEnv.from_name("coding-env") + >>> + >>> # Get action class + >>> CodeAction = AutoAction.from_name("coding") + >>> action = CodeAction(code="print('Hello!')") + +Classes: + AutoEnv: Automatic environment client selection and instantiation + AutoAction: Automatic action class selection + +The auto-discovery system works by: +1. Discovering installed openenv-* packages via importlib.metadata +2. Loading environment manifests (openenv.yaml) from package resources +3. Supporting HuggingFace Hub repositories for remote environments +4. Caching discovery results for performance +""" + +from .auto_action import AutoAction +from .auto_env import AutoEnv + +__all__ = ["AutoEnv", "AutoAction"] diff --git a/src/core/openenv/auto/_discovery.py b/src/core/openenv/auto/_discovery.py new file mode 100644 index 0000000000000000000000000000000000000000..9dda19f4a393a38f74ac4e2508d5edc0e19f0990 --- /dev/null +++ b/src/core/openenv/auto/_discovery.py @@ -0,0 +1,584 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Environment Auto-Discovery System +================================== + +This module provides automatic discovery of OpenEnv environments by: +1. Discovering installed openenv-* packages using importlib.metadata +2. Loading manifests (openenv.yaml) from package resources +3. Caching results for performance +4. Supporting HuggingFace Hub downloads + +This enables AutoEnv to work without coupling to src/envs/ directory. +""" + +import importlib +import importlib.metadata +import importlib.resources +import json +import logging +import re +import tempfile +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Type + +import yaml + +logger = logging.getLogger(__name__) + + +@dataclass +class EnvironmentInfo: + """ + Rich information about a discovered environment. + + Attributes: + env_key: Environment key (e.g., "echo", "coding") + name: Full environment name (e.g., "echo_env") + package_name: Package name (e.g., "openenv-echo_env") + version: Version string + description: Human-readable description + client_module_path: Full module path to client (e.g., "echo_env.client") + client_class_name: Client class name (e.g., "EchoEnv") + action_class_name: Action class name (e.g., "EchoAction") + observation_class_name: Observation class name (e.g., "EchoObservation") + default_image: Default Docker image name (e.g., "echo-env:latest") + spec_version: OpenEnv spec version (from openenv.yaml) + manifest: Original manifest data + """ + + env_key: str + name: str + package_name: str + version: str + description: str + client_module_path: str + client_class_name: str + action_class_name: str + observation_class_name: str + default_image: str + spec_version: Optional[int] = None + manifest: Optional[Dict[str, Any]] = None + + def get_client_class(self) -> Type: + """ + Dynamically import and return the client class. + + Returns: + Client class (e.g., EchoEnv) + + Raises: + ImportError: If module or class cannot be imported + """ + try: + module = importlib.import_module(self.client_module_path) + return getattr(module, self.client_class_name) + except ImportError as e: + raise ImportError( + f"Failed to import {self.client_class_name} from {self.client_module_path}: {e}\n" + f"Make sure the package '{self.package_name}' is installed: " + f"pip install {self.package_name}" + ) from e + except AttributeError as e: + raise ImportError( + f"Class {self.client_class_name} not found in {self.client_module_path}: {e}" + ) from e + + def get_action_class(self) -> Type: + """ + Dynamically import and return the action class. + + Returns: + Action class (e.g., EchoAction) + + Raises: + ImportError: If module or class cannot be imported + """ + try: + module = importlib.import_module(self.client_module_path) + return getattr(module, self.action_class_name) + except ImportError as e: + raise ImportError( + f"Failed to import {self.action_class_name} from {self.client_module_path}: {e}\n" + f"Make sure the package '{self.package_name}' is installed: " + f"pip install {self.package_name}" + ) from e + except AttributeError as e: + raise ImportError( + f"Class {self.action_class_name} not found in {self.client_module_path}: {e}" + ) from e + + def get_observation_class(self) -> Type: + """ + Dynamically import and return the observation class. + + Returns: + Observation class (e.g., EchoObservation) + + Raises: + ImportError: If module or class cannot be imported + """ + try: + module = importlib.import_module(self.client_module_path) + return getattr(module, self.observation_class_name) + except ImportError as e: + raise ImportError( + f"Failed to import {self.observation_class_name} from {self.client_module_path}: {e}\n" + f"Make sure the package '{self.package_name}' is installed: " + f"pip install {self.package_name}" + ) from e + except AttributeError as e: + raise ImportError( + f"Class {self.observation_class_name} not found in {self.client_module_path}: {e}" + ) from e + + +def _normalize_env_name(name: str) -> str: + """ + Normalize environment name to standard format. + + Args: + name: Input name (e.g., "echo", "echo-env", "echo_env") + + Returns: + Normalized name (e.g., "echo_env") + + Examples: + >>> _normalize_env_name("echo") + 'echo_env' + >>> _normalize_env_name("echo-env") + 'echo_env' + >>> _normalize_env_name("echo_env") + 'echo_env' + """ + # Remove common suffixes + name = re.sub(r"[-_]env$", "", name) + # Convert hyphens to underscores + name = name.replace("-", "_") + # Add _env suffix if not present + if not name.endswith("_env"): + name = f"{name}_env" + return name + + +def _is_hub_url(name: str) -> bool: + """ + Check if name is a HuggingFace Hub URL or repo ID. + + Args: + name: Input name + + Returns: + True if it looks like a Hub URL + + Examples: + >>> _is_hub_url("meta-pytorch/echo_env") + True + >>> _is_hub_url("https://huggingface.co/meta-pytorch/echo_env") + True + >>> _is_hub_url("echo") + False + """ + # Contains org/repo pattern or huggingface.co domain + return "/" in name or "huggingface.co" in name + + +def _infer_class_name(env_name: str, class_type: str) -> str: + """ + Infer class name from environment name using simple conventions. + + Args: + env_name: Environment name (e.g., "echo_env") + class_type: Type of class ("client", "action", "observation") + + Returns: + Inferred class name + + Examples: + >>> _infer_class_name("echo_env", "client") + 'EchoEnv' + >>> _infer_class_name("echo_env", "action") + 'EchoAction' + """ + # Remove _env suffix for base name + base_name = env_name.replace("_env", "") + + # Convert to PascalCase + pascal_name = "".join(word.capitalize() for word in base_name.split("_")) + + # Add suffix based on type + if class_type == "client": + return f"{pascal_name}Env" + elif class_type == "action": + return f"{pascal_name}Action" + elif class_type == "observation": + return f"{pascal_name}Observation" + else: + raise ValueError(f"Unknown class type: {class_type}") + + +def _load_manifest_from_package( + package_name: str, module_name: str +) -> Optional[Dict[str, Any]]: + """ + Load openenv.yaml manifest from an installed package. + + Args: + package_name: Package name (e.g., "openenv-echo_env") + module_name: Module name (e.g., "echo_env") + + Returns: + Parsed manifest dictionary, or None if not found + + """ + try: + # Try to read openenv.yaml from package + if hasattr(importlib.resources, "files"): + # Python 3.9+ + package_files = importlib.resources.files(module_name) + if (package_files / "openenv.yaml").is_file(): + manifest_text = (package_files / "openenv.yaml").read_text() + return yaml.safe_load(manifest_text) + else: + # Python 3.7-3.8 fallback + with importlib.resources.open_text(module_name, "openenv.yaml") as f: + return yaml.safe_load(f) + except (FileNotFoundError, ModuleNotFoundError, AttributeError): + logger.debug(f"No openenv.yaml found in {module_name}") + return None + except Exception as e: + logger.warning(f"Failed to load openenv.yaml from {module_name}: {e}") + return None + + +def _create_env_info_from_package( + package_name: str, module_name: str, version: str +) -> Optional[EnvironmentInfo]: + """ + Create EnvironmentInfo from an installed package. + + Args: + package_name: Package name (e.g., "openenv-echo_env") + module_name: Module name (e.g., "echo_env") + version: Package version + + Returns: + EnvironmentInfo instance, or None if invalid + """ + # Load manifest + manifest = _load_manifest_from_package(package_name, module_name) + + # Get environment name + if manifest and "name" in manifest: + env_name = manifest["name"] + else: + # Infer from module name + env_name = module_name + + # Normalize to ensure _env suffix + if not env_name.endswith("_env"): + env_name = f"{env_name}_env" + + # Determine env_key (e.g., "echo_env" → "echo") + env_key = env_name.replace("_env", "") if env_name.endswith("_env") else env_name + + # Get description + description = ( + manifest.get("description", f"{env_name} environment") + if manifest + else f"{env_name} environment" + ) + + # Get spec version + spec_version = manifest.get("spec_version") if manifest else None + + # Determine class names + # Check if manifest has custom class names (custom format) + if manifest and "action" in manifest and "observation" in manifest: + # Custom format (like coding_env) + client_class_name = _infer_class_name(env_name, "client") + action_class_name = manifest.get( + "action", _infer_class_name(env_name, "action") + ) + observation_class_name = manifest.get( + "observation", _infer_class_name(env_name, "observation") + ) + else: + # Use conventions + client_class_name = _infer_class_name(env_name, "client") + action_class_name = _infer_class_name(env_name, "action") + observation_class_name = _infer_class_name(env_name, "observation") + + # Module path is just module_name.client + client_module_path = f"{module_name}.client" + + # Determine default Docker image name + image_name = env_name.replace("_", "-") + default_image = f"{image_name}:latest" + + return EnvironmentInfo( + env_key=env_key, + name=env_name, + package_name=package_name, + version=version, + description=description, + client_module_path=client_module_path, + client_class_name=client_class_name, + action_class_name=action_class_name, + observation_class_name=observation_class_name, + default_image=default_image, + spec_version=spec_version, + manifest=manifest, + ) + + +class EnvironmentDiscovery: + """ + Auto-discovery system for OpenEnv environments using installed packages. + + This class discovers installed openenv-* packages and loads their metadata. + """ + + def __init__(self): + """Initialize discovery system.""" + self._cache: Optional[Dict[str, EnvironmentInfo]] = None + self._cache_file = Path(tempfile.gettempdir()) / "openenv_discovery_cache.json" + + def _discover_installed_packages(self) -> Dict[str, EnvironmentInfo]: + """ + Discover all installed openenv-* packages. + + Returns: + Dictionary mapping env_key to EnvironmentInfo + """ + environments = {} + + # Invalidate import caches to ensure we pick up newly installed packages + importlib.invalidate_caches() + + # Get all installed packages + try: + distributions = importlib.metadata.distributions() + except Exception as e: + logger.warning(f"Failed to get installed packages: {e}") + return environments + + # Filter for openenv-* packages (exclude openenv-core) + for dist in distributions: + package_name = dist.metadata["Name"] + + if not package_name.startswith("openenv-"): + continue + + if package_name == "openenv-core": + continue + + # Get module name (e.g., "openenv-echo_env" → "echo_env") + module_name = package_name.replace("openenv-", "").replace("-", "_") + + # Get version + version = dist.version + + try: + # Create environment info + env_info = _create_env_info_from_package( + package_name, module_name, version + ) + + if env_info: + environments[env_info.env_key] = env_info + logger.debug( + f"Discovered environment: {env_info.env_key} ({package_name})" + ) + + except Exception as e: + logger.warning(f"Failed to load environment from {package_name}: {e}") + continue + + return environments + + def _load_cache(self) -> Optional[Dict[str, EnvironmentInfo]]: + """ + Load cached discovery results. + + Returns: + Dictionary of env_key -> EnvironmentInfo, or None if cache invalid + """ + if not self._cache_file.exists(): + return None + + try: + with open(self._cache_file, "r") as f: + cache_data = json.load(f) + + # Reconstruct EnvironmentInfo objects + cache = {} + for env_key, env_data in cache_data.items(): + cache[env_key] = EnvironmentInfo(**env_data) + + return cache + except Exception as e: + logger.warning(f"Failed to load discovery cache: {e}") + return None + + def _save_cache(self, environments: Dict[str, EnvironmentInfo]) -> None: + """ + Save discovery results to cache. + + Args: + environments: Dictionary of env_key -> EnvironmentInfo + """ + try: + cache_data = {} + for env_key, env_info in environments.items(): + cache_data[env_key] = asdict(env_info) + + with open(self._cache_file, "w") as f: + json.dump(cache_data, f, indent=2) + + except Exception as e: + logger.warning(f"Failed to save discovery cache: {e}") + + def discover(self, use_cache: bool = True) -> Dict[str, EnvironmentInfo]: + """ + Discover all installed OpenEnv environments. + + Args: + use_cache: If True, try to load from cache first + + Returns: + Dictionary mapping env_key to EnvironmentInfo + + Examples: + >>> discovery = EnvironmentDiscovery() + >>> envs = discovery.discover() + >>> print(envs.keys()) + dict_keys(['echo', 'coding', ...]) + """ + # Try to load from memory cache first + if use_cache and self._cache is not None: + return self._cache + + # Try to load from file cache + if use_cache: + cached = self._load_cache() + if cached is not None: + self._cache = cached + return self._cache + + # Discover from installed packages + environments = self._discover_installed_packages() + + # Save to cache + self._save_cache(environments) + self._cache = environments + + return environments + + def get_environment(self, env_key: str) -> Optional[EnvironmentInfo]: + """ + Get information about a specific environment. + + Args: + env_key: Environment key (e.g., "echo", "coding") + + Returns: + EnvironmentInfo if found, None otherwise + + Examples: + >>> discovery = EnvironmentDiscovery() + >>> env = discovery.get_environment("echo") + >>> print(env.client_class_name) + 'EchoEnv' + """ + environments = self.discover() + return environments.get(env_key) + + def get_environment_by_name(self, name: str) -> Optional[EnvironmentInfo]: + """ + Get environment info by flexible name matching. + + Args: + name: Environment name (e.g., "echo", "echo-env", "echo_env") + + Returns: + EnvironmentInfo if found, None otherwise + """ + # Normalize name to env_key + normalized = _normalize_env_name(name) + env_key = normalized.replace("_env", "") + + return self.get_environment(env_key) + + def list_environments(self) -> None: + """ + Print a formatted list of all discovered environments. + + Examples: + >>> discovery = EnvironmentDiscovery() + >>> discovery.list_environments() + Available OpenEnv Environments: + ---------------------------------------------------------------------- + echo : Echo Environment (v0.1.0) - openenv-echo_env + coding : Coding Environment (v0.1.0) - openenv-coding_env + ... + """ + environments = self.discover() + + print("Available OpenEnv Environments:") + print("-" * 70) + + if not environments: + print(" No OpenEnv environments found.") + print(" Install environments with: pip install openenv-") + else: + for env_key in sorted(environments.keys()): + env = environments[env_key] + print(f" {env_key:<15}: {env.description} (v{env.version})") + print(f" Package: {env.package_name}") + + print("-" * 70) + print(f"Total: {len(environments)} environments") + + def clear_cache(self) -> None: + """Clear the discovery cache.""" + if self._cache_file.exists(): + self._cache_file.unlink() + self._cache = None + + +# Global discovery instance +_global_discovery: Optional[EnvironmentDiscovery] = None + + +def get_discovery() -> EnvironmentDiscovery: + """ + Get or create the global discovery instance. + + Returns: + Global EnvironmentDiscovery instance + + Examples: + >>> discovery = get_discovery() + >>> envs = discovery.discover() + """ + global _global_discovery + + if _global_discovery is None: + _global_discovery = EnvironmentDiscovery() + + return _global_discovery + + +def reset_discovery() -> None: + """Reset the global discovery instance (useful for testing).""" + global _global_discovery + if _global_discovery is not None: + _global_discovery.clear_cache() + _global_discovery = None diff --git a/src/core/openenv/auto/auto_action.py b/src/core/openenv/auto/auto_action.py new file mode 100644 index 0000000000000000000000000000000000000000..b097ad1d193a605fe834ff18dd9ccd8d913eab45 --- /dev/null +++ b/src/core/openenv/auto/auto_action.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +AutoAction - Automatic Action Class Selection +============================================== + +AutoAction provides a HuggingFace-style API for automatically retrieving the +correct Action class from installed packages or HuggingFace Hub. + +This module simplifies working with environment actions by automatically +detecting and returning the appropriate Action class without requiring +manual imports. + +Example: + >>> from openenv import AutoEnv, AutoAction + >>> + >>> # Get Action class from environment name + >>> CodeAction = AutoAction.from_env("coding") + >>> action = CodeAction(code="print('Hello!')") + >>> + >>> # From HuggingFace Hub + >>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env") + >>> + >>> # Use with AutoEnv + >>> env = AutoEnv.from_env("coding-env") + >>> result = env.step(action) +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Type + +from ._discovery import _is_hub_url, get_discovery +from .auto_env import AutoEnv + +logger = logging.getLogger(__name__) + + +class AutoAction: + """ + AutoAction automatically retrieves the correct Action class based on + environment names or HuggingFace Hub repositories. + + This class follows the HuggingFace AutoModel pattern, making it easy to + get the right Action class without needing to know which module to import. + + The class provides factory methods that look up the Action class and + return the class (not an instance) for you to instantiate. + + Example: + >>> # From installed package + >>> CodeAction = AutoAction.from_env("coding") + >>> action = CodeAction(code="print('test')") + >>> + >>> # From HuggingFace Hub + >>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env") + >>> action = CodeAction(code="print('test')") + >>> + >>> # Use with AutoEnv for a complete workflow + >>> env = AutoEnv.from_env("coding-env") + >>> ActionClass = AutoAction.from_env("coding-env") + >>> action = ActionClass(code="print('Hello, AutoAction!')") + >>> result = env.step(action) + + Note: + AutoAction is not meant to be instantiated directly. Use the class + method from_env() instead. + """ + + def __init__(self): + """AutoAction should not be instantiated directly. Use class methods instead.""" + raise TypeError( + "AutoAction is a factory class and should not be instantiated directly. " + "Use AutoAction.from_hub() or AutoAction.from_env() instead." + ) + + @classmethod + def from_env(cls, name: str, skip_install: bool = False) -> Type: + """ + Get the Action class from environment name or HuggingFace Hub repository. + + This method automatically: + 1. Checks if the name is a HuggingFace Hub URL/repo ID + 2. If Hub: downloads and installs the environment package + 3. If local: looks up the installed openenv-* package + 4. Imports and returns the Action class + + Args: + name: Environment name or HuggingFace Hub repo ID + Examples: + - "coding" / "coding-env" / "coding_env" + - "meta-pytorch/coding-env" (Hub repo ID) + - "https://huggingface.co/meta-pytorch/coding-env" (Hub URL) + skip_install: If True, skip package installation and return + GenericAction class instead. Use this when working with + GenericEnvClient to avoid installing remote packages. + + Returns: + Action class (not an instance!). Returns GenericAction when + skip_install=True. + + Raises: + ValueError: If environment not found (only when skip_install=False) + ImportError: If environment package is not installed (only when skip_install=False) + + Examples: + >>> # From installed package + >>> CodeAction = AutoAction.from_env("coding-env") + >>> action = CodeAction(code="print('Hello!')") + >>> + >>> # From HuggingFace Hub + >>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env") + >>> action = CodeAction(code="print('Hello!')") + >>> + >>> # Skip installation, use GenericAction (for GenericEnvClient) + >>> ActionClass = AutoAction.from_env("user/repo", skip_install=True) + >>> action = ActionClass(code="print('Hello!')") # Returns GenericAction + >>> + >>> # Different name formats + >>> EchoAction = AutoAction.from_env("echo") + >>> EchoAction = AutoAction.from_env("echo-env") + >>> EchoAction = AutoAction.from_env("echo_env") + """ + # If skip_install is True, return GenericAction without any package lookup + if skip_install: + from openenv.core.generic_client import GenericAction + + logger.info( + f"Returning GenericAction for '{name}' (skip_install=True). " + f"Use keyword arguments to create actions: GenericAction(code='...')" + ) + return GenericAction + + # Check if it's a HuggingFace Hub URL or repo ID + if _is_hub_url(name): + # Ensure package is installed (reuse AutoEnv logic, downloads only if needed) + env_name = AutoEnv._ensure_package_from_hub(name) + else: + env_name = name + + # Get environment info from discovery + discovery = get_discovery() + env_info = discovery.get_environment_by_name(env_name) + + if not env_info: + # Environment not found - provide helpful error message + available_envs = discovery.discover() + + if not available_envs: + raise ValueError( + "No OpenEnv environments found.\n" + "Install an environment with: pip install openenv-\n" + "Or specify a HuggingFace Hub repository: AutoAction.from_env('openenv/echo_env')" + ) + + # Try to suggest similar environment names + from difflib import get_close_matches + + env_keys = list(available_envs.keys()) + suggestions = get_close_matches(env_name, env_keys, n=3, cutoff=0.6) + + error_msg = f"Unknown environment '{env_name}'.\n" + if suggestions: + error_msg += f"Did you mean: {', '.join(suggestions)}?\n" + error_msg += f"Available environments: {', '.join(sorted(env_keys))}" + + raise ValueError(error_msg) + + # Get the action class + try: + action_class = env_info.get_action_class() + return action_class + except ImportError as e: + raise ImportError( + f"Failed to import action class for '{env_name}'.\n" + f"Package '{env_info.package_name}' appears to be installed but the module cannot be imported.\n" + f"Try reinstalling: pip install --force-reinstall {env_info.package_name}\n" + f"Original error: {e}" + ) from e + + @classmethod + def from_hub(cls, env_name: str, skip_install: bool = False) -> Type: + """ + Get the Action class from environment name. + + This is an alias for from_env() for backward compatibility and clarity. + + Args: + env_name: Environment name (e.g., "coding", "echo") + skip_install: If True, skip package installation and return + GenericAction class instead. + + Returns: + Action class (not an instance!) + + Examples: + >>> CodeAction = AutoAction.from_hub("coding") + >>> action = CodeAction(code="print('Hello!')") + """ + return cls.from_env(env_name, skip_install=skip_install) + + @classmethod + def get_action_info(cls, name: str) -> Dict[str, Any]: + """ + Get detailed information about an action class. + + Args: + name: Environment name + + Returns: + Dictionary with action class metadata + + Raises: + ValueError: If environment not found + + Examples: + >>> info = AutoAction.get_action_info("coding") + >>> print(info['action_class']) + 'CodingAction' + >>> print(info['module']) + 'coding_env.client' + """ + discovery = get_discovery() + env_info = discovery.get_environment_by_name(name) + + if not env_info: + raise ValueError(f"Unknown environment: {name}") + + return { + "env_key": env_info.env_key, + "env_name": env_info.name, + "package": env_info.package_name, + "action_class": env_info.action_class_name, + "observation_class": env_info.observation_class_name, + "module": env_info.client_module_path, + } + + @classmethod + def list_actions(cls) -> None: + """ + Print a formatted list of all available action classes. + + This discovers all installed openenv-* packages and displays + their action class information in a user-friendly format. + + Examples: + >>> AutoAction.list_actions() + Available Action Classes: + ---------------------------------------------------------------------- + echo : EchoAction (from openenv-echo-env) + coding : CodingAction (from openenv-coding_env) + ---------------------------------------------------------------------- + Total: 2 action classes + """ + discovery = get_discovery() + environments = discovery.discover() + + print("Available Action Classes:") + print("-" * 70) + + if not environments: + print(" No OpenEnv environments found.") + print(" Install environments with: pip install openenv-") + else: + for env_key in sorted(environments.keys()): + env = environments[env_key] + print(f" {env_key:<15}: {env.action_class_name}") + print(f" Package: {env.package_name}") + + print("-" * 70) + print(f"Total: {len(environments)} action classes") diff --git a/src/core/openenv/auto/auto_env.py b/src/core/openenv/auto/auto_env.py new file mode 100644 index 0000000000000000000000000000000000000000..be845565b651ec721a029505d754bb0e5328bfa6 --- /dev/null +++ b/src/core/openenv/auto/auto_env.py @@ -0,0 +1,897 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +AutoEnv - Automatic Environment Selection +========================================== + +AutoEnv provides a HuggingFace-style API for automatically selecting and +instantiating the correct environment client from installed packages or +HuggingFace Hub. + +This module simplifies environment creation by automatically detecting the +environment type from the name and instantiating the appropriate client class. + +Example: + >>> from openenv import AutoEnv, AutoAction + >>> + >>> # From installed package + >>> env = AutoEnv.from_env("coding-env") + >>> + >>> # From HuggingFace Hub + >>> env = AutoEnv.from_env("meta-pytorch/coding-env") + >>> + >>> # With configuration + >>> env = AutoEnv.from_env("coding", env_vars={"DEBUG": "1"}) +""" + +from __future__ import annotations + +import importlib +import logging +import os +import shutil +import subprocess +import sys +from typing import Any, Dict, Optional, TYPE_CHECKING + +import requests +from openenv.core.utils import run_async_safely + +from ._discovery import _is_hub_url, get_discovery + + +if TYPE_CHECKING: + from openenv.core.containers.runtime import ContainerProvider + from openenv.core.env_client import EnvClient + +logger = logging.getLogger(__name__) + +# Cache for repo ID → env_name mapping to avoid redundant downloads +_hub_env_name_cache: Dict[str, str] = {} + +# Environment variable to skip user confirmation for remote installs +OPENENV_TRUST_REMOTE_CODE = "OPENENV_TRUST_REMOTE_CODE" + + +def _has_uv() -> bool: + """Check if uv is available in the system.""" + return shutil.which("uv") is not None + + +def _get_pip_command() -> list[str]: + """ + Get the appropriate pip command (uv pip or pip). + + Returns: + List of command parts for pip installation + """ + if _has_uv(): + return ["uv", "pip"] + return [sys.executable, "-m", "pip"] + + +def _confirm_remote_install(repo_id: str) -> bool: + """ + Ask user for confirmation before installing remote code. + + This is a security measure since we're executing code from the internet. + + Args: + repo_id: The HuggingFace repo ID being installed + + Returns: + True if user confirms, False otherwise + """ + # Check environment variable for automated/CI environments + if os.environ.get(OPENENV_TRUST_REMOTE_CODE, "").lower() in ("1", "true", "yes"): + logger.info("Skipping confirmation (OPENENV_TRUST_REMOTE_CODE is set)") + return True + + # Check if we're in an interactive terminal + if not sys.stdin.isatty(): + logger.warning( + "Cannot prompt for confirmation in non-interactive mode. " + "Set OPENENV_TRUST_REMOTE_CODE=1 to allow remote installs." + ) + return False + + print(f"\n{'=' * 60}") + print("⚠️ SECURITY WARNING: Remote Code Installation") + print(f"{'=' * 60}") + print("You are about to install code from a remote repository:") + print(f" Repository: {repo_id}") + print(f" Source: https://huggingface.co/spaces/{repo_id}") + print("\nThis will execute code from the internet on your machine.") + print("Only proceed if you trust the source.") + print(f"{'=' * 60}\n") + + try: + response = input("Do you want to proceed? [y/N]: ").strip().lower() + return response in ("y", "yes") + except (EOFError, KeyboardInterrupt): + print("\nInstallation cancelled.") + return False + + +class AutoEnv: + """ + AutoEnv automatically selects and instantiates the correct environment client + based on environment names or HuggingFace Hub repositories. + + This class follows the HuggingFace AutoModel pattern, making it easy to work + with different environments without needing to import specific client classes. + + The class provides factory methods that: + 1. Check if name is a HuggingFace Hub URL/repo ID + 2. If Hub: download and install the environment package + 3. If local: look up the installed openenv-* package + 4. Import and instantiate the client class + + Example: + >>> # From installed package + >>> env = AutoEnv.from_env("coding-env") + >>> + >>> # From HuggingFace Hub + >>> env = AutoEnv.from_env("meta-pytorch/coding-env") + >>> + >>> # List available environments + >>> AutoEnv.list_environments() + + Note: + AutoEnv is not meant to be instantiated directly. Use the class method + from_env() instead. + """ + + def __init__(self): + """AutoEnv should not be instantiated directly. Use class methods instead.""" + raise TypeError( + "AutoEnv is a factory class and should not be instantiated directly. " + "Use AutoEnv.from_hub() or AutoEnv.from_env() instead." + ) + + @classmethod + def _resolve_space_url(cls, repo_id: str) -> str: + """ + Resolve HuggingFace Space repo ID to Space URL. + + Args: + repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + + Returns: + Space URL (e.g., "https://wukaixingxp-coding-env-test.hf.space") + + Examples: + >>> AutoEnv._resolve_space_url("wukaixingxp/coding-env-test") + 'https://wukaixingxp-coding-env-test.hf.space' + """ + # Clean up repo_id if it's a full URL + if "huggingface.co" in repo_id: + # Extract org/repo from URL + # https://huggingface.co/wukaixingxp/coding-env-test -> wukaixingxp/coding-env-test + parts = repo_id.split("/") + if len(parts) >= 2: + repo_id = f"{parts[-2]}/{parts[-1]}" + + # Convert user/space-name to user-space-name.hf.space + space_slug = repo_id.replace("/", "-") + return f"https://{space_slug}.hf.space" + + @classmethod + def _is_local_url(cls, url: str) -> bool: + """ + Check if a URL points to a local server. + + Args: + url: URL to check + + Returns: + True if URL is localhost or 127.0.0.1, False otherwise + + Examples: + >>> AutoEnv._is_local_url("http://localhost:8000") + True + >>> AutoEnv._is_local_url("http://127.0.0.1:8000") + True + >>> AutoEnv._is_local_url("https://example.com") + False + """ + url_lower = url.lower() + return "localhost" in url_lower or "127.0.0.1" in url_lower + + @classmethod + def _check_server_availability(cls, base_url: str, timeout: float = 2.0) -> bool: + """ + Check if a server at the given URL is running and accessible. + + Args: + base_url: Server base URL to check + timeout: Request timeout in seconds + + Returns: + True if server is accessible, False otherwise + + Examples: + >>> AutoEnv._check_server_availability("http://localhost:8000") + True # if server is running + """ + try: + # Bypass proxy for localhost to avoid proxy issues + proxies = None + if cls._is_local_url(base_url): + proxies = {"http": None, "https": None} + + # Try to access the health endpoint + response = requests.get( + f"{base_url}/health", timeout=timeout, proxies=proxies + ) + if response.status_code == 200: + return True + + # If health endpoint doesn't exist, try root endpoint + response = requests.get(base_url, timeout=timeout, proxies=proxies) + return response.status_code == 200 + except (requests.RequestException, Exception) as e: + logger.debug(f"Server {base_url} not accessible: {e}") + return False + + @classmethod + def _check_space_availability(cls, space_url: str, timeout: float = 5.0) -> bool: + """ + Check if HuggingFace Space is running and accessible. + + Args: + space_url: Space URL to check + timeout: Request timeout in seconds + + Returns: + True if Space is accessible, False otherwise + + Examples: + >>> AutoEnv._check_space_availability("https://wukaixingxp-coding-env-test.hf.space") + True + """ + try: + # Try to access the health endpoint + response = requests.get(f"{space_url}/health", timeout=timeout) + if response.status_code == 200: + return True + + # If health endpoint doesn't exist, try root endpoint + response = requests.get(space_url, timeout=timeout) + return response.status_code == 200 + except (requests.RequestException, Exception) as e: + logger.debug(f"Space {space_url} not accessible: {e}") + return False + + @classmethod + def _get_hub_git_url(cls, repo_id: str) -> str: + """ + Get the git URL for a HuggingFace Space. + + Args: + repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + + Returns: + Git URL for pip installation (e.g., "git+https://huggingface.co/spaces/wukaixingxp/coding-env-test") + """ + # Clean up repo_id if it's a full URL + if "huggingface.co" in repo_id: + parts = repo_id.split("/") + if len(parts) >= 2: + repo_id = f"{parts[-2]}/{parts[-1]}" + + return f"git+https://huggingface.co/spaces/{repo_id}" + + @classmethod + def _install_from_hub(cls, repo_id: str, trust_remote_code: bool = False) -> str: + """ + Install environment package directly from HuggingFace Hub using git+. + + This is the preferred method as it avoids downloading the entire repo + and uses pip/uv's native git support. + + Args: + repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + trust_remote_code: If True, skip user confirmation + + Returns: + Package name that was installed + + Raises: + ValueError: If installation fails or user declines + """ + # Security check - confirm with user before installing remote code + if not trust_remote_code and not _confirm_remote_install(repo_id): + raise ValueError( + "Installation cancelled by user.\n" + "To allow remote installs without prompting, set OPENENV_TRUST_REMOTE_CODE=1" + ) + + git_url = cls._get_hub_git_url(repo_id) + pip_cmd = _get_pip_command() + pip_name = "uv pip" if pip_cmd[0] == "uv" else "pip" + + logger.info(f"Installing from HuggingFace Space using {pip_name}: {repo_id}") + logger.info(f"Command: {' '.join(pip_cmd)} install {git_url}") + + try: + result = subprocess.run( + [*pip_cmd, "install", git_url], + check=True, + capture_output=True, + text=True, + ) + + # Try to extract package name from pip output + # Look for "Successfully installed -" + for line in result.stdout.split("\n"): + if "Successfully installed" in line: + # Parse package name from the line + parts = line.replace("Successfully installed", "").strip().split() + for part in parts: + if part.startswith("openenv-"): + # Remove version suffix (e.g., "openenv-coding_env-0.1.0" -> "openenv-coding_env") + # Check if last segment looks like a version number + last_segment = part.rsplit("-", 1)[-1] + if last_segment.replace(".", "").isdigit(): + package_name = "-".join(part.rsplit("-", 1)[:-1]) + else: + package_name = part + logger.info(f"Successfully installed: {package_name}") + return package_name + + # Fallback: try to determine package name from repo_id + # Convention: repo name like "coding-env-test" -> package "openenv-coding_env" + env_name = repo_id.split("/")[-1] # Get repo name from "user/repo" + env_name = env_name.replace("-", "_") + if not env_name.endswith("_env"): + env_name = f"{env_name}_env" + package_name = f"openenv-{env_name}" + + logger.info(f"Installed (inferred package name): {package_name}") + return package_name + + except subprocess.CalledProcessError as e: + error_msg = e.stderr or e.stdout or str(e) + raise ValueError( + f"Failed to install environment from HuggingFace Space: {repo_id}\n" + f"Command: {' '.join(pip_cmd)} install {git_url}\n" + f"Error: {error_msg}\n" + f"Make sure the repository exists and contains a valid Python package." + ) from e + + @classmethod + def _is_package_installed(cls, package_name: str) -> bool: + """ + Check if a package is already installed. + + Args: + package_name: Package name (e.g., "openenv-coding_env") + + Returns: + True if installed, False otherwise + """ + try: + import importlib.metadata + + importlib.metadata.distribution(package_name) + return True + except importlib.metadata.PackageNotFoundError: + return False + + @classmethod + def _ensure_package_from_hub( + cls, name: str, trust_remote_code: bool = False + ) -> str: + """ + Ensure package from HuggingFace Hub is installed. + + Uses git+ URLs for direct installation without downloading the entire repo. + Prompts user for confirmation before installing remote code. + + Args: + name: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + trust_remote_code: If True, skip user confirmation + + Returns: + Environment name (e.g., "coding_env") + """ + global _hub_env_name_cache + + # Check if we already resolved this repo ID + if name in _hub_env_name_cache: + env_name = _hub_env_name_cache[name] + logger.debug(f"Using cached env name for {name}: {env_name}") + return env_name + + # Try to infer expected package name from repo ID + # Convention: repo "user/coding-env" -> package "openenv-coding_env" + repo_name = name.split("/")[-1] if "/" in name else name + expected_env_name = repo_name.replace("-", "_") + if not expected_env_name.endswith("_env"): + expected_env_name = f"{expected_env_name}_env" + expected_package_name = f"openenv-{expected_env_name}" + + # Check if already installed + if cls._is_package_installed(expected_package_name): + logger.info(f"Package already installed: {expected_package_name}") + # Clear and refresh discovery cache to make sure it's detected + get_discovery().clear_cache() + get_discovery().discover(use_cache=False) + # Cache the result + _hub_env_name_cache[name] = expected_env_name + return expected_env_name + + # Not installed, install using git+ URL + logger.info(f"Package not found locally, installing from Hub: {name}") + + # Track existing packages before installation + get_discovery().clear_cache() + existing_envs = set(get_discovery().discover(use_cache=False).keys()) + + # Install the package + cls._install_from_hub(name, trust_remote_code=trust_remote_code) + + # Clear discovery cache to pick up the newly installed package + try: + importlib.invalidate_caches() + except Exception: + pass + get_discovery().clear_cache() + discovered_envs = get_discovery().discover(use_cache=False) + + # Find the newly installed environment by comparing before/after + new_envs = set(discovered_envs.keys()) - existing_envs + + if new_envs: + # Use the first newly discovered environment + env_name = next(iter(new_envs)) + logger.info(f"Found newly installed environment: '{env_name}'") + else: + # Fallback: try to find by matching module patterns + # Look for any env that might match the repo name pattern + repo_name = name.split("/")[-1] if "/" in name else name + repo_base = ( + repo_name.replace("-", "_").replace("_env", "").replace("_test", "") + ) + + env_name = None + for env_key, env_info in discovered_envs.items(): + # Check if env_key is a prefix/substring match + if env_key in repo_base or repo_base.startswith(env_key): + env_name = env_key + logger.info( + f"Found matching environment '{env_name}' for repo '{name}'" + ) + break + + if env_name is None: + # Last resort: use inferred name from repo + env_name = repo_name.replace("-", "_") + if not env_name.endswith("_env"): + env_name = f"{env_name}_env" + # Strip to get env_key + env_name = env_name.replace("_env", "") + logger.warning( + f"Could not find newly installed environment for repo '{name}', " + f"using inferred name: {env_name}" + ) + + # Cache the result to avoid redundant installs + _hub_env_name_cache[name] = env_name + + return env_name + + @classmethod + def from_env( + cls, + name: str, + base_url: Optional[str] = None, + docker_image: Optional[str] = None, + container_provider: Optional[ContainerProvider] = None, + wait_timeout: float = 30.0, + env_vars: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + skip_install: bool = False, + **kwargs: Any, + ) -> "EnvClient": + """ + Create an environment client from a name or HuggingFace Hub repository. + + This method automatically: + 1. Checks if the name is a HuggingFace Hub URL/repo ID + 2. If Hub: installs the environment package using git+ URL + 3. If local: looks up the installed openenv-* package + 4. Imports the client class and instantiates it + + Args: + name: Environment name or HuggingFace Hub repo ID + Examples: + - "coding" / "coding-env" / "coding_env" + - "meta-pytorch/coding-env" (Hub repo ID) + - "https://huggingface.co/meta-pytorch/coding-env" (Hub URL) + base_url: Optional base URL for HTTP connection + docker_image: Optional Docker image name (overrides default) + container_provider: Optional container provider + wait_timeout: Timeout for container startup (seconds) + env_vars: Optional environment variables for the container + trust_remote_code: If True, skip user confirmation when installing + from HuggingFace Hub. Can also be set via OPENENV_TRUST_REMOTE_CODE + environment variable. + skip_install: If True, skip package installation and return a + GenericEnvClient for remote environments. Useful when you only + want to connect to a running server without installing any + remote code. When True: + - If base_url is provided: connects directly using GenericEnvClient + - If HF Space is running: connects to Space using GenericEnvClient + - If HF Space is not running: uses Docker from HF registry + **kwargs: Additional arguments passed to the client class + + Returns: + Instance of the environment client class + + Raises: + ValueError: If environment not found or cannot be loaded + ImportError: If environment package is not installed + + Examples: + >>> # From installed package + >>> env = AutoEnv.from_env("coding-env") + >>> + >>> # From HuggingFace Hub + >>> env = AutoEnv.from_env("meta-pytorch/coding-env") + >>> + >>> # With custom Docker image + >>> env = AutoEnv.from_env("coding", docker_image="my-coding-env:v2") + >>> + >>> # With environment variables + >>> env = AutoEnv.from_env( + ... "dipg", + ... env_vars={"DIPG_DATASET_PATH": "/data/dipg"} + ... ) + >>> + >>> # Skip package installation, use GenericEnvClient + >>> env = AutoEnv.from_env( + ... "user/my-env", + ... skip_install=True + ... ) + """ + from openenv.core import GenericEnvClient + + # Handle skip_install mode - return GenericEnvClient without package installation + if skip_install: + # If base_url is provided, connect directly + if base_url: + if cls._check_server_availability(base_url): + logger.info( + f"Using GenericEnvClient for {base_url} (skip_install=True)" + ) + return GenericEnvClient(base_url=base_url, **kwargs) + else: + raise ConnectionError( + f"Server not available at {base_url}. " + f"Please ensure the server is running." + ) + + # If it's a Hub URL, try to connect to Space or use Docker + if _is_hub_url(name): + space_url = cls._resolve_space_url(name) + logger.info(f"Checking if HuggingFace Space is accessible: {space_url}") + + if cls._check_space_availability(space_url): + logger.info( + f"Using GenericEnvClient for Space {space_url} (skip_install=True)" + ) + return GenericEnvClient(base_url=space_url, **kwargs) + else: + # Space not running, use Docker from HF registry + logger.info( + f"Space not running at {space_url}, " + f"using GenericEnvClient with HF Docker registry" + ) + return run_async_safely( + GenericEnvClient.from_env( + name, + use_docker=True, + provider=container_provider, + env_vars=env_vars or {}, + **kwargs, + ) + ) + + # For local environments with skip_install, we need docker_image + if docker_image: + logger.info( + f"Using GenericEnvClient with Docker image {docker_image} " + f"(skip_install=True)" + ) + return run_async_safely( + GenericEnvClient.from_docker_image( + image=docker_image, + provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars or {}, + **kwargs, + ) + ) + else: + raise ValueError( + f"Cannot use skip_install=True for local environment '{name}' " + f"without providing base_url or docker_image. " + f"For local environments, either:\n" + f" 1. Provide base_url to connect to a running server\n" + f" 2. Provide docker_image to start a container\n" + f" 3. Set skip_install=False to use the installed package" + ) + + # Check if it's a HuggingFace Hub URL or repo ID + if _is_hub_url(name): + # Try to connect to Space directly first + space_url = cls._resolve_space_url(name) + logger.info(f"Checking if HuggingFace Space is accessible: {space_url}") + + space_is_available = cls._check_space_availability(space_url) + + if space_is_available and base_url is None: + # Space is accessible! We'll connect directly without Docker + logger.info(f"Space is accessible at: {space_url}") + logger.info("Installing package for client code (no Docker needed)...") + + # Ensure package is installed (uses git+ URL) + env_name = cls._ensure_package_from_hub( + name, trust_remote_code=trust_remote_code + ) + + # Set base_url to connect to remote Space + base_url = space_url + logger.info("Will connect to remote Space (no local Docker)") + else: + # Space not accessible or user provided explicit base_url + if not space_is_available: + logger.info(f"Space not accessible at {space_url}") + logger.info("Falling back to local Docker mode...") + + # Ensure package is installed (uses git+ URL) + env_name = cls._ensure_package_from_hub( + name, trust_remote_code=trust_remote_code + ) + else: + env_name = name + + # Get environment info from discovery + discovery = get_discovery() + env_info = discovery.get_environment_by_name(env_name) + + if not env_info: + # Environment not found - provide helpful error message + available_envs = discovery.discover() + + if not available_envs: + raise ValueError( + "No OpenEnv environments found.\n" + "Install an environment with: pip install openenv-\n" + "Or specify a HuggingFace Hub repository: AutoEnv.from_env('openenv/echo_env')" + ) + + # Try to suggest similar environment names + from difflib import get_close_matches + + env_keys = list(available_envs.keys()) + suggestions = get_close_matches(env_name, env_keys, n=3, cutoff=0.6) + + error_msg = f"Unknown environment '{env_name}'.\n" + if suggestions: + error_msg += f"Did you mean: {', '.join(suggestions)}?\n" + error_msg += f"Available environments: {', '.join(sorted(env_keys))}" + + raise ValueError(error_msg) + + # Get the client class + try: + client_class = env_info.get_client_class() + except ImportError as e: + raise ImportError( + f"Failed to import environment client for '{env_name}'.\n" + f"Package '{env_info.package_name}' appears to be installed but the module cannot be imported.\n" + f"Try reinstalling: pip install --force-reinstall {env_info.package_name}\n" + f"Original error: {e}" + ) from e + + # Determine Docker image to use + if docker_image is None: + docker_image = env_info.default_image + + # Create client instance + try: + if base_url: + # Check if the server at base_url is available + is_local = cls._is_local_url(base_url) + server_available = cls._check_server_availability(base_url) + + if server_available: + # Server is running, connect directly + logger.info( + f"✅ Server available at {base_url}, connecting directly" + ) + return client_class(base_url=base_url, provider=None, **kwargs) + elif is_local: + # Local server not running, auto-start Docker container + logger.info(f"❌ Server not available at {base_url}") + logger.info(f"🐳 Auto-starting Docker container: {docker_image}") + return run_async_safely( + client_class.from_docker_image( + image=docker_image, + provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars or {}, + **kwargs, + ) + ) + else: + # Remote server not available, cannot auto-start + raise ConnectionError( + f"Remote server not available at {base_url}. " + f"Please ensure the server is running." + ) + else: + # No base_url provided, start new Docker container + return run_async_safely( + client_class.from_docker_image( + image=docker_image, + provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars or {}, + **kwargs, + ) + ) + except Exception as e: + raise ValueError( + f"Failed to create environment client for '{env_name}'.\n" + f"Client class: {client_class.__name__}\n" + f"Docker image: {docker_image}\n" + f"Error: {e}" + ) from e + + @classmethod + def from_hub( + cls, + name: str, + base_url: Optional[str] = None, + docker_image: Optional[str] = None, + container_provider: Optional["ContainerProvider"] = None, + wait_timeout: float = 30.0, + env_vars: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + skip_install: bool = False, + **kwargs: Any, + ) -> "EnvClient": + """ + Create an environment client from a name or HuggingFace Hub repository. + + This is an alias for from_env() for backward compatibility. + + Args: + name: Environment name or HuggingFace Hub repo ID + base_url: Optional base URL for HTTP connection + docker_image: Optional Docker image name (overrides default) + container_provider: Optional container provider + wait_timeout: Timeout for container startup (seconds) + env_vars: Optional environment variables for the container + trust_remote_code: If True, skip user confirmation when installing + from HuggingFace Hub + skip_install: If True, skip package installation and return a + GenericEnvClient for remote environments + **kwargs: Additional arguments passed to the client class + + Returns: + Instance of the environment client class + + Examples: + >>> env = AutoEnv.from_hub("coding-env") + >>> env = AutoEnv.from_hub("meta-pytorch/coding-env") + """ + return cls.from_env( + name=name, + base_url=base_url, + docker_image=docker_image, + container_provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars, + trust_remote_code=trust_remote_code, + skip_install=skip_install, + **kwargs, + ) + + @classmethod + def get_env_class(cls, name: str): + """ + Get the environment client class without instantiating it. + + Args: + name: Environment name + + Returns: + The environment client class + + Raises: + ValueError: If environment not found + + Examples: + >>> CodingEnv = AutoEnv.get_env_class("coding") + >>> # Now you can instantiate it yourself + >>> env = CodingEnv(base_url="http://localhost:8000") + """ + discovery = get_discovery() + env_info = discovery.get_environment_by_name(name) + + if not env_info: + raise ValueError(f"Unknown environment: {name}") + + return env_info.get_client_class() + + @classmethod + def get_env_info(cls, name: str) -> Dict[str, Any]: + """ + Get detailed information about an environment. + + Args: + name: Environment name + + Returns: + Dictionary with environment metadata + + Raises: + ValueError: If environment not found + + Examples: + >>> info = AutoEnv.get_env_info("coding") + >>> print(info['description']) + 'Coding environment for OpenEnv' + >>> print(info['default_image']) + 'coding-env:latest' + """ + discovery = get_discovery() + env_info = discovery.get_environment_by_name(name) + + if not env_info: + raise ValueError(f"Unknown environment: {name}") + + return { + "env_key": env_info.env_key, + "name": env_info.name, + "package": env_info.package_name, + "version": env_info.version, + "description": env_info.description, + "env_class": env_info.client_class_name, + "action_class": env_info.action_class_name, + "observation_class": env_info.observation_class_name, + "module": env_info.client_module_path, + "default_image": env_info.default_image, + "spec_version": env_info.spec_version, + } + + @classmethod + def list_environments(cls) -> None: + """ + Print a formatted list of all available environments. + + This discovers all installed openenv-* packages and displays + their metadata in a user-friendly format. + + Examples: + >>> AutoEnv.list_environments() + Available OpenEnv Environments: + ---------------------------------------------------------------------- + echo : Echo Environment (v0.1.0) + Package: openenv-echo-env + coding : Coding Environment (v0.1.0) + Package: openenv-coding_env + ---------------------------------------------------------------------- + Total: 2 environments + """ + discovery = get_discovery() + discovery.list_environments() diff --git a/src/core/openenv/cli/__init__.py b/src/core/openenv/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40bee4e3ecf31e272806785b2cd7e05ff0000564 --- /dev/null +++ b/src/core/openenv/cli/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenEnv CLI package.""" + +__version__ = "0.1.0" diff --git a/src/core/openenv/cli/__main__.py b/src/core/openenv/cli/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b457cb7e1f430771bd310c120972a57c06cd661 --- /dev/null +++ b/src/core/openenv/cli/__main__.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv CLI entry point. + +This module provides the main entry point for the OpenEnv command-line interface, +following the Hugging Face CLI pattern. +""" + +import sys + +import typer +from openenv.cli.commands import build, fork, init, push, serve, skills, validate + +# Create the main CLI app +app = typer.Typer( + name="openenv", + help="OpenEnv - An e2e framework for creating, deploying and using isolated execution environments for agentic RL training", + no_args_is_help=True, +) + +# Register commands +app.command(name="init", help="Initialize a new OpenEnv environment")(init.init) +app.command(name="build", help="Build Docker images for OpenEnv environments")( + build.build +) +app.command( + name="validate", help="Validate environment structure and deployment readiness" +)(validate.validate) +app.command( + name="push", + help="Push an OpenEnv environment to Hugging Face Spaces or custom registry", +)(push.push) +app.command(name="serve", help="Serve environments locally (TODO: Phase 4)")( + serve.serve +) +app.command( + name="fork", + help="Fork (duplicate) a Hugging Face Space to your account", +)(fork.fork) +app.add_typer( + skills.app, + name="skills", + help="Manage OpenEnv skills for AI assistants", +) + + +# Entry point for setuptools +def main() -> None: + """Main entry point for the CLI.""" + try: + app() + except KeyboardInterrupt: + print("\nOperation cancelled by user.") + sys.exit(130) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/core/openenv/cli/_cli_utils.py b/src/core/openenv/cli/_cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b781bb3e34c60842fd8fc8b0eef7b700eb8461e0 --- /dev/null +++ b/src/core/openenv/cli/_cli_utils.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CLI utilities for OpenEnv command-line interface.""" + +from pathlib import Path +from typing import List + +from rich.console import Console + +# Create a console instance for CLI output +console = Console() + + +def validate_env_structure(env_dir: Path, strict: bool = False) -> List[str]: + """ + Validate that the directory follows OpenEnv environment structure. + + Args: + env_dir: Path to environment directory + strict: If True, enforce all optional requirements + + Returns: + List of validation warnings (empty if all checks pass) + + Raises: + FileNotFoundError: If required files are missing + """ + warnings = [] + + # Required files + required_files = [ + "openenv.yaml", + "__init__.py", + "client.py", + "models.py", + "README.md", + ] + + for file in required_files: + if not (env_dir / file).exists(): + raise FileNotFoundError(f"Required file missing: {file}") + + # Dockerfile: must exist in server/ or at env root + has_root_dockerfile = (env_dir / "Dockerfile").exists() + has_server_dockerfile = (env_dir / "server" / "Dockerfile").exists() + + if not has_root_dockerfile and not has_server_dockerfile: + raise FileNotFoundError( + "Required file missing: server/Dockerfile or Dockerfile at env root" + ) + + # When no root Dockerfile, require the traditional server/ layout + if not has_root_dockerfile: + server_dir = env_dir / "server" + if not server_dir.exists() or not server_dir.is_dir(): + raise FileNotFoundError("Required directory missing: server/") + + for file in ["server/__init__.py", "server/app.py"]: + if not (env_dir / file).exists(): + raise FileNotFoundError(f"Required file missing: {file}") + + # Check for dependency management (pyproject.toml required) + has_pyproject = (env_dir / "pyproject.toml").exists() + + if not has_pyproject: + raise FileNotFoundError( + "No dependency specification found. 'pyproject.toml' is required." + ) + + # Warnings for recommended structure + + if not (env_dir / "outputs").exists(): + warnings.append("Recommended directory missing: outputs/") + + return warnings diff --git a/src/core/openenv/cli/_validation.py b/src/core/openenv/cli/_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..60ea7cc58b3fb22c6247280ecbf64fe762433549 --- /dev/null +++ b/src/core/openenv/cli/_validation.py @@ -0,0 +1,594 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Validation utilities for multi-mode deployment readiness. + +This module provides functions to check if environments are properly +configured for multi-mode deployment (Docker, direct Python, notebooks, clusters). +""" + +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import requests + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib + + +def _make_criterion( + criterion_id: str, + description: str, + passed: bool, + *, + required: bool = True, + details: str | None = None, + expected: Any | None = None, + actual: Any | None = None, +) -> dict[str, Any]: + """Create a standard criterion result payload.""" + criterion: dict[str, Any] = { + "id": criterion_id, + "description": description, + "passed": passed, + "required": required, + } + if details is not None: + criterion["details"] = details + if expected is not None: + criterion["expected"] = expected + if actual is not None: + criterion["actual"] = actual + return criterion + + +def _normalize_runtime_url(base_url: str) -> str: + """Normalize and validate a runtime target URL.""" + target = base_url.strip() + if not target: + raise ValueError("Runtime URL cannot be empty") + + if "://" not in target: + target = f"http://{target}" + + parsed = urlparse(target) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"Invalid runtime URL: {base_url}") + + return target.rstrip("/") + + +def _runtime_standard_profile(api_version: str) -> str: + """Resolve the runtime standard profile for an API version.""" + if api_version.startswith("1."): + return "openenv-http/1.x" + return "openenv-http/unknown" + + +def _build_summary(criteria: list[dict[str, Any]]) -> dict[str, Any]: + """Build a compact pass/fail summary for a criteria list.""" + total_count = len(criteria) + passed_count = sum(1 for criterion in criteria if criterion.get("passed", False)) + failed_criteria = [ + criterion.get("id", "unknown") + for criterion in criteria + if not criterion.get("passed", False) + ] + required_criteria = [ + criterion for criterion in criteria if criterion.get("required", True) + ] + required_total_count = len(required_criteria) + required_passed_count = sum( + 1 for criterion in required_criteria if criterion.get("passed", False) + ) + + return { + "passed_count": passed_count, + "total_count": total_count, + "failed_criteria": failed_criteria, + "required_passed_count": required_passed_count, + "required_total_count": required_total_count, + } + + +def validate_running_environment( + base_url: str, timeout_s: float = 5.0 +) -> dict[str, Any]: + """ + Validate a running OpenEnv server against runtime API standards. + + The returned JSON report contains an overall pass/fail result and + per-criterion outcomes that can be consumed in CI. + """ + normalized_url = _normalize_runtime_url(base_url) + criteria: list[dict[str, Any]] = [] + + report: dict[str, Any] = { + "target": normalized_url, + "validation_type": "running_environment", + "standard_version": "unknown", + "standard_profile": "openenv-http/unknown", + "mode": "unknown", + "passed": False, + "summary": {}, + "criteria": criteria, + } + + openapi_paths: dict[str, Any] = {} + api_version = "unknown" + + # Criterion: OpenAPI endpoint reachable with a declared version. + try: + openapi_response = requests.get( + f"{normalized_url}/openapi.json", timeout=timeout_s + ) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "openapi_version_available", + "GET /openapi.json returns OpenAPI info.version", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "info.version": "string"}, + ) + ) + else: + try: + openapi_json = openapi_response.json() + except ValueError: + openapi_json = None + + openapi_ok = ( + openapi_response.status_code == 200 + and isinstance(openapi_json, dict) + and isinstance(openapi_json.get("info"), dict) + and isinstance(openapi_json["info"].get("version"), str) + ) + + if openapi_ok: + api_version = str(openapi_json["info"]["version"]) + openapi_paths = openapi_json.get("paths", {}) + criteria.append( + _make_criterion( + "openapi_version_available", + "GET /openapi.json returns OpenAPI info.version", + True, + expected={"status_code": 200, "info.version": "string"}, + actual={ + "status_code": openapi_response.status_code, + "info.version": api_version, + }, + ) + ) + else: + criteria.append( + _make_criterion( + "openapi_version_available", + "GET /openapi.json returns OpenAPI info.version", + False, + details="Response missing required OpenAPI info.version field", + expected={"status_code": 200, "info.version": "string"}, + actual={ + "status_code": openapi_response.status_code, + "body_type": ( + type(openapi_json).__name__ + if openapi_json is not None + else "non_json" + ), + }, + ) + ) + + report["standard_version"] = api_version + report["standard_profile"] = _runtime_standard_profile(api_version) + + # Criterion: Health endpoint. + try: + health_response = requests.get(f"{normalized_url}/health", timeout=timeout_s) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "health_endpoint", + "GET /health returns healthy status", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "status": "healthy"}, + ) + ) + else: + try: + health_json = health_response.json() + except ValueError: + health_json = None + + health_ok = ( + health_response.status_code == 200 + and isinstance(health_json, dict) + and health_json.get("status") == "healthy" + ) + criteria.append( + _make_criterion( + "health_endpoint", + "GET /health returns healthy status", + health_ok, + expected={"status_code": 200, "status": "healthy"}, + actual={ + "status_code": health_response.status_code, + "status": ( + health_json.get("status") + if isinstance(health_json, dict) + else None + ), + }, + ) + ) + + # Criterion: Metadata endpoint has required fields. + try: + metadata_response = requests.get( + f"{normalized_url}/metadata", timeout=timeout_s + ) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "metadata_endpoint", + "GET /metadata returns name and description", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "fields": ["name", "description"]}, + ) + ) + else: + try: + metadata_json = metadata_response.json() + except ValueError: + metadata_json = None + + metadata_ok = ( + metadata_response.status_code == 200 + and isinstance(metadata_json, dict) + and isinstance(metadata_json.get("name"), str) + and isinstance(metadata_json.get("description"), str) + ) + criteria.append( + _make_criterion( + "metadata_endpoint", + "GET /metadata returns name and description", + metadata_ok, + expected={"status_code": 200, "fields": ["name", "description"]}, + actual={ + "status_code": metadata_response.status_code, + "name": ( + metadata_json.get("name") + if isinstance(metadata_json, dict) + else None + ), + "description": ( + metadata_json.get("description") + if isinstance(metadata_json, dict) + else None + ), + }, + ) + ) + + # Criterion: Schema endpoint returns action/observation/state. + try: + schema_response = requests.get(f"{normalized_url}/schema", timeout=timeout_s) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "schema_endpoint", + "GET /schema returns action, observation, and state schemas", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={ + "status_code": 200, + "fields": ["action", "observation", "state"], + }, + ) + ) + else: + try: + schema_json = schema_response.json() + except ValueError: + schema_json = None + + schema_ok = ( + schema_response.status_code == 200 + and isinstance(schema_json, dict) + and isinstance(schema_json.get("action"), dict) + and isinstance(schema_json.get("observation"), dict) + and isinstance(schema_json.get("state"), dict) + ) + criteria.append( + _make_criterion( + "schema_endpoint", + "GET /schema returns action, observation, and state schemas", + schema_ok, + expected={ + "status_code": 200, + "fields": ["action", "observation", "state"], + }, + actual={ + "status_code": schema_response.status_code, + "has_action": ( + isinstance(schema_json.get("action"), dict) + if isinstance(schema_json, dict) + else False + ), + "has_observation": ( + isinstance(schema_json.get("observation"), dict) + if isinstance(schema_json, dict) + else False + ), + "has_state": ( + isinstance(schema_json.get("state"), dict) + if isinstance(schema_json, dict) + else False + ), + }, + ) + ) + + # Criterion: MCP endpoint is reachable. + try: + mcp_response = requests.post( + f"{normalized_url}/mcp", json={}, timeout=timeout_s + ) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "mcp_endpoint", + "POST /mcp is reachable and returns JSON-RPC payload", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "jsonrpc": "2.0"}, + ) + ) + else: + try: + mcp_json = mcp_response.json() + except ValueError: + mcp_json = None + + mcp_ok = ( + mcp_response.status_code == 200 + and isinstance(mcp_json, dict) + and mcp_json.get("jsonrpc") == "2.0" + ) + criteria.append( + _make_criterion( + "mcp_endpoint", + "POST /mcp is reachable and returns JSON-RPC payload", + mcp_ok, + expected={"status_code": 200, "jsonrpc": "2.0"}, + actual={ + "status_code": mcp_response.status_code, + "jsonrpc": ( + mcp_json.get("jsonrpc") if isinstance(mcp_json, dict) else None + ), + }, + ) + ) + + # Criterion: mode endpoint contract consistency via OpenAPI paths. + if isinstance(openapi_paths, dict) and openapi_paths: + has_reset = "/reset" in openapi_paths + has_step = "/step" in openapi_paths + has_state = "/state" in openapi_paths + + if has_reset: + report["mode"] = "simulation" + mode_ok = has_step and has_state + expected_paths = {"/reset": True, "/step": True, "/state": True} + else: + report["mode"] = "production" + mode_ok = not has_step and not has_state + expected_paths = {"/reset": False, "/step": False, "/state": False} + + criteria.append( + _make_criterion( + "mode_endpoint_consistency", + "OpenAPI endpoint set matches OpenEnv mode contract", + mode_ok, + expected=expected_paths, + actual={ + "/reset": has_reset, + "/step": has_step, + "/state": has_state, + }, + ) + ) + else: + criteria.append( + _make_criterion( + "mode_endpoint_consistency", + "OpenAPI endpoint set matches OpenEnv mode contract", + False, + details="Cannot determine mode without OpenAPI paths", + expected={"openapi.paths": "present"}, + actual={"openapi.paths": "missing"}, + ) + ) + + report["passed"] = all( + criterion["passed"] for criterion in criteria if criterion.get("required", True) + ) + report["summary"] = _build_summary(criteria) + return report + + +def validate_multi_mode_deployment(env_path: Path) -> tuple[bool, list[str]]: + """ + Validate that an environment is ready for multi-mode deployment. + + Checks: + 1. pyproject.toml exists + 2. uv.lock exists + 3. pyproject.toml has [project.scripts] with server entry point + 4. server/app.py has a main() function + 5. Required dependencies are present + + Returns: + Tuple of (is_valid, list of issues found) + """ + issues = [] + + # Check pyproject.toml exists + pyproject_path = env_path / "pyproject.toml" + if not pyproject_path.exists(): + issues.append("Missing pyproject.toml") + return False, issues + + # Check uv.lock exists + lockfile_path = env_path / "uv.lock" + if not lockfile_path.exists(): + issues.append("Missing uv.lock - run 'uv lock' to generate it") + + # Parse pyproject.toml + try: + with open(pyproject_path, "rb") as f: + pyproject = tomllib.load(f) + except Exception as e: + issues.append(f"Failed to parse pyproject.toml: {e}") + return False, issues + + # Check [project.scripts] section + scripts = pyproject.get("project", {}).get("scripts", {}) + if "server" not in scripts: + issues.append("Missing [project.scripts] server entry point") + + # Check server entry point format + server_entry = scripts.get("server", "") + if server_entry and ":main" not in server_entry: + issues.append( + f"Server entry point should reference main function, got: {server_entry}" + ) + + # Check required dependencies + deps = [dep.lower() for dep in pyproject.get("project", {}).get("dependencies", [])] + has_openenv = any( + dep.startswith("openenv") and not dep.startswith("openenv-core") for dep in deps + ) + has_legacy_core = any(dep.startswith("openenv-core") for dep in deps) + + if not (has_openenv or has_legacy_core): + issues.append( + "Missing required dependency: openenv-core>=0.2.0 (or openenv>=0.2.0)" + ) + + # Check server/app.py exists + server_app = env_path / "server" / "app.py" + if not server_app.exists(): + issues.append("Missing server/app.py") + else: + # Check for main() function (flexible - with or without parameters) + app_content = server_app.read_text(encoding="utf-8") + if "def main(" not in app_content: + issues.append("server/app.py missing main() function") + + # Check if main() is callable + if "__name__" not in app_content or "main()" not in app_content: + issues.append( + "server/app.py main() function not callable (missing if __name__ == '__main__')" + ) + + return len(issues) == 0, issues + + +def get_deployment_modes(env_path: Path) -> dict[str, bool]: + """ + Check which deployment modes are supported by the environment. + + Returns: + Dictionary with deployment mode names and whether they're supported + """ + modes = { + "docker": False, + "openenv_serve": False, + "uv_run": False, + "python_module": False, + } + + # Check Docker (Dockerfile may be in server/ or at env root) + modes["docker"] = (env_path / "server" / "Dockerfile").exists() or ( + env_path / "Dockerfile" + ).exists() + + # Check multi-mode deployment readiness + is_valid, _ = validate_multi_mode_deployment(env_path) + if is_valid: + modes["openenv_serve"] = True + modes["uv_run"] = True + modes["python_module"] = True + + return modes + + +def format_validation_report(env_name: str, is_valid: bool, issues: list[str]) -> str: + """ + Format a validation report for display. + + Returns: + Formatted report string + """ + if is_valid: + return f"[OK] {env_name}: Ready for multi-mode deployment" + + report = [f"[FAIL] {env_name}: Not ready for multi-mode deployment", ""] + report.append("Issues found:") + for issue in issues: + report.append(f" - {issue}") + + return "\n".join(report) + + +def build_local_validation_json_report( + env_name: str, + env_path: Path, + is_valid: bool, + issues: list[str], + deployment_modes: dict[str, bool] | None = None, +) -> dict[str, Any]: + """Build a JSON report for local environment validation.""" + criteria = [ + _make_criterion( + "multi_mode_deployment_readiness", + "Environment structure is ready for multi-mode deployment", + is_valid, + details="No issues found" if is_valid else f"{len(issues)} issue(s) found", + actual={"issues": issues}, + ) + ] + + if deployment_modes: + for mode, supported in deployment_modes.items(): + criteria.append( + _make_criterion( + f"deployment_mode_{mode}", + f"Deployment mode '{mode}' is supported", + supported, + required=False, + ) + ) + + return { + "target": str(env_path), + "environment": env_name, + "validation_type": "local_environment", + "standard_version": "local", + "standard_profile": "openenv-local", + "passed": is_valid, + "summary": _build_summary(criteria), + "criteria": criteria, + "issues": issues, + "deployment_modes": deployment_modes or {}, + } diff --git a/src/core/openenv/cli/commands/__init__.py b/src/core/openenv/cli/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f351a32ff5b353b05b1019005253e0c7cdf71c57 --- /dev/null +++ b/src/core/openenv/cli/commands/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenEnv CLI commands.""" + +from . import build, fork, init, push, serve, skills, validate + +__all__ = ["build", "fork", "init", "push", "serve", "skills", "validate"] diff --git a/src/core/openenv/cli/commands/build.py b/src/core/openenv/cli/commands/build.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4d91b0f80754b00943caaeab23774b09b6d987 --- /dev/null +++ b/src/core/openenv/cli/commands/build.py @@ -0,0 +1,461 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Build Docker images for OpenEnv environments.""" + +from __future__ import annotations + +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Annotated + +import typer + +from .._cli_utils import console + +app = typer.Typer(help="Build Docker images for OpenEnv environments") + + +def _detect_build_context(env_path: Path) -> tuple[str, Path, Path | None]: + """ + Detect whether we're building a standalone or in-repo environment. + + Returns: + tuple: (build_mode, build_context_path, repo_root) + - build_mode: "standalone" or "in-repo" + - build_context_path: Path to use as Docker build context + - repo_root: Path to repo root (None for standalone) + """ + # Ensure env_path is absolute for proper comparison + env_path = env_path.absolute() + + # Check if we're in a git repository + current = env_path + repo_root = None + + # Walk up to find .git directory + for parent in [current] + list(current.parents): + if (parent / ".git").exists(): + repo_root = parent + break + + if repo_root is None: + # Not in a git repo = standalone + return "standalone", env_path, None + + # Check if environment is under envs/ (in-repo pattern) + try: + rel_path = env_path.relative_to(repo_root) + rel_str = str(rel_path) + if ( + rel_str.startswith("envs/") + or rel_str.startswith("envs\\") + or rel_str.startswith("envs/") + ): + # In-repo environment + return "in-repo", repo_root, repo_root + except ValueError: + pass + + # Otherwise, it's standalone (environment outside repo structure) + return "standalone", env_path, None + + +def _prepare_standalone_build(env_path: Path, temp_dir: Path) -> Path: + """ + Prepare a standalone environment for building. + + For standalone builds: + 1. Copy environment to temp directory + 2. Ensure pyproject.toml depends on openenv + + Returns: + Path to the prepared build directory + """ + console.print("[cyan]Preparing standalone build...[/cyan]") + + # Copy environment to temp directory + build_dir = temp_dir / env_path.name + shutil.copytree(env_path, build_dir, symlinks=True) + + console.print(f"[cyan]Copied environment to:[/cyan] {build_dir}") + + # Check if pyproject.toml has openenv dependency + pyproject_path = build_dir / "pyproject.toml" + if pyproject_path.exists(): + with open(pyproject_path, "rb") as f: + try: + import tomli + + pyproject = tomli.load(f) + deps = pyproject.get("project", {}).get("dependencies", []) + + # Check if openenv dependency is declared + has_openenv = any(dep.startswith("openenv") for dep in deps) + + if not has_openenv: + console.print( + "[yellow]Warning:[/yellow] pyproject.toml doesn't list the openenv dependency", + ) + console.print( + "[yellow]You may need to add:[/yellow] openenv>=0.2.0", + ) + except ImportError: + console.print( + "[yellow]Warning:[/yellow] tomli not available, skipping dependency check", + ) + + return build_dir + + +def _prepare_inrepo_build(env_path: Path, repo_root: Path, temp_dir: Path) -> Path: + """ + Prepare an in-repo environment for building. + + For in-repo builds: + 1. Create temp directory with environment and core + 2. Set up structure that matches expected layout + + Returns: + Path to the prepared build directory + """ + console.print("[cyan]Preparing in-repo build...[/cyan]") + + # Copy environment to temp directory + build_dir = temp_dir / env_path.name + shutil.copytree(env_path, build_dir, symlinks=True) + + # Copy OpenEnv package metadata + sources to temp directory. + # Keep the src/ layout since pyproject.toml uses package-dir = {"" = "src"}. + package_src = repo_root / "src" / "openenv" + package_dest = build_dir / "openenv" + if package_src.exists(): + package_dest.mkdir(parents=True, exist_ok=True) + shutil.copytree(package_src, package_dest / "src" / "openenv", symlinks=True) + + for filename in ("pyproject.toml", "README.md"): + src_file = repo_root / filename + if src_file.exists(): + shutil.copy2(src_file, package_dest / filename) + + console.print(f"[cyan]Copied OpenEnv package to:[/cyan] {package_dest}") + + # Update pyproject.toml to reference local OpenEnv copy + pyproject_path = build_dir / "pyproject.toml" + if pyproject_path.exists(): + with open(pyproject_path, "rb") as f: + try: + import tomli + + pyproject = tomli.load(f) + deps = pyproject.get("project", {}).get("dependencies", []) + + # Replace openenv/openenv-core with local reference + new_deps = [] + for dep in deps: + if ( + dep.startswith("openenv-core") + or dep.startswith("openenv_core") + or dep.startswith("openenv") + ): + # Skip - we'll use local core + continue + new_deps.append(dep) + + # Write back with local core reference + pyproject["project"]["dependencies"] = new_deps + [ + "openenv-core @ file:///app/env/openenv" + ] + + # Write updated pyproject.toml + with open(pyproject_path, "wb") as out_f: + import tomli_w + + tomli_w.dump(pyproject, out_f) + + console.print( + "[cyan]Updated pyproject.toml to use local core[/cyan]" + ) + + # Remove old lockfile since dependencies changed + lockfile = build_dir / "uv.lock" + if lockfile.exists(): + lockfile.unlink() + console.print("[cyan]Removed outdated uv.lock[/cyan]") + + except ImportError: + console.print( + "[yellow]Warning:[/yellow] tomli/tomli_w not available, using pyproject.toml as-is", + ) + else: + console.print( + "[yellow]Warning:[/yellow] OpenEnv package not found, building without it" + ) + + console.print(f"[cyan]Build directory prepared:[/cyan] {build_dir}") + return build_dir + + +def _run_command( + cmd: list[str], + cwd: Path | None = None, + check: bool = True, +) -> subprocess.CompletedProcess: + """Run a shell command and handle errors.""" + console.print(f"[bold cyan]Running:[/bold cyan] {' '.join(cmd)}") + try: + result = subprocess.run( + cmd, cwd=cwd, check=check, capture_output=True, text=True + ) + if result.stdout: + console.print(result.stdout) + if result.stderr: + print(result.stderr, file=sys.stderr) + return result + except subprocess.CalledProcessError as e: + print(f"Error running command: {e}", file=sys.stderr) + if e.stdout: + console.print(e.stdout) + if e.stderr: + print(e.stderr, file=sys.stderr) + if check: + raise typer.Exit(1) from e + return e + + +def _build_docker_image( + env_path: Path, + tag: str | None = None, + context_path: Path | None = None, + dockerfile: Path | None = None, + build_args: dict[str, str] | None = None, + no_cache: bool = False, +) -> bool: + """Build Docker image for the environment with smart context detection.""" + + # Detect build context (standalone vs in-repo) + build_mode, detected_context, repo_root = _detect_build_context(env_path) + + console.print(f"[bold cyan]Build mode detected:[/bold cyan] {build_mode}") + + # Use detected context unless explicitly overridden + if context_path is None: + context_path = detected_context + + # Create temporary build directory + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + + # Prepare build directory based on mode + if build_mode == "standalone": + build_dir = _prepare_standalone_build(env_path, temp_dir) + else: # in-repo + build_dir = _prepare_inrepo_build(env_path, repo_root, temp_dir) + + # Determine Dockerfile path + if dockerfile is None: + # Look for Dockerfile in server/ subdirectory + dockerfile = build_dir / "server" / "Dockerfile" + if not dockerfile.exists(): + # Fallback to root of build directory + dockerfile = build_dir / "Dockerfile" + + if not dockerfile.exists(): + console.print( + f"[bold red]Error:[/bold red] Dockerfile not found at {dockerfile}", + ) + return False + + # Generate tag if not provided + if tag is None: + env_name = env_path.name + if env_name.endswith("_env"): + env_name = env_name[:-4] + tag = f"openenv-{env_name}" + + console.print(f"[bold cyan]Building Docker image:[/bold cyan] {tag}") + console.print(f"[bold cyan]Build context:[/bold cyan] {build_dir}") + console.print(f"[bold cyan]Dockerfile:[/bold cyan] {dockerfile}") + + # Prepare build args + if build_args is None: + build_args = {} + + # Add build mode and env name to build args + build_args["BUILD_MODE"] = build_mode + build_args["ENV_NAME"] = env_path.name.replace("_env", "") + + # Build Docker command + cmd = ["docker", "build", "-t", tag, "-f", str(dockerfile)] + + if no_cache: + cmd.append("--no-cache") + + for key, value in build_args.items(): + cmd.extend(["--build-arg", f"{key}={value}"]) + + cmd.append(str(build_dir)) + + result = _run_command(cmd, check=False) + return result.returncode == 0 + + +def _push_docker_image(tag: str, registry: str | None = None) -> bool: + """Push Docker image to registry.""" + if registry: + full_tag = f"{registry}/{tag}" + console.print(f"[bold cyan]Tagging image as {full_tag}[/bold cyan]") + _run_command(["docker", "tag", tag, full_tag]) + tag = full_tag + + console.print(f"[bold cyan]Pushing image:[/bold cyan] {tag}") + result = _run_command(["docker", "push", tag], check=False) + return result.returncode == 0 + + +@app.command() +def build( + env_path: Annotated[ + str | None, + typer.Argument( + help="Path to the environment directory (default: current directory)" + ), + ] = None, + tag: Annotated[ + str | None, + typer.Option( + "--tag", + "-t", + help="Docker image tag (default: openenv-)", + ), + ] = None, + context: Annotated[ + str | None, + typer.Option( + "--context", + "-c", + help="Build context path (default: /server)", + ), + ] = None, + dockerfile: Annotated[ + str | None, + typer.Option( + "--dockerfile", + "-f", + help="Path to Dockerfile (default: /Dockerfile)", + ), + ] = None, + no_cache: Annotated[ + bool, + typer.Option( + "--no-cache", + help="Build without using cache", + ), + ] = False, + build_arg: Annotated[ + list[str] | None, + typer.Option( + "--build-arg", + help="Build arguments (can be used multiple times, format: KEY=VALUE)", + ), + ] = None, +) -> None: + """ + Build Docker images for OpenEnv environments. + + This command builds Docker images using the environment's pyproject.toml + and uv for dependency management. Run from the environment root directory. + + Examples: + # Build from environment root (recommended) + $ cd my_env + $ openenv build + + # Build with custom tag + $ openenv build -t my-custom-tag + + # Build without cache + $ openenv build --no-cache + + # Build with custom build arguments + $ openenv build --build-arg VERSION=1.0 --build-arg ENV=prod + + # Build from different directory + $ openenv build envs/echo_env + """ + # Determine environment path (default to current directory) + if env_path is None: + env_path_obj = Path.cwd() + else: + env_path_obj = Path(env_path) + + # Validate environment path + if not env_path_obj.exists(): + print( + f"Error: Environment path does not exist: {env_path_obj}", + file=sys.stderr, + ) + raise typer.Exit(1) + + if not env_path_obj.is_dir(): + print( + f"Error: Environment path is not a directory: {env_path_obj}", + file=sys.stderr, + ) + raise typer.Exit(1) + + # Check for openenv.yaml to confirm this is an environment directory + openenv_yaml = env_path_obj / "openenv.yaml" + if not openenv_yaml.exists(): + print( + f"Error: Not an OpenEnv environment directory (missing openenv.yaml): {env_path_obj}", + file=sys.stderr, + ) + print( + "Hint: Run this command from the environment root directory or specify the path", + file=sys.stderr, + ) + raise typer.Exit(1) + + console.print(f"[bold]Building Docker image for:[/bold] {env_path_obj.name}") + console.print("=" * 60) + + # Parse build args + build_args = {} + if build_arg: + for arg in build_arg: + if "=" in arg: + key, value = arg.split("=", 1) + build_args[key] = value + else: + print( + f"Warning: Invalid build arg format: {arg}", + file=sys.stderr, + ) + + # Convert string paths to Path objects + context_path_obj = Path(context) if context else None + dockerfile_path_obj = Path(dockerfile) if dockerfile else None + + # Build Docker image + success = _build_docker_image( + env_path=env_path_obj, + tag=tag, + context_path=context_path_obj, + dockerfile=dockerfile_path_obj, + build_args=build_args if build_args else None, + no_cache=no_cache, + ) + + if not success: + print("✗ Docker build failed", file=sys.stderr) + raise typer.Exit(1) + + console.print("[bold green]✓ Docker build successful[/bold green]") + console.print("\n[bold green]Done![/bold green]") diff --git a/src/core/openenv/cli/commands/fork.py b/src/core/openenv/cli/commands/fork.py new file mode 100644 index 0000000000000000000000000000000000000000..e06f41f1d8606874446f8edc07d675eb4680fa32 --- /dev/null +++ b/src/core/openenv/cli/commands/fork.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fork (duplicate) a Hugging Face Space using the Hub API.""" + +from __future__ import annotations + +from typing import Annotated + +import typer +from huggingface_hub import HfApi, login, whoami + +from .._cli_utils import console + +app = typer.Typer( + help="Fork (duplicate) an OpenEnv environment on Hugging Face to your account" +) + + +def _parse_key_value(s: str) -> tuple[str, str]: + """Parse KEY=VALUE string. Raises BadParameter if no '='.""" + if "=" not in s: + raise typer.BadParameter( + f"Expected KEY=VALUE format, got: {s!r}. " + "Use --set-env KEY=VALUE or --set-secret KEY=VALUE" + ) + key, _, value = s.partition("=") + key = key.strip() + if not key: + raise typer.BadParameter(f"Empty key in: {s!r}") + return key, value.strip() + + +def _ensure_hf_authenticated() -> str: + """Ensure user is authenticated with Hugging Face. Returns username.""" + try: + user_info = whoami() + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + if not username: + raise ValueError("Could not extract username from whoami response") + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception: + console.print( + "[bold yellow]Not authenticated with Hugging Face. Please login...[/bold yellow]" + ) + try: + login() + user_info = whoami() + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + if not username: + raise ValueError("Could not extract username from whoami response") + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception as e: + raise typer.BadParameter( + f"Hugging Face authentication failed: {e}. Please run login manually." + ) from e + + +@app.command() +def fork( + source_space: Annotated[ + str, + typer.Argument( + help="Source Space ID in format 'owner/space-name' (e.g. org/my-openenv-space)" + ), + ], + repo_id: Annotated[ + str | None, + typer.Option( + "--repo-id", + "-r", + help="Target repo ID for the fork (default: created under your account with same name)", + ), + ] = None, + private: Annotated[ + bool, + typer.Option("--private", help="Create the forked Space as private"), + ] = False, + set_env: Annotated[ + list[str], + typer.Option( + "--set-env", + "-e", + help="Set Space variable (public). Can be repeated. Format: KEY=VALUE", + ), + ] = [], + set_secret: Annotated[ + list[str], + typer.Option( + "--set-secret", + "--secret", + "-s", + help="Set Space secret. Can be repeated. Format: KEY=VALUE", + ), + ] = [], + hardware: Annotated[ + str | None, + typer.Option( + "--hardware", + "-H", + help="Request hardware (e.g. t4-medium, cpu-basic). See Hub docs for options.", + ), + ] = None, +) -> None: + """ + Fork (duplicate) a Hugging Face Space to your account using the Hub API. + + Uses the Hugging Face duplicate_space API. You can set environment variables + and secrets, and request hardware/storage/sleep time at creation time. + + Examples: + $ openenv fork owner/source-space + $ openenv fork owner/source-space --private + $ openenv fork owner/source-space --repo-id myuser/my-fork + $ openenv fork owner/source-space --set-env MODEL_ID=user/model --set-secret HF_TOKEN=hf_xxx + $ openenv fork owner/source-space --hardware t4-medium + """ + if "/" not in source_space or source_space.count("/") != 1: + raise typer.BadParameter( + f"Invalid source Space ID: {source_space!r}. Expected format: 'owner/space-name'" + ) + + _ensure_hf_authenticated() + api = HfApi() + + # Build kwargs for duplicate_space (only pass what we have) + dup_kwargs: dict = { + "from_id": source_space, + "private": private, + } + if set_env: + dup_kwargs["variables"] = [ + {"key": k, "value": v} for k, v in (_parse_key_value(x) for x in set_env) + ] + if set_secret: + dup_kwargs["secrets"] = [ + {"key": k, "value": v} for k, v in (_parse_key_value(x) for x in set_secret) + ] + # HF API requires hardware when duplicating; default to free cpu-basic + dup_kwargs["hardware"] = hardware if hardware is not None else "cpu-basic" + if repo_id is not None: + if "/" not in repo_id or repo_id.count("/") != 1: + raise typer.BadParameter( + f"Invalid --repo-id: {repo_id!r}. Expected format: 'username/repo-name'" + ) + dup_kwargs["to_id"] = repo_id + + console.print(f"[bold cyan]Forking Space {source_space}...[/bold cyan]") + try: + result = api.duplicate_space(**dup_kwargs) + except Exception as e: + console.print(f"[bold red]✗[/bold red] Fork failed: {e}") + raise typer.Exit(1) from e + + # result is RepoUrl (str-like) or similar; get repo_id for display + if hasattr(result, "repo_id"): + new_repo_id = result.repo_id + elif isinstance(result, str): + # URL like https://huggingface.co/spaces/owner/name -> owner/name + if "/spaces/" in result: + new_repo_id = result.split("/spaces/")[-1].rstrip("/") + else: + new_repo_id = result + else: + new_repo_id = getattr(result, "repo_id", str(result)) + + console.print("[bold green]✓[/bold green] Space forked successfully") + console.print( + f"[bold]Space URL:[/bold] https://huggingface.co/spaces/{new_repo_id}" + ) diff --git a/src/core/openenv/cli/commands/init.py b/src/core/openenv/cli/commands/init.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf0fc7168f109657157f3d71600857e5b91f37e --- /dev/null +++ b/src/core/openenv/cli/commands/init.py @@ -0,0 +1,500 @@ +"""Initialize a new OpenEnv environment.""" + +from __future__ import annotations + +import random +import shutil +import subprocess +from importlib import resources +from pathlib import Path +from typing import Annotated, Dict, List, Tuple + +import typer + +from .._cli_utils import console + +app = typer.Typer(help="Initialize a new OpenEnv environment") + + +def _snake_to_pascal(snake_str: str) -> str: + """Convert snake_case to PascalCase (e.g., 'my_env' -> 'MyEnv').""" + return "".join(word.capitalize() for word in snake_str.split("_")) + + +def _get_env_prefix(env_name: str) -> str: + """Extract the prefix for class names (e.g., 'my_env' -> 'My', 'test_env' -> 'Test').""" + # Remove trailing '_env' if present + if env_name.endswith("_env"): + base = env_name[:-4] # Remove '_env' + else: + base = env_name + + # If empty or just one part, use the whole thing + if not base or "_" not in base: + return base.capitalize() if base else env_name.capitalize() + + # PascalCase all parts except the last + parts = base.split("_") + return "".join(word.capitalize() for word in parts) + + +def _snake_to_camel(snake_str: str) -> str: + """Convert snake_case to camelCase (e.g., 'my_env' -> 'myEnv').""" + parts = snake_str.split("_") + return parts[0] + "".join(word.capitalize() for word in parts[1:]) + + +def _snake_to_title(snake_str: str) -> str: + """Convert snake_case to Title Case (e.g., 'my_env' -> 'My Env').""" + return " ".join(word.capitalize() for word in snake_str.split("_")) + + +def _validate_env_name(name: str) -> str: + """Validate environment name (must be valid Python identifier in snake_case).""" + if not name: + raise typer.BadParameter("Environment name cannot be empty") + + # Check if it's a valid Python identifier + if not name.isidentifier(): + raise typer.BadParameter( + f"Environment name '{name}' is not a valid Python identifier. Use snake_case (e.g., 'my_env', 'game_env')." + ) + + # Check if it starts with a number + if name[0].isdigit(): + raise typer.BadParameter( + f"Environment name '{name}' cannot start with a number." + ) + + return name + + +def _get_random_hf_space_config() -> Dict[str, str]: + """ + Get random Hugging Face Space configuration values. + + Returns: + Dictionary with 'emoji', 'colorFrom', and 'colorTo' keys + """ + # Valid emojis (emoji-only characters) + emojis = [ + "🎮", + "🎯", + "🚀", + "🌟", + "🎨", + "🎪", + "🎭", + "🎬", + "🎤", + "🎧", + "🎵", + "🎶", + "🎸", + "🎹", + "🥁", + "🎺", + "🎻", + "🎼", + "🎯", + "🎲", + "🎳", + "🎰", + "🎴", + "🃏", + "🀄", + "🎴", + "🎨", + "🖼️", + "🎬", + "🎭", + "🎪", + "🎤", + "🎧", + "🎵", + "🎶", + "🎸", + "🎹", + "🎺", + "🎻", + "🥁", + "🎯", + "🎲", + "🎳", + "🎰", + "🏀", + "⚽", + "🏈", + "⚾", + "🎾", + "🏐", + "🏉", + "🎱", + "🏓", + "🏸", + "🥅", + "🏒", + "🏑", + "🏏", + "⛳", + "🏹", + "🎣", + "🥊", + "🥋", + "🎽", + "🏅", + "🎖️", + "🏆", + "🥇", + "🥈", + "🥉", + "🔊", + "🔉", + "🔈", + "🔇", + "📢", + "📣", + "📯", + "🔔", + "🔕", + "📻", + "📡", + "💻", + "🖥️", + "🖨️", + "⌨️", + "🖱️", + "🖲️", + "🕹️", + "🗜️", + "💾", + "💿", + "📀", + "📼", + "📷", + "📸", + "📹", + "🎥", + "📽️", + "🎞️", + "📞", + "☎️", + "📟", + "📠", + "📺", + "📻", + "🎙️", + "🎚️", + "🎛️", + "⏱️", + "⏲️", + "⏰", + "🕰️", + "⌚", + "📱", + "📲", + "💻", + "⌨️", + "🖥️", + "🖨️", + "🖱️", + ] + + # Valid colors from HF Spaces config reference + colors = ["red", "yellow", "green", "blue", "indigo", "purple", "pink", "gray"] + + return { + "emoji": random.choice(emojis), + "colorFrom": random.choice(colors), + "colorTo": random.choice(colors), + } + + +def _create_template_replacements(env_name: str) -> Dict[str, str]: + """ + Create comprehensive template replacement dictionary. + + Supports all naming conventions: + - PascalCase for class names + - camelCase for variable names + - snake_case for module names, file paths + """ + env_prefix = _get_env_prefix(env_name) + env_camel = _snake_to_camel(env_name) + env_title = _snake_to_title(env_name) + + # Get random HF Space config values + hf_config = _get_random_hf_space_config() + + replacements = { + # Template placeholders (MUST come first - full class names before partial) + "__ENV_CLASS_NAME__Environment": f"{env_prefix}Environment", + "__ENV_CLASS_NAME__Action": f"{env_prefix}Action", + "__ENV_CLASS_NAME__Observation": f"{env_prefix}Observation", + "__ENV_CLASS_NAME__Env": f"{env_prefix}Env", + # Template placeholders (partial - must come after full replacements) + "__ENV_NAME__": env_name, + "__ENV_CLASS_NAME__": env_prefix, # Use prefix, not full PascalCase + "__ENV_TITLE_NAME__": env_title, + "__ENV_CAMEL_NAME__": env_camel, + # Hugging Face Space config placeholders + "__HF_EMOJI__": hf_config["emoji"], + "__HF_COLOR_FROM__": hf_config["colorFrom"], + "__HF_COLOR_TO__": hf_config["colorTo"], + } + + return replacements + + +def _replace_in_content(content: str, replacements: Dict[str, str]) -> str: + """Replace all occurrences in content using case-sensitive replacements.""" + result = content + # Sort by length (longest first) to avoid partial replacements + for old, new in sorted(replacements.items(), key=lambda x: len(x[0]), reverse=True): + result = result.replace(old, new) + return result + + +def _should_rename_file(filename: str, env_name: str) -> Tuple[bool, str]: + """ + Check if a file should be renamed and return the new name. + + Handles template placeholders in filenames like: + - `__ENV_NAME___environment.py` → `_environment.py` + """ + # Check for template placeholder + if "__ENV_NAME__" in filename: + new_name = filename.replace("__ENV_NAME__", env_name) + return True, new_name + + return False, filename + + +def _copy_and_template_file( + src_path: Path, + dest_path: Path, + replacements: Dict[str, str], +) -> None: + """Copy a file and apply template replacements.""" + dest_path.parent.mkdir(parents=True, exist_ok=True) + + try: + # Read source file + content = src_path.read_bytes() + + # Try to decode as text and apply replacements + try: + text = content.decode("utf-8") + # Normalize line endings to LF before applying replacements + text = text.replace("\r\n", "\n").replace("\r", "\n") + text = _replace_in_content(text, replacements) + dest_path.write_text(text, encoding="utf-8", newline="\n") + except UnicodeDecodeError: + # Binary file, just copy + dest_path.write_bytes(content) + except Exception as e: + raise RuntimeError( + f"Failed to copy template file {src_path} to {dest_path}: {e}" + ) from e + + +def _copy_template_directory( + template_pkg: str, + template_dir: str, + dest_dir: Path, + replacements: Dict[str, str], + env_name: str, +) -> List[Path]: + """Recursively copy template directory and apply replacements.""" + created_files: List[Path] = [] + + # Get the package path using importlib.resources but avoid importing the template package + # We'll use the package's __file__ to get the directory path + import importlib + + try: + # Import the parent package (not the template package itself) + if "." in template_pkg: + parent_pkg = ".".join(template_pkg.split(".")[:-1]) + pkg = importlib.import_module(parent_pkg) + template_path = Path(pkg.__file__).parent / template_pkg.split(".")[-1] + else: + pkg = importlib.import_module(template_pkg.split(".")[0]) + template_path = Path(pkg.__file__).parent / template_pkg.split(".")[-1] + except Exception: + # Fallback: try to use resources.files but handle import errors + try: + base = resources.files(template_pkg.split(".")[0]) + template_path = base.joinpath(*template_pkg.split(".")[1:]) + if not template_path.exists(): + raise FileNotFoundError(f"Template directory not found: {template_pkg}") + except Exception as e: + raise FileNotFoundError( + f"Template directory not found: {template_pkg}" + ) from e + + if template_dir: + template_path = template_path / template_dir + + if not template_path.exists() or not template_path.is_dir(): + raise FileNotFoundError( + f"Template directory not found: {template_pkg}.{template_dir}" + ) + + # Walk through all files in template directory using Path + for item in template_path.rglob("*"): + if item.is_file(): + rel_path = item.relative_to(template_path) + dest_path = dest_dir / rel_path + + # Apply filename templating + should_rename, new_name = _should_rename_file(dest_path.name, env_name) + if should_rename: + dest_path = dest_path.parent / new_name + + # Copy and apply replacements + _copy_and_template_file(item, dest_path, replacements) + created_files.append(dest_path) + + return created_files + + +def _generate_uv_lock(env_dir: Path) -> bool: + """Generate uv.lock from pyproject.toml using uv.""" + pyproject_path = env_dir / "pyproject.toml" + + if not pyproject_path.exists(): + return False + + try: + cmd = [ + "uv", + "lock", + "--directory", + str(env_dir), + ] + + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + + if result.stdout: + console.print(result.stdout) + + return True + + except subprocess.CalledProcessError as e: + console.print( + f"[yellow]Warning: Could not generate uv.lock: {e.stderr}[/yellow]" + ) + return False + except FileNotFoundError: + console.print( + "[yellow]Warning: 'uv' not found. Install it to generate uv.lock[/yellow]" + ) + return False + + +@app.command() +def init( + env_name: Annotated[ + str, + typer.Argument( + help="Name of the environment to create (snake_case, e.g., 'my_env')" + ), + ], + output_dir: Annotated[ + str | None, + typer.Option( + "--output-dir", + "-o", + help="Output directory (defaults to current working directory)", + ), + ] = None, +) -> None: + """ + Initialize a new OpenEnv environment. + + Creates a new directory with the environment name and generates all necessary + files based on the OpenEnv template structure. + + Example: + $ openenv init my_game_env + $ openenv init my_env --output-dir /path/to/projects + """ + # Validate environment name + env_name = _validate_env_name(env_name) + + # Determine output directory + base_dir = Path(output_dir).resolve() if output_dir else Path.cwd().resolve() + env_dir = base_dir / env_name + + # Check if directory already exists + if env_dir.exists(): + if env_dir.is_file(): + raise typer.BadParameter(f"Path '{env_dir}' exists and is a file") + if any(env_dir.iterdir()): + raise typer.BadParameter( + f"Directory '{env_dir}' already exists and is not empty. " + "Please choose a different name or remove the existing directory." + ) + + try: + # Create template replacements + replacements = _create_template_replacements(env_name) + + # Create environment directory + env_dir.mkdir(parents=True, exist_ok=True) + + console.print( + f"[bold cyan]Creating OpenEnv environment '{env_name}'...[/bold cyan]" + ) + + # Copy template files from template structure + template_pkg = "openenv.cli.templates.openenv_env" + created_files = _copy_template_directory( + template_pkg, + "", + env_dir, + replacements, + env_name, + ) + + console.print(f"[bold green]✓[/bold green] Created {len(created_files)} files") + + # Generate uv.lock + console.print("\n[bold]Generating uv.lock...[/bold]") + if _generate_uv_lock(env_dir): + console.print("[green]✓[/green] Generated uv.lock") + else: + console.print("[yellow]⚠[/yellow] Could not generate uv.lock automatically") + console.print(" You can generate it manually with:") + console.print(f" cd {env_dir} && uv lock") + + console.print( + f"\n[bold green]Environment created successfully at: {env_dir}[/bold green]" + ) + console.print("\n[bold]Next steps:[/bold]") + console.print(f" cd {env_dir}") + console.print( + f" # Edit your environment implementation in server/{env_name}_environment.py" + ) + console.print(" # Edit your models in models.py") + console.print(" # Install dependencies: uv sync") + console.print("\n # To integrate into OpenEnv repo:") + console.print(f" # 1. Copy this directory to /envs/{env_name}_env") + console.print( + f" # 2. Build from repo root: docker build -t {env_name}_env:latest -f envs/{env_name}_env/server/Dockerfile ." + ) + console.print( + f" # 3. Run your image: docker run -p 8000:8000 {env_name}_env:latest" + ) + + except Exception as e: + # Cleanup on error + if env_dir.exists() and env_dir.is_dir(): + try: + shutil.rmtree(env_dir) + except Exception: + pass + + console.print(f"[bold red]Error:[/bold red] {e}") + raise typer.Exit(1) from e diff --git a/src/core/openenv/cli/commands/push.py b/src/core/openenv/cli/commands/push.py new file mode 100644 index 0000000000000000000000000000000000000000..beb571c239734d8f826a42e91f34cdd5845a44ff --- /dev/null +++ b/src/core/openenv/cli/commands/push.py @@ -0,0 +1,718 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Push an OpenEnv environment to Hugging Face Spaces.""" + +from __future__ import annotations + +import shutil +import sys +import tempfile +from fnmatch import fnmatch +from pathlib import Path +from typing import Annotated + +import typer +import yaml +from huggingface_hub import HfApi, login, whoami + +from .._cli_utils import console, validate_env_structure + +app = typer.Typer(help="Push an OpenEnv environment to Hugging Face Spaces") + + +DEFAULT_PUSH_IGNORE_PATTERNS = [".*", "__pycache__", "*.pyc"] + + +def _path_matches_pattern(relative_path: Path, pattern: str) -> bool: + """Return True if a relative path matches an exclude pattern.""" + normalized_pattern = pattern.strip() + if normalized_pattern.startswith("!"): + return False + + while normalized_pattern.startswith("./"): + normalized_pattern = normalized_pattern[2:] + + if normalized_pattern.startswith("/"): + normalized_pattern = normalized_pattern[1:] + + if not normalized_pattern: + return False + + posix_path = relative_path.as_posix() + pattern_candidates = [normalized_pattern] + if normalized_pattern.startswith("**/"): + # Gitignore-style "**/" can also match directly at the root. + pattern_candidates.append(normalized_pattern[3:]) + + # Support directory patterns such as "artifacts/" and "**/outputs/". + if normalized_pattern.endswith("/"): + dir_pattern_candidates: list[str] = [] + for candidate in pattern_candidates: + base = candidate.rstrip("/") + if not base: + continue + dir_pattern_candidates.extend([base, f"{base}/*"]) + + return any( + fnmatch(posix_path, candidate) for candidate in dir_pattern_candidates + ) + + # Match both full relative path and basename for convenience. + return any( + fnmatch(posix_path, candidate) for candidate in pattern_candidates + ) or any(fnmatch(relative_path.name, candidate) for candidate in pattern_candidates) + + +def _should_exclude_path(relative_path: Path, ignore_patterns: list[str]) -> bool: + """Return True when the path should be excluded from staging/upload.""" + return any( + _path_matches_pattern(relative_path, pattern) for pattern in ignore_patterns + ) + + +def _read_ignore_file(ignore_path: Path) -> tuple[list[str], int]: + """Read ignore patterns from a file and return (patterns, ignored_negations).""" + patterns: list[str] = [] + ignored_negations = 0 + + for line in ignore_path.read_text().splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if stripped.startswith("!"): + ignored_negations += 1 + continue + patterns.append(stripped) + + return patterns, ignored_negations + + +def _load_ignore_patterns(env_dir: Path, exclude_file: str | None) -> list[str]: + """Load ignore patterns from defaults and an optional ignore file.""" + patterns = list(DEFAULT_PUSH_IGNORE_PATTERNS) + ignored_negations = 0 + + def _merge_ignore_file(ignore_path: Path, *, source_label: str) -> None: + nonlocal ignored_negations + file_patterns, skipped_negations = _read_ignore_file(ignore_path) + patterns.extend(file_patterns) + ignored_negations += skipped_negations + console.print( + f"[bold green]✓[/bold green] Loaded {len(file_patterns)} ignore patterns from {source_label}: {ignore_path}" + ) + + # Optional source: explicit exclude file from CLI. + if exclude_file: + ignore_path = Path(exclude_file) + if not ignore_path.is_absolute(): + ignore_path = env_dir / ignore_path + ignore_path = ignore_path.resolve() + + if not ignore_path.exists() or not ignore_path.is_file(): + raise typer.BadParameter( + f"Exclude file not found or not a file: {ignore_path}" + ) + + _merge_ignore_file(ignore_path, source_label="--exclude") + + # Keep stable order while removing duplicates. + patterns = list(dict.fromkeys(patterns)) + + if ignored_negations > 0: + console.print( + f"[bold yellow]⚠[/bold yellow] Skipped {ignored_negations} negated ignore patterns ('!') because negation is not supported for push excludes" + ) + + return patterns + + +def _copytree_ignore_factory(env_dir: Path, ignore_patterns: list[str]): + """Build a shutil.copytree ignore callback from path-based patterns.""" + + def _ignore(path: str, names: list[str]) -> set[str]: + current_dir = Path(path) + ignored: set[str] = set() + + for name in names: + candidate = current_dir / name + try: + relative_path = candidate.relative_to(env_dir) + except ValueError: + # candidate is not under env_dir (e.g. symlink or + # copytree root differs from env_dir); skip filtering. + continue + if _should_exclude_path(relative_path, ignore_patterns): + ignored.add(name) + + return ignored + + return _ignore + + +def _validate_openenv_directory(directory: Path) -> tuple[str, dict]: + """ + Validate that the directory is an OpenEnv environment. + + Returns: + Tuple of (env_name, manifest_data) + """ + # Use the comprehensive validation function + try: + warnings = validate_env_structure(directory) + for warning in warnings: + console.print(f"[bold yellow]⚠[/bold yellow] {warning}") + except FileNotFoundError as e: + raise typer.BadParameter(f"Invalid OpenEnv environment structure: {e}") from e + + # Load and validate manifest + manifest_path = directory / "openenv.yaml" + try: + with open(manifest_path, "r") as f: + manifest = yaml.safe_load(f) + except Exception as e: + raise typer.BadParameter(f"Failed to parse openenv.yaml: {e}") from e + + if not isinstance(manifest, dict): + raise typer.BadParameter("openenv.yaml must be a YAML dictionary") + + env_name = manifest.get("name") + if not env_name: + raise typer.BadParameter("openenv.yaml must contain a 'name' field") + + return env_name, manifest + + +def _ensure_hf_authenticated() -> str: + """ + Ensure user is authenticated with Hugging Face. + + Returns: + Username of authenticated user + """ + try: + # Try to get current user + user_info = whoami() + # Handle both dict and object return types + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + # If it's an object, try to get name attribute + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + + if not username: + raise ValueError("Could not extract username from whoami response") + + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception: + # Not authenticated, prompt for login + console.print( + "[bold yellow]Not authenticated with Hugging Face. Please login...[/bold yellow]" + ) + + try: + login() + # Verify login worked + user_info = whoami() + # Handle both dict and object return types + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + + if not username: + raise ValueError("Could not extract username from whoami response") + + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception as e: + raise typer.BadParameter( + f"Hugging Face authentication failed: {e}. Please run login manually." + ) from e + + +def _prepare_staging_directory( + env_dir: Path, + env_name: str, + staging_dir: Path, + ignore_patterns: list[str], + base_image: str | None = None, + enable_interface: bool = True, +) -> None: + """ + Prepare files for deployment. + + This includes: + - Copying necessary files + - Modifying Dockerfile to optionally enable web interface and update base image + - Ensuring README has proper HF frontmatter (if interface enabled) + """ + # Create staging directory structure + staging_dir.mkdir(parents=True, exist_ok=True) + + # Copy all files from env directory + copy_ignore = _copytree_ignore_factory(env_dir, ignore_patterns) + for item in env_dir.iterdir(): + relative_path = item.relative_to(env_dir) + if _should_exclude_path(relative_path, ignore_patterns): + continue + + dest = staging_dir / item.name + if item.is_dir(): + shutil.copytree(item, dest, dirs_exist_ok=True, ignore=copy_ignore) + else: + shutil.copy2(item, dest) + + # Dockerfile must be at repo root for Hugging Face. Prefer root if present + # (it was copied there); otherwise move server/Dockerfile to root. + dockerfile_server_path = staging_dir / "server" / "Dockerfile" + dockerfile_root_path = staging_dir / "Dockerfile" + dockerfile_path: Path | None = None + + if dockerfile_root_path.exists(): + dockerfile_path = dockerfile_root_path + elif dockerfile_server_path.exists(): + dockerfile_server_path.rename(dockerfile_root_path) + console.print( + "[bold cyan]Moved Dockerfile to repository root for deployment[/bold cyan]" + ) + dockerfile_path = dockerfile_root_path + + # Modify Dockerfile to optionally enable web interface and update base image + if dockerfile_path and dockerfile_path.exists(): + dockerfile_content = dockerfile_path.read_text() + lines = dockerfile_content.split("\n") + new_lines = [] + cmd_found = False + base_image_updated = False + web_interface_env_exists = "ENABLE_WEB_INTERFACE" in dockerfile_content + last_instruction = None + + for line in lines: + stripped = line.strip() + token = stripped.split(maxsplit=1)[0] if stripped else "" + current_instruction = token.upper() + + is_healthcheck_continuation = last_instruction == "HEALTHCHECK" + + # Update base image if specified + if base_image and stripped.startswith("FROM") and not base_image_updated: + new_lines.append(f"FROM {base_image}") + base_image_updated = True + last_instruction = "FROM" + continue + + if ( + stripped.startswith("CMD") + and not cmd_found + and not web_interface_env_exists + and enable_interface + and not is_healthcheck_continuation + ): + new_lines.append("ENV ENABLE_WEB_INTERFACE=true") + cmd_found = True + + new_lines.append(line) + + if current_instruction: + last_instruction = current_instruction + + if not cmd_found and not web_interface_env_exists and enable_interface: + new_lines.append("ENV ENABLE_WEB_INTERFACE=true") + + if base_image and not base_image_updated: + new_lines.insert(0, f"FROM {base_image}") + + dockerfile_path.write_text("\n".join(new_lines)) + + changes = [] + if base_image and base_image_updated: + changes.append("updated base image") + if enable_interface and not web_interface_env_exists: + changes.append("enabled web interface") + if changes: + console.print( + f"[bold green]✓[/bold green] Updated Dockerfile: {', '.join(changes)}" + ) + else: + console.print( + "[bold yellow]⚠[/bold yellow] No Dockerfile at server/ or repo root" + ) + + # Ensure README has proper HF frontmatter (only if interface enabled) + if enable_interface: + readme_path = staging_dir / "README.md" + if readme_path.exists(): + readme_content = readme_path.read_text() + if "base_path: /web" not in readme_content: + # Check if frontmatter exists + if readme_content.startswith("---"): + # Add base_path to existing frontmatter + lines = readme_content.split("\n") + new_lines = [] + _in_frontmatter = True + for i, line in enumerate(lines): + new_lines.append(line) + if line.strip() == "---" and i > 0: + # End of frontmatter, add base_path before this line + if "base_path:" not in "\n".join(new_lines): + new_lines.insert(-1, "base_path: /web") + _in_frontmatter = False + readme_path.write_text("\n".join(new_lines)) + else: + # No frontmatter, add it + frontmatter = f"""--- +title: {env_name.replace("_", " ").title()} Environment Server +emoji: 🔊 +colorFrom: '#00C9FF' +colorTo: '#1B2845' +sdk: docker +pinned: false +app_port: 8000 +base_path: /web +tags: + - openenv +--- + +""" + readme_path.write_text(frontmatter + readme_content) + console.print( + "[bold green]✓[/bold green] Updated README with HF Space frontmatter" + ) + else: + console.print("[bold yellow]⚠[/bold yellow] No README.md found") + + +def _create_hf_space( + repo_id: str, + api: HfApi, + private: bool = False, +) -> None: + """Create a Hugging Face Space if it doesn't exist.""" + console.print(f"[bold cyan]Creating/verifying space: {repo_id}[/bold cyan]") + + try: + api.create_repo( + repo_id=repo_id, + repo_type="space", + space_sdk="docker", + private=private, + exist_ok=True, + ) + console.print(f"[bold green]✓[/bold green] Space {repo_id} is ready") + except Exception as e: + # Space might already exist, which is okay with exist_ok=True + # But if there's another error, log it + console.print(f"[bold yellow]⚠[/bold yellow] Space creation: {e}") + + +def _upload_to_hf_space( + repo_id: str, + staging_dir: Path, + api: HfApi, + ignore_patterns: list[str], + private: bool = False, + create_pr: bool = False, + commit_message: str | None = None, +) -> None: + """Upload files to Hugging Face Space.""" + if create_pr: + console.print( + f"[bold cyan]Uploading files to {repo_id} (will open a Pull Request)...[/bold cyan]" + ) + else: + console.print(f"[bold cyan]Uploading files to {repo_id}...[/bold cyan]") + + upload_kwargs: dict = { + "folder_path": str(staging_dir), + "repo_id": repo_id, + "repo_type": "space", + "create_pr": create_pr, + "ignore_patterns": ignore_patterns, + } + if commit_message: + upload_kwargs["commit_message"] = commit_message + + try: + result = api.upload_folder(**upload_kwargs) + console.print("[bold green]✓[/bold green] Upload completed successfully") + if create_pr and result is not None and hasattr(result, "pr_url"): + console.print(f"[bold]Pull request:[/bold] {result.pr_url}") + console.print( + f"[bold]Space URL:[/bold] https://huggingface.co/spaces/{repo_id}" + ) + except Exception as e: + console.print(f"[bold red]✗[/bold red] Upload failed: {e}") + raise typer.Exit(1) from e + + +@app.command() +def push( + directory: Annotated[ + str | None, + typer.Argument( + help="Directory containing the OpenEnv environment (default: current directory)" + ), + ] = None, + repo_id: Annotated[ + str | None, + typer.Option( + "--repo-id", + "-r", + help="Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)", + ), + ] = None, + base_image: Annotated[ + str | None, + typer.Option( + "--base-image", + "-b", + help="Base Docker image to use (overrides Dockerfile FROM)", + ), + ] = None, + interface: Annotated[ + bool, + typer.Option( + "--interface", + help="Enable web interface (default: True if no registry specified)", + ), + ] = None, + no_interface: Annotated[ + bool, + typer.Option( + "--no-interface", + help="Disable web interface", + ), + ] = False, + registry: Annotated[ + str | None, + typer.Option( + "--registry", + help="Custom registry URL (e.g., docker.io/username). Disables web interface by default.", + ), + ] = None, + private: Annotated[ + bool, + typer.Option( + "--private", + help="Deploy the space as private", + ), + ] = False, + create_pr: Annotated[ + bool, + typer.Option( + "--create-pr", + help="Create a Pull Request instead of pushing to the default branch", + ), + ] = False, + exclude: Annotated[ + str | None, + typer.Option( + "--exclude", + help="Optional additional ignore file with newline-separated glob patterns to exclude from Hugging Face uploads", + ), + ] = None, +) -> None: + """ + Push an OpenEnv environment to Hugging Face Spaces or a custom Docker registry. + + This command: + 1. Validates that the directory is an OpenEnv environment (openenv.yaml present) + 2. Builds and pushes to Hugging Face Spaces or custom Docker registry + 3. Optionally enables web interface for deployment + + The web interface is enabled by default when pushing to HuggingFace Spaces, + but disabled by default when pushing to a custom Docker registry. + + Examples: + # Push to HuggingFace Spaces from current directory (web interface enabled) + $ cd my_env + $ openenv push + + # Push to HuggingFace repo and open a Pull Request + $ openenv push my-org/my-env --create-pr + $ openenv push --repo-id my-org/my-env --create-pr + + # Push to HuggingFace without web interface + $ openenv push --no-interface + + # Push to Docker Hub + $ openenv push --registry docker.io/myuser + + # Push to GitHub Container Registry + $ openenv push --registry ghcr.io/myorg + + # Push to custom registry with web interface + $ openenv push --registry myregistry.io/path1/path2 --interface + + # Push to specific HuggingFace repo + $ openenv push --repo-id my-org/my-env + + # Push privately with custom base image + $ openenv push --private --base-image ghcr.io/meta-pytorch/openenv-base:latest + """ + # Handle interface flag logic + if no_interface and interface: + console.print( + "[bold red]Error:[/bold red] Cannot specify both --interface and --no-interface", + file=sys.stderr, + ) + raise typer.Exit(1) + + # Determine if web interface should be enabled + if no_interface: + enable_interface = False + elif interface is not None: + enable_interface = interface + elif registry is not None: + # Custom registry: disable interface by default + enable_interface = False + else: + # HuggingFace: enable interface by default + enable_interface = True + + # Determine directory + if directory: + env_dir = Path(directory).resolve() + else: + env_dir = Path.cwd().resolve() + + if not env_dir.exists() or not env_dir.is_dir(): + raise typer.BadParameter(f"Directory does not exist: {env_dir}") + + # Check for openenv.yaml to confirm this is an environment directory + openenv_yaml = env_dir / "openenv.yaml" + if not openenv_yaml.exists(): + console.print( + f"[bold red]Error:[/bold red] Not an OpenEnv environment directory (missing openenv.yaml): {env_dir}", + ) + console.print( + "[yellow]Hint:[/yellow] Run this command from the environment root directory", + ) + raise typer.Exit(1) + + # Validate OpenEnv environment + console.print( + f"[bold cyan]Validating OpenEnv environment in {env_dir}...[/bold cyan]" + ) + env_name, manifest = _validate_openenv_directory(env_dir) + console.print(f"[bold green]✓[/bold green] Found OpenEnv environment: {env_name}") + + # Handle custom registry push + if registry: + console.print("[bold cyan]Preparing to push to custom registry...[/bold cyan]") + if enable_interface: + console.print("[bold cyan]Web interface will be enabled[/bold cyan]") + + # Import build functions + from .build import _build_docker_image, _push_docker_image + + # Prepare build args for custom registry deployment + build_args = {} + if enable_interface: + build_args["ENABLE_WEB_INTERFACE"] = "true" + + # Build Docker image from the environment directory + tag = f"{registry}/{env_name}" + console.print(f"[bold cyan]Building Docker image: {tag}[/bold cyan]") + + success = _build_docker_image( + env_path=env_dir, + tag=tag, + build_args=build_args if build_args else None, + ) + + if not success: + console.print("[bold red]✗ Docker build failed[/bold red]") + raise typer.Exit(1) + + console.print("[bold green]✓ Docker build successful[/bold green]") + + # Push to registry + console.print(f"[bold cyan]Pushing to registry: {registry}[/bold cyan]") + + success = _push_docker_image( + tag, registry=None + ) # Tag already includes registry + + if not success: + console.print("[bold red]✗ Docker push failed[/bold red]") + raise typer.Exit(1) + + console.print("\n[bold green]✓ Deployment complete![/bold green]") + console.print(f"[bold]Image:[/bold] {tag}") + return + + ignore_patterns = _load_ignore_patterns(env_dir, exclude) + + # Ensure authentication for HuggingFace + username = _ensure_hf_authenticated() + + # Determine repo_id + if not repo_id: + repo_id = f"{username}/{env_name}" + + # Validate repo_id format + if "/" not in repo_id or repo_id.count("/") != 1: + raise typer.BadParameter( + f"Invalid repo-id format: {repo_id}. Expected format: 'username/repo-name'" + ) + + # Initialize Hugging Face API + api = HfApi() + + # Prepare staging directory + deployment_type = ( + "with web interface" if enable_interface else "without web interface" + ) + console.print( + f"[bold cyan]Preparing files for Hugging Face deployment ({deployment_type})...[/bold cyan]" + ) + with tempfile.TemporaryDirectory() as tmpdir: + staging_dir = Path(tmpdir) / "staging" + _prepare_staging_directory( + env_dir, + env_name, + staging_dir, + ignore_patterns=ignore_patterns, + base_image=base_image, + enable_interface=enable_interface, + ) + + # Create/verify space (no-op if exists; needed when pushing to own new repo) + if not create_pr: + _create_hf_space(repo_id, api, private=private) + # When create_pr we rely on upload_folder to create branch and PR + + # Upload files + _upload_to_hf_space( + repo_id, + staging_dir, + api, + private=private, + create_pr=create_pr, + ignore_patterns=ignore_patterns, + ) + + console.print("\n[bold green]✓ Deployment complete![/bold green]") + console.print(f"Visit your space at: https://huggingface.co/spaces/{repo_id}") diff --git a/src/core/openenv/cli/commands/serve.py b/src/core/openenv/cli/commands/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..df2bfa5a34d83e07ea35e5df06e523d1c565cbc5 --- /dev/null +++ b/src/core/openenv/cli/commands/serve.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Serve OpenEnv environments locally (TO BE IMPLEMENTED).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Annotated + +import typer + +from .._cli_utils import console + +app = typer.Typer(help="Serve OpenEnv environments locally") + + +@app.command() +def serve( + env_path: Annotated[ + str | None, + typer.Argument( + help="Path to the environment directory (default: current directory)" + ), + ] = None, + port: Annotated[ + int, + typer.Option("--port", "-p", help="Port to serve on"), + ] = 8000, + host: Annotated[ + str, + typer.Option("--host", help="Host to bind to"), + ] = "0.0.0.0", + reload: Annotated[ + bool, + typer.Option("--reload", help="Enable auto-reload on code changes"), + ] = False, +) -> None: + """ + Serve an OpenEnv environment locally. + + TODO: This command is currently not implemented and has been deferred for later. + + Planned functionality: + - Run environment server locally without Docker + - Support multiple deployment modes (local, notebook, cluster) + - Auto-reload for development + - Integration with environment's [project.scripts] entry point + + For now, use Docker-based serving: + 1. Build the environment: openenv build + 2. Run the container: docker run -p 8000:8000 + + Or use uv directly: + uv run --project . server --port 8000 + """ + console.print("[bold yellow]⚠ This command is not yet implemented[/bold yellow]\n") + + console.print( + "The [bold cyan]openenv serve[/bold cyan] command has been deferred for later." + ) + + console.print("[bold]Alternative approaches:[/bold]\n") + + console.print("[cyan]Option 1: Docker-based serving (recommended)[/cyan]") + console.print(" 1. Build the environment:") + console.print(" [dim]$ openenv build[/dim]") + console.print(" 2. Run the Docker container:") + console.print( + f" [dim]$ docker run -p {port}:{port} openenv-:latest[/dim]\n" + ) + + console.print("[cyan]Option 2: Direct execution with uv[/cyan]") + + # Determine environment path + if env_path is None: + env_path_obj = Path.cwd() + else: + env_path_obj = Path(env_path) + + # Check for openenv.yaml + openenv_yaml = env_path_obj / "openenv.yaml" + if openenv_yaml.exists(): + console.print(" From your environment directory:") + console.print(f" [dim]$ cd {env_path_obj}[/dim]") + console.print(f" [dim]$ uv run --project . server --port {port}[/dim]\n") + else: + console.print(" From an environment directory with pyproject.toml:") + console.print(f" [dim]$ uv run --project . server --port {port}[/dim]\n") + + raise typer.Exit(0) diff --git a/src/core/openenv/cli/commands/skills.py b/src/core/openenv/cli/commands/skills.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb29db72e26a104e9eb75a6309fbc9ed39538eb --- /dev/null +++ b/src/core/openenv/cli/commands/skills.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Commands to manage OpenEnv CLI skills for AI assistants.""" + +from __future__ import annotations + +import os +import shutil +from pathlib import Path +from typing import Annotated + +import typer + +DEFAULT_SKILL_ID = "openenv-cli" + +_SKILL_YAML_PREFIX = """\ +--- +name: openenv-cli +description: "OpenEnv CLI (`openenv`) for scaffolding, validating, building, and pushing OpenEnv environments." +--- + +Install: `pip install openenv-core` + +The OpenEnv CLI command `openenv` is available. +Use `openenv --help` to view available commands. +""" + +_SKILL_TIPS = """ +## Tips + +- Start with `openenv init ` to scaffold a new environment +- Validate projects with `openenv validate` +- Build and deploy with `openenv build` and `openenv push` +- Use `openenv --help` for command-specific options +""" + +CENTRAL_LOCAL = Path(".agents/skills") +CENTRAL_GLOBAL = Path("~/.agents/skills") + +GLOBAL_TARGETS = { + "codex": Path("~/.codex/skills"), + "claude": Path("~/.claude/skills"), + "cursor": Path("~/.cursor/skills"), + "opencode": Path("~/.config/opencode/skills"), +} + +LOCAL_TARGETS = { + "codex": Path(".codex/skills"), + "claude": Path(".claude/skills"), + "cursor": Path(".cursor/skills"), + "opencode": Path(".opencode/skills"), +} + +app = typer.Typer(help="Manage OpenEnv skills for AI assistants") + + +def _build_skill_md() -> str: + """Generate SKILL.md content for the OpenEnv CLI skill.""" + from openenv import __version__ + + lines = _SKILL_YAML_PREFIX.splitlines() + lines.append("") + lines.append( + f"Generated with `openenv-core v{__version__}`. Run `openenv skills add --force` to regenerate." + ) + lines.extend(_SKILL_TIPS.splitlines()) + return "\n".join(lines).strip() + "\n" + + +def _remove_existing(path: Path, force: bool) -> None: + """Remove existing file/directory/symlink if force is True, else fail.""" + if not (path.exists() or path.is_symlink()): + return + if not force: + raise typer.Exit(code=1) + + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path) + else: + path.unlink() + + +def _install_to(skills_dir: Path, force: bool) -> Path: + """Install the OpenEnv skill in a skills directory.""" + skills_dir = skills_dir.expanduser().resolve() + skills_dir.mkdir(parents=True, exist_ok=True) + dest = skills_dir / DEFAULT_SKILL_ID + + if dest.exists() or dest.is_symlink(): + if not force: + typer.echo( + f"Skill already exists at {dest}. Re-run with --force to overwrite." + ) + raise typer.Exit(code=1) + _remove_existing(dest, force=True) + + dest.mkdir() + (dest / "SKILL.md").write_text(_build_skill_md(), encoding="utf-8") + return dest + + +def _create_symlink( + agent_skills_dir: Path, central_skill_path: Path, force: bool +) -> Path: + """Create a relative symlink from agent directory to central skill location.""" + agent_skills_dir = agent_skills_dir.expanduser().resolve() + agent_skills_dir.mkdir(parents=True, exist_ok=True) + link_path = agent_skills_dir / DEFAULT_SKILL_ID + + if link_path.exists() or link_path.is_symlink(): + if not force: + typer.echo( + f"Skill already exists at {link_path}. Re-run with --force to overwrite." + ) + raise typer.Exit(code=1) + _remove_existing(link_path, force=True) + + link_path.symlink_to(os.path.relpath(central_skill_path, agent_skills_dir)) + return link_path + + +@app.command("preview") +def skills_preview() -> None: + """Print generated SKILL.md content.""" + typer.echo(_build_skill_md()) + + +@app.command("add") +def skills_add( + claude: Annotated[ + bool, + typer.Option("--claude", help="Install for Claude."), + ] = False, + codex: Annotated[ + bool, + typer.Option("--codex", help="Install for Codex."), + ] = False, + cursor: Annotated[ + bool, + typer.Option("--cursor", help="Install for Cursor."), + ] = False, + opencode: Annotated[ + bool, + typer.Option("--opencode", help="Install for OpenCode."), + ] = False, + global_: Annotated[ + bool, + typer.Option( + "--global", + "-g", + help=( + "Install globally (user-level) instead of in the current project directory." + ), + ), + ] = False, + dest: Annotated[ + Path | None, + typer.Option(help="Install into a custom destination (skills directory path)."), + ] = None, + force: Annotated[ + bool, + typer.Option("--force", help="Overwrite existing skills in the destination."), + ] = False, +) -> None: + """Install OpenEnv CLI skill for AI assistants.""" + if dest: + if claude or codex or cursor or opencode or global_: + typer.echo( + "--dest cannot be combined with --claude, --codex, --cursor, --opencode, or --global." + ) + raise typer.Exit(code=1) + skill_dest = _install_to(dest, force) + typer.echo(f"Installed '{DEFAULT_SKILL_ID}' to {skill_dest}") + return + + central_path = CENTRAL_GLOBAL if global_ else CENTRAL_LOCAL + central_skill_path = _install_to(central_path, force) + typer.echo( + f"Installed '{DEFAULT_SKILL_ID}' to central location: {central_skill_path}" + ) + + targets = GLOBAL_TARGETS if global_ else LOCAL_TARGETS + agent_targets: list[Path] = [] + + if claude: + agent_targets.append(targets["claude"]) + if codex: + agent_targets.append(targets["codex"]) + if cursor: + agent_targets.append(targets["cursor"]) + if opencode: + agent_targets.append(targets["opencode"]) + + for agent_target in agent_targets: + link_path = _create_symlink(agent_target, central_skill_path, force) + typer.echo(f"Created symlink: {link_path}") diff --git a/src/core/openenv/cli/commands/validate.py b/src/core/openenv/cli/commands/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..32abcc17e11f8cf22e08d04e570383f57ccc1199 --- /dev/null +++ b/src/core/openenv/cli/commands/validate.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv validate command. + +This module provides the 'openenv validate' command to check if environments +are properly configured for multi-mode deployment. +""" + +import json +from pathlib import Path +from typing import Annotated + +import typer +from openenv.cli._validation import ( + build_local_validation_json_report, + format_validation_report, + get_deployment_modes, + validate_multi_mode_deployment, + validate_running_environment, +) + + +def _looks_like_url(value: str) -> bool: + """Return True when the value appears to be a URL target.""" + candidate = value.strip().lower() + return candidate.startswith("http://") or candidate.startswith("https://") + + +def validate( + target: Annotated[ + str | None, + typer.Argument( + help=( + "Path to the environment directory (default: current directory) " + "or a running OpenEnv URL (http://... or https://...)" + ), + ), + ] = None, + url: Annotated[ + str | None, + typer.Option( + "--url", + help="Validate a running OpenEnv server by base URL (e.g. http://localhost:8000)", + ), + ] = None, + json_output: Annotated[ + bool, + typer.Option( + "--json", + help="Output local validation report as JSON (runtime validation is JSON by default)", + ), + ] = False, + timeout: Annotated[ + float, + typer.Option( + "--timeout", + help="HTTP timeout in seconds for runtime validation", + min=0.1, + ), + ] = 5.0, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Show detailed information") + ] = False, +) -> None: + """ + Validate local environments and running OpenEnv servers. + + Local validation checks if an environment is properly configured with: + - Required files (pyproject.toml, openenv.yaml, server/app.py, etc.) + - Docker deployment support + - uv run server capability + - python -m module execution + + Runtime validation checks if a live OpenEnv server conforms to the + versioned runtime API contract and returns a criteria-based JSON report. + + Examples: + # Validate current directory (recommended) + $ cd my_env + $ openenv validate + + # Validate a running environment and return JSON criteria + $ openenv validate --url http://localhost:8000 + $ openenv validate https://my-env.hf.space + + # Validate with detailed output + $ openenv validate --verbose + + # Validate specific environment + $ openenv validate envs/echo_env + """ + runtime_target = url + if ( + runtime_target is not None + and target is not None + and not _looks_like_url(target) + ): + typer.echo( + "Error: Cannot combine a local path argument with --url runtime validation", + err=True, + ) + raise typer.Exit(1) + + if target is not None and _looks_like_url(target): + if runtime_target is not None and runtime_target != target: + typer.echo( + "Error: Conflicting runtime targets provided via argument and --url", + err=True, + ) + raise typer.Exit(1) + runtime_target = target + + if runtime_target is not None: + try: + report = validate_running_environment(runtime_target, timeout_s=timeout) + except ValueError as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from exc + + typer.echo(json.dumps(report, indent=2)) + if not report.get("passed", False): + raise typer.Exit(1) + return + + # Determine environment path (default to current directory) + if target is None: + env_path_obj = Path.cwd() + else: + env_path_obj = Path(target) + + if not env_path_obj.exists(): + typer.echo(f"Error: Path does not exist: {env_path_obj}", err=True) + raise typer.Exit(1) + + if not env_path_obj.is_dir(): + typer.echo(f"Error: Path is not a directory: {env_path_obj}", err=True) + raise typer.Exit(1) + + # Check for openenv.yaml to confirm this is an environment directory + openenv_yaml = env_path_obj / "openenv.yaml" + if not openenv_yaml.exists(): + typer.echo( + f"Error: Not an OpenEnv environment directory (missing openenv.yaml): {env_path_obj}", + err=True, + ) + typer.echo( + "Hint: Run this command from the environment root directory or specify the path", + err=True, + ) + raise typer.Exit(1) + + env_name = env_path_obj.name + if env_name.endswith("_env"): + base_name = env_name[:-4] + else: + base_name = env_name + + # Run validation + is_valid, issues = validate_multi_mode_deployment(env_path_obj) + modes = get_deployment_modes(env_path_obj) + + if json_output: + report = build_local_validation_json_report( + env_name=base_name, + env_path=env_path_obj, + is_valid=is_valid, + issues=issues, + deployment_modes=modes if verbose else None, + ) + typer.echo(json.dumps(report, indent=2)) + if not is_valid: + raise typer.Exit(1) + return + + # Show validation report + report = format_validation_report(base_name, is_valid, issues) + typer.echo(report) + + # Show deployment modes if verbose + if verbose: + typer.echo("\nSupported deployment modes:") + for mode, supported in modes.items(): + status = "[YES]" if supported else "[NO]" + typer.echo(f" {status} {mode}") + + if is_valid: + typer.echo("\nUsage examples:") + typer.echo(f" cd {env_path_obj.name} && uv run server") + typer.echo(f" cd {env_path_obj.name} && openenv build") + typer.echo(f" cd {env_path_obj.name} && openenv push") + + if not is_valid: + raise typer.Exit(1) diff --git a/src/core/openenv/cli/templates/__init__.py b/src/core/openenv/cli/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..452e81a7b8584c3447c6f83fc9560f6f9d334ced --- /dev/null +++ b/src/core/openenv/cli/templates/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenEnv CLI templates package.""" diff --git a/src/core/openenv/cli/templates/openenv_env/.dockerignore b/src/core/openenv/cli/templates/openenv_env/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..fc288e5de90f4988be5e0ef73d17b2314786406f --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/.dockerignore @@ -0,0 +1,15 @@ +.venv +.git +.gitignore +.env +__pycache__/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw +*.pyzwz + + diff --git a/src/core/openenv/cli/templates/openenv_env/README.md b/src/core/openenv/cli/templates/openenv_env/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3f14526a0ce173408073358a6b94d15c85c9aa97 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/README.md @@ -0,0 +1,255 @@ +--- +title: __ENV_TITLE_NAME__ Environment Server +emoji: __HF_EMOJI__ +colorFrom: __HF_COLOR_FROM__ +colorTo: __HF_COLOR_TO__ +sdk: docker +pinned: false +app_port: 8000 +base_path: /web +tags: + - openenv +--- + +# __ENV_TITLE_NAME__ Environment + +A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns. + +## Quick Start + +The simplest way to use the __ENV_TITLE_NAME__ environment is through the `__ENV_CLASS_NAME__Env` class: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +try: + # Create environment from Docker image + __ENV_NAME__env = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest") + + # Reset + result = __ENV_NAME__env.reset() + print(f"Reset: {result.observation.echoed_message}") + + # Send multiple messages + messages = ["Hello, World!", "Testing echo", "Final message"] + + for msg in messages: + result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Sent: '{msg}'") + print(f" → Echoed: '{result.observation.echoed_message}'") + print(f" → Length: {result.observation.message_length}") + print(f" → Reward: {result.reward}") + +finally: + # Always clean up + __ENV_NAME__env.close() +``` + +That's it! The `__ENV_CLASS_NAME__Env.from_docker_image()` method handles: +- Starting the Docker container +- Waiting for the server to be ready +- Connecting to the environment +- Container cleanup when you call `close()` + +## Building the Docker Image + +Before using the environment, you need to build the Docker image: + +```bash +# From project root +docker build -t __ENV_NAME__-env:latest -f server/Dockerfile . +``` + +## Deploying to Hugging Face Spaces + +You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command: + +```bash +# From the environment directory (where openenv.yaml is located) +openenv push + +# Or specify options +openenv push --namespace my-org --private +``` + +The `openenv push` command will: +1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`) +2. Prepare a custom build for Hugging Face Docker space (enables web interface) +3. Upload to Hugging Face (ensuring you're logged in) + +### Prerequisites + +- Authenticate with Hugging Face: The command will prompt for login if not already authenticated + +### Options + +- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory) +- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml) +- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM) +- `--private`: Deploy the space as private (default: public) + +### Examples + +```bash +# Push to your personal namespace (defaults to username/env-name from openenv.yaml) +openenv push + +# Push to a specific repository +openenv push --repo-id my-org/my-env + +# Push with a custom base image +openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest + +# Push as a private space +openenv push --private + +# Combine options +openenv push --repo-id my-org/my-env --base-image custom-base:latest --private +``` + +After deployment, your space will be available at: +`https://huggingface.co/spaces/` + +The deployed space includes: +- **Web Interface** at `/web` - Interactive UI for exploring the environment +- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface +- **Health Check** at `/health` - Container health monitoring +- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions + +## Environment Details + +### Action +**__ENV_CLASS_NAME__Action**: Contains a single field +- `message` (str) - The message to echo back + +### Observation +**__ENV_CLASS_NAME__Observation**: Contains the echo response and metadata +- `echoed_message` (str) - The message echoed back +- `message_length` (int) - Length of the message +- `reward` (float) - Reward based on message length (length × 0.1) +- `done` (bool) - Always False for echo environment +- `metadata` (dict) - Additional info like step count + +### Reward +The reward is calculated as: `message_length × 0.1` +- "Hi" → reward: 0.2 +- "Hello, World!" → reward: 1.3 +- Empty message → reward: 0.0 + +## Advanced Usage + +### Connecting to an Existing Server + +If you already have a __ENV_TITLE_NAME__ environment server running, you can connect directly: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Env + +# Connect to existing server +__ENV_NAME__env = __ENV_CLASS_NAME__Env(base_url="") + +# Use as normal +result = __ENV_NAME__env.reset() +result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!")) +``` + +Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server. + +### Using the Context Manager + +The client supports context manager usage for automatic connection management: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +# Connect with context manager (auto-connects and closes) +with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: + result = env.reset() + print(f"Reset: {result.observation.echoed_message}") + # Multiple steps with low latency + for msg in ["Hello", "World", "!"]: + result = env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Echoed: {result.observation.echoed_message}") +``` + +The client uses WebSocket connections for: +- **Lower latency**: No HTTP connection overhead per request +- **Persistent session**: Server maintains your environment state +- **Efficient for episodes**: Better for many sequential steps + +### Concurrent WebSocket Sessions + +The server supports multiple concurrent WebSocket connections. To enable this, +modify `server/app.py` to use factory mode: + +```python +# In server/app.py - use factory mode for concurrent sessions +app = create_app( + __ENV_CLASS_NAME__Environment, # Pass class, not instance + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + max_concurrent_envs=4, # Allow 4 concurrent sessions +) +``` + +Then multiple clients can connect simultaneously: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env +from concurrent.futures import ThreadPoolExecutor + +def run_episode(client_id: int): + with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: + result = env.reset() + for i in range(10): + result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}")) + return client_id, result.observation.message_length + +# Run 4 episodes concurrently +with ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(run_episode, range(4))) +``` + +## Development & Testing + +### Direct Environment Testing + +Test the environment logic directly without starting the HTTP server: + +```bash +# From the server directory +python3 server/__ENV_NAME___environment.py +``` + +This verifies that: +- Environment resets correctly +- Step executes actions properly +- State tracking works +- Rewards are calculated correctly + +### Running Locally + +Run the server locally for development: + +```bash +uvicorn server.app:app --reload +``` + +## Project Structure + +``` +__ENV_NAME__/ +├── .dockerignore # Docker build exclusions +├── __init__.py # Module exports +├── README.md # This file +├── openenv.yaml # OpenEnv manifest +├── pyproject.toml # Project metadata and dependencies +├── uv.lock # Locked dependencies (generated) +├── client.py # __ENV_CLASS_NAME__Env client +├── models.py # Action and Observation models +└── server/ + ├── __init__.py # Server module exports + ├── __ENV_NAME___environment.py # Core environment logic + ├── app.py # FastAPI application (HTTP + WebSocket endpoints) + └── Dockerfile # Container image definition +``` diff --git a/src/core/openenv/cli/templates/openenv_env/__init__.py b/src/core/openenv/cli/templates/openenv_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe07a082faf989d3ae22ece407c34364b394128 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""__ENV_TITLE_NAME__ Environment.""" + +from .client import __ENV_CLASS_NAME__Env +from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + +__all__ = [ + "__ENV_CLASS_NAME__Action", + "__ENV_CLASS_NAME__Observation", + "__ENV_CLASS_NAME__Env", +] diff --git a/src/core/openenv/cli/templates/openenv_env/client.py b/src/core/openenv/cli/templates/openenv_env/client.py new file mode 100644 index 0000000000000000000000000000000000000000..720090431300aad0866c8a737f84a48a3df238b3 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/client.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""__ENV_TITLE_NAME__ Environment Client.""" + +from typing import Dict + +from openenv.core import EnvClient +from openenv.core.client_types import StepResult +from openenv.core.env_server.types import State + +from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + + +class __ENV_CLASS_NAME__Env( + EnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation, State] +): + """ + Client for the __ENV_TITLE_NAME__ Environment. + + This client maintains a persistent WebSocket connection to the environment server, + enabling efficient multi-step interactions with lower latency. + Each client instance has its own dedicated environment session on the server. + + Example: + >>> # Connect to a running server + >>> with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) + ... print(result.observation.echoed_message) + + Example with Docker: + >>> # Automatically start container and connect + >>> client = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest") + >>> try: + ... result = client.reset() + ... result = client.step(__ENV_CLASS_NAME__Action(message="Test")) + ... finally: + ... client.close() + """ + + def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: + """ + Convert __ENV_CLASS_NAME__Action to JSON payload for step message. + + Args: + action: __ENV_CLASS_NAME__Action instance + + Returns: + Dictionary representation suitable for JSON encoding + """ + return { + "message": action.message, + } + + def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]: + """ + Parse server response into StepResult[__ENV_CLASS_NAME__Observation]. + + Args: + payload: JSON response data from server + + Returns: + StepResult with __ENV_CLASS_NAME__Observation + """ + obs_data = payload.get("observation", {}) + observation = __ENV_CLASS_NAME__Observation( + echoed_message=obs_data.get("echoed_message", ""), + message_length=obs_data.get("message_length", 0), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict) -> State: + """ + Parse server response into State object. + + Args: + payload: JSON response from state request + + Returns: + State object with episode_id and step_count + """ + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) diff --git a/src/core/openenv/cli/templates/openenv_env/models.py b/src/core/openenv/cli/templates/openenv_env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..5aea7f452a043602375620c48e65f0915ebf7f42 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/models.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Data models for the __ENV_TITLE_NAME__ Environment. + +The __ENV_NAME__ environment is a simple test environment that echoes back messages. +""" + +from openenv.core.env_server.types import Action, Observation +from pydantic import Field + + +class __ENV_CLASS_NAME__Action(Action): + """Action for the __ENV_TITLE_NAME__ environment - just a message to echo.""" + + message: str = Field(..., description="Message to echo back") + + +class __ENV_CLASS_NAME__Observation(Observation): + """Observation from the __ENV_TITLE_NAME__ environment - the echoed message.""" + + echoed_message: str = Field(default="", description="The echoed message") + message_length: int = Field(default=0, description="Length of the echoed message") diff --git a/src/core/openenv/cli/templates/openenv_env/openenv.yaml b/src/core/openenv/cli/templates/openenv_env/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..828cc53b2b61c37bf6f860f25cbe2881825e3fd3 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/openenv.yaml @@ -0,0 +1,7 @@ +spec_version: 1 +name: __ENV_NAME__ +type: space +runtime: fastapi +app: server.app:app +port: 8000 + diff --git a/src/core/openenv/cli/templates/openenv_env/pyproject.toml b/src/core/openenv/cli/templates/openenv_env/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..a8e59fbfa3dbc8a0df7c84d479e79cef062d8e61 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/pyproject.toml @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "openenv-__ENV_NAME__" +version = "0.1.0" +description = "__ENV_TITLE_NAME__ environment for OpenEnv" +requires-python = ">=3.10" +dependencies = [ + # Core OpenEnv runtime (provides FastAPI server + HTTP client types) + # install from github + # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", + "openenv-core[core]>=0.2.1", + # Environment-specific dependencies + # Add all dependencies needed for your environment here + # Examples: + # "numpy>=1.19.0", + # "torch>=2.0.0", + # "gymnasium>=0.29.0", + # "openspiel>=1.0.0", + # "smolagents>=1.22.0,<2", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-cov>=4.0.0", +] + +[project.scripts] +# Server entry point - enables running via: uv run --project . server +# or: python -m __ENV_NAME__.server.app +server = "__ENV_NAME__.server.app:main" + +[tool.setuptools] +include-package-data = true +packages = ["__ENV_NAME__", "__ENV_NAME__.server"] +package-dir = { "__ENV_NAME__" = ".", "__ENV_NAME__.server" = "server" } \ No newline at end of file diff --git a/src/core/openenv/cli/templates/openenv_env/server/Dockerfile b/src/core/openenv/cli/templates/openenv_env/server/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3d10ac76bf7e199e26fb77921f88d98f96120368 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/server/Dockerfile @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Multi-stage build using openenv-base +# This Dockerfile is flexible and works for both: +# - In-repo environments (with local OpenEnv sources) +# - Standalone environments (with openenv from PyPI/Git) +# The build script (openenv build) handles context detection and sets appropriate build args. + +ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ${BASE_IMAGE} AS builder + +WORKDIR /app + +# Ensure git is available (required for installing dependencies from VCS) +RUN apt-get update && \ + apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* + +# Build argument to control whether we're building standalone or in-repo +ARG BUILD_MODE=in-repo +ARG ENV_NAME=__ENV_NAME__ + +# Copy environment code (always at root of build context) +COPY . /app/env + +# For in-repo builds, openenv is already vendored in the build context +# For standalone builds, openenv will be installed via pyproject.toml +WORKDIR /app/env + +# Ensure uv is available (for local builds where base image lacks it) +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +# Install dependencies using uv sync +# If uv.lock exists, use it; otherwise resolve on the fly +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-install-project --no-editable; \ + else \ + uv sync --no-install-project --no-editable; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-editable; \ + else \ + uv sync --no-editable; \ + fi + +# Final runtime stage +FROM ${BASE_IMAGE} + +WORKDIR /app + +# Copy the virtual environment from builder +COPY --from=builder /app/env/.venv /app/.venv + +# Copy the environment code +COPY --from=builder /app/env /app/env + +# Set PATH to use the virtual environment +ENV PATH="/app/.venv/bin:$PATH" + +# Set PYTHONPATH so imports work correctly +ENV PYTHONPATH="/app/env:$PYTHONPATH" + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the FastAPI server +# The module path is constructed to work with the /app/env structure +CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/src/core/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/core/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py new file mode 100644 index 0000000000000000000000000000000000000000..bbde58219abbb880e79662bde49c6adab96f77eb --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +__ENV_TITLE_NAME__ Environment Implementation. + +A simple test environment that echoes back messages sent to it. +Perfect for testing HTTP server infrastructure. +""" + +from uuid import uuid4 + +from openenv.core.env_server.interfaces import Environment +from openenv.core.env_server.types import State + +try: + from ..models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation +except ImportError: + from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + + +class __ENV_CLASS_NAME__Environment(Environment): + """ + A simple echo environment that echoes back messages. + + This environment is designed for testing the HTTP server infrastructure. + It maintains minimal state and simply echoes back whatever message it receives. + + Example: + >>> env = __ENV_CLASS_NAME__Environment() + >>> obs = env.reset() + >>> print(obs.echoed_message) # "__ENV_TITLE_NAME__ environment ready!" + >>> + >>> obs = env.step(__ENV_CLASS_NAME__Action(message="Hello")) + >>> print(obs.echoed_message) # "Hello" + >>> print(obs.message_length) # 5 + """ + + # Enable concurrent WebSocket sessions. + # Set to True if your environment isolates state between instances. + # When True, multiple WebSocket clients can connect simultaneously, each + # getting their own environment instance (when using factory mode in app.py). + SUPPORTS_CONCURRENT_SESSIONS: bool = True + + def __init__(self): + """Initialize the __ENV_NAME__ environment.""" + self._state = State(episode_id=str(uuid4()), step_count=0) + self._reset_count = 0 + + def reset(self) -> __ENV_CLASS_NAME__Observation: + """ + Reset the environment. + + Returns: + __ENV_CLASS_NAME__Observation with a ready message + """ + self._state = State(episode_id=str(uuid4()), step_count=0) + self._reset_count += 1 + + return __ENV_CLASS_NAME__Observation( + echoed_message="__ENV_TITLE_NAME__ environment ready!", + message_length=0, + done=False, + reward=0.0, + ) + + def step(self, action: __ENV_CLASS_NAME__Action) -> __ENV_CLASS_NAME__Observation: # type: ignore[override] + """ + Execute a step in the environment by echoing the message. + + Args: + action: __ENV_CLASS_NAME__Action containing the message to echo + + Returns: + __ENV_CLASS_NAME__Observation with the echoed message and its length + """ + self._state.step_count += 1 + + message = action.message + length = len(message) + + # Simple reward: longer messages get higher rewards + reward = length * 0.1 + + return __ENV_CLASS_NAME__Observation( + echoed_message=message, + message_length=length, + done=False, + reward=reward, + metadata={"original_message": message, "step": self._state.step_count}, + ) + + @property + def state(self) -> State: + """ + Get the current environment state. + + Returns: + Current State with episode_id and step_count + """ + return self._state diff --git a/src/core/openenv/cli/templates/openenv_env/server/__init__.py b/src/core/openenv/cli/templates/openenv_env/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..191fb655582f1cc13943574814ed4b39b5d60d7c --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/server/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""__ENV_TITLE_NAME__ environment server components.""" + +from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment + +__all__ = ["__ENV_CLASS_NAME__Environment"] diff --git a/src/core/openenv/cli/templates/openenv_env/server/app.py b/src/core/openenv/cli/templates/openenv_env/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..898911a2a55495426d20b438c4de009ec103ccdd --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/server/app.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +FastAPI application for the __ENV_TITLE_NAME__ Environment. + +This module creates an HTTP server that exposes the __ENV_CLASS_NAME__Environment +over HTTP and WebSocket endpoints, compatible with EnvClient. + +Endpoints: + - POST /reset: Reset the environment + - POST /step: Execute an action + - GET /state: Get current environment state + - GET /schema: Get action/observation schemas + - WS /ws: WebSocket endpoint for persistent sessions + +Usage: + # Development (with auto-reload): + uvicorn server.app:app --reload --host 0.0.0.0 --port 8000 + + # Production: + uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4 + + # Or run directly: + python -m server.app +""" + +try: + from openenv.core.env_server.http_server import create_app +except Exception as e: # pragma: no cover + raise ImportError( + "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" + ) from e + +try: + from ..models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment +except ModuleNotFoundError: + from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + from server.__ENV_NAME___environment import __ENV_CLASS_NAME__Environment + + +# Create the app with web interface and README integration +app = create_app( + __ENV_CLASS_NAME__Environment, + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + env_name="__ENV_NAME__", + max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions +) + + +def main(host: str = "0.0.0.0", port: int = 8000): + """ + Entry point for direct execution via uv run or python -m. + + This function enables running the server without Docker: + uv run --project . server + uv run --project . server --port 8001 + python -m __ENV_NAME__.server.app + + Args: + host: Host address to bind to (default: "0.0.0.0") + port: Port number to listen on (default: 8000) + + For production deployments, consider using uvicorn directly with + multiple workers: + uvicorn __ENV_NAME__.server.app:app --workers 4 + """ + import uvicorn + + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + main(port=args.port) diff --git a/src/core/openenv/cli/templates/openenv_env/server/requirements.txt b/src/core/openenv/cli/templates/openenv_env/server/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..65b1c22b3db715ed9d63b9ad06cd4afb0d9412c5 --- /dev/null +++ b/src/core/openenv/cli/templates/openenv_env/server/requirements.txt @@ -0,0 +1,6 @@ +openenv[core]>=0.2.0 +fastapi>=0.115.0 +uvicorn>=0.24.0 + + + diff --git a/src/core/openenv/core/README.md b/src/core/openenv/core/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d153f1e4f72ce6c7b4e814c78c74e0e734c462b --- /dev/null +++ b/src/core/openenv/core/README.md @@ -0,0 +1,212 @@ +# image OpenEnv: Agentic Execution Environments + +An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - step(), reset(), state(). Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs. + +In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use. + + +## Overview +`openenv.core` provides the foundational building blocks for creating and interacting with containerized environments over HTTP. It enables you to build agent environments that can be deployed as Docker containers and accessed via a simple HTTP API. + +> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental +> stage. You should expect bugs, incomplete features, and APIs that may change +> in future versions. The project welcomes bugfixes, but to make sure things are +> well coordinated you should discuss any significant change before starting the +> work. It's recommended that you signal your intention to contribute in the +> issue tracker, either by filing a new issue or by claiming an existing one. + + +# OpenEnv Core + +Core components for OpenEnv - a framework for building HTTP-based agentic environments. + +## Features + +- **EnvClient**: Async-first client for interacting with remote environments +- **SyncEnvClient**: Synchronous wrapper via `.sync()` for sync codebases +- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket +- **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.) +- **Type System**: Strongly-typed Action/Observation/State interfaces +- **Web Interface**: Optional web UI for interacting with environments + +## Installation + +```bash +pip install "openenv[core]" +``` + +For development: +```bash +pip install "openenv[core]" +``` + +## Quick Start + +### Creating an Environment Client + +EnvClient is **async by default**. Use `async with` and `await` for all operations: + +```python +import asyncio +from openenv.core import EnvClient, StepResult +from dataclasses import dataclass +from typing import Any + +@dataclass +class MyAction: + text: str + +@dataclass +class MyObservation: + response: str + +class MyEnvClient(EnvClient[MyAction, MyObservation, Any]): + def _step_payload(self, action: MyAction) -> dict: + return {"text": action.text} + + def _parse_result(self, payload: dict) -> StepResult[MyObservation]: + obs_data = payload["observation"] + return StepResult( + observation=MyObservation(**obs_data), + reward=payload.get("reward"), + done=payload.get("done", False) + ) + + def _parse_state(self, payload: dict) -> Any: + return payload + +# Async usage (recommended) +async def main(): + client = await MyEnvClient.from_docker_image("my-env:latest") + async with client: + result = await client.reset() + step_result = await client.step(MyAction(text="hello")) + +asyncio.run(main()) + +# Sync usage (via .sync() wrapper) +with MyEnvClient(base_url="http://localhost:8000").sync() as client: + result = client.reset() + step_result = client.step(MyAction(text="hello")) +``` + +### Creating an Environment Server + +```python +from openenv.core.env_server import Environment, HTTPEnvServer, create_app +from dataclasses import dataclass + +@dataclass +class MyAction: + text: str + +@dataclass +class MyObservation: + response: str + reward: float = 0.0 + done: bool = False + +class MyEnvironment(Environment): + def reset(self) -> MyObservation: + return MyObservation(response="Ready") + + def step(self, action: MyAction) -> MyObservation: + return MyObservation( + response=f"Echo: {action.text}", + reward=1.0, + done=False + ) + +# Create FastAPI app +env = MyEnvironment() +app = create_app(env, MyAction, MyObservation) + +# Run with: uvicorn module:app --host 0.0.0.0 --port 8000 +``` + +## Container Providers + +OpenEnv Core supports multiple container providers: + +### Local Docker Provider + +```python +from openenv.core.containers.runtime import LocalDockerProvider + +provider = LocalDockerProvider() +base_url = provider.start_container("my-env:latest") +provider.wait_for_ready(base_url) +# Use environment... +provider.stop_container() +``` + +### Kubernetes Provider (Coming Soon) + +```python +from openenv.core.containers.runtime import KubernetesProvider + +provider = KubernetesProvider(namespace="envs") +base_url = provider.start_container("my-env:latest") +# Use environment... +provider.stop_container() +``` + + +## API Reference + +### EnvClient + +Async base class for environment clients. Key methods: + +- `async connect()`: Establish WebSocket connection +- `async reset(**kwargs)`: Reset environment +- `async step(action)`: Execute action +- `async state()`: Get current state +- `async close()`: Close connection and cleanup +- `sync()`: Return a SyncEnvClient wrapper for synchronous usage + +Abstract methods to implement: +- `_step_payload(action)`: Convert action to JSON +- `_parse_result(payload)`: Parse response to StepResult +- `_parse_state(payload)`: Parse state response + +### SyncEnvClient + +Synchronous wrapper around EnvClient. Use `client.sync()` to get one: + +```python +sync_client = async_client.sync() +with sync_client: + result = sync_client.reset() + result = sync_client.step(action) +``` + +### HTTPEnvServer + +Server wrapper with these methods: + +- `register_routes(app)`: Register endpoints on FastAPI app +- `_deserialize_action(data)`: Convert JSON to Action +- `_serialize_observation(obs)`: Convert Observation to JSON + +### Environment Interface + +Base interface for environment implementations: + +- `reset()`: Reset environment and return initial observation +- `step(action)`: Execute action and return observation +- `state`: Property returning current environment state + +## License + +This project is licensed under the BSD-3-Clause License - see the LICENSE file for details. + +## Contributing + +Contributions are welcome! Please see the main OpenEnv repository for contribution guidelines. + +## Links + +- **Homepage**: https://github.com/meta-pytorch/OpenEnv +- **Documentation**: https://github.com/meta-pytorch/OpenEnv/blob/main/README.md +- **Bug Tracker**: https://github.com/meta-pytorch/OpenEnv/issues diff --git a/src/core/openenv/core/__init__.py b/src/core/openenv/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96065d6a80463e2fe599de7728243fc2adad7135 --- /dev/null +++ b/src/core/openenv/core/__init__.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core components for agentic environments.""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING + +from . import env_server +from .env_server import * # noqa: F403 + +if TYPE_CHECKING: + from .env_client import EnvClient + from .generic_client import GenericAction, GenericEnvClient + from .llm_client import ( + AnthropicClient, + create_llm_client, + LLMClient, + LLMResponse, + OpenAIClient, + ToolCall, + ) + from .mcp_client import MCPClientBase, MCPToolClient + from .sync_client import SyncEnvClient + +__all__ = [ + "EnvClient", + "SyncEnvClient", + "GenericEnvClient", + "GenericAction", + "MCPClientBase", + "MCPToolClient", + "AnthropicClient", + "LLMClient", + "LLMResponse", + "OpenAIClient", + "ToolCall", + "create_llm_client", +] + env_server.__all__ # type: ignore + + +_LAZY_ATTRS = { + "EnvClient": (".env_client", "EnvClient"), + "SyncEnvClient": (".sync_client", "SyncEnvClient"), + "GenericEnvClient": (".generic_client", "GenericEnvClient"), + "GenericAction": (".generic_client", "GenericAction"), + "MCPClientBase": (".mcp_client", "MCPClientBase"), + "MCPToolClient": (".mcp_client", "MCPToolClient"), + "AnthropicClient": (".llm_client", "AnthropicClient"), + "LLMClient": (".llm_client", "LLMClient"), + "LLMResponse": (".llm_client", "LLMResponse"), + "OpenAIClient": (".llm_client", "OpenAIClient"), + "ToolCall": (".llm_client", "ToolCall"), + "create_llm_client": (".llm_client", "create_llm_client"), +} + + +def __getattr__(name: str): + if name in _LAZY_ATTRS: + module_path, attr_name = _LAZY_ATTRS[name] + module = import_module(module_path, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value + + try: + value = getattr(env_server, name) + except AttributeError as exc: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc + + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals().keys()) | set(__all__)) diff --git a/src/core/openenv/core/client_types.py b/src/core/openenv/core/client_types.py new file mode 100644 index 0000000000000000000000000000000000000000..c7501c656b66a780f29bf23309aaf00fab8df432 --- /dev/null +++ b/src/core/openenv/core/client_types.py @@ -0,0 +1,23 @@ +# Type definitions for EnvTorch +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar + +# Generic type for observations +ObsT = TypeVar("ObsT") +StateT = TypeVar("StateT") + + +@dataclass +class StepResult(Generic[ObsT]): + """ + Represents the result of one environment step. + + Attributes: + observation: The environment's observation after the action. + reward: Scalar reward for this step (optional). + done: Whether the episode is finished. + """ + + observation: ObsT + reward: Optional[float] = None + done: bool = False diff --git a/src/core/openenv/core/containers/__init__.py b/src/core/openenv/core/containers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38e67ef3cd60bf13a26ef7c8bf23986c3eb5990e --- /dev/null +++ b/src/core/openenv/core/containers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container management for environment servers.""" diff --git a/src/core/openenv/core/containers/images/Dockerfile b/src/core/openenv/core/containers/images/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..97bb1cf5e2ce0e58c82496cced3e58976baead4c --- /dev/null +++ b/src/core/openenv/core/containers/images/Dockerfile @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# OpenEnv Base Image +# +# This is the standard base image for all OpenEnv environment servers. +# It includes the minimal dependencies needed to run HTTP environment servers +# and uv for fast dependency management. +# +# Build from repo root: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . +# Tag: docker tag openenv-base:latest openenv-base:0.2.0 +# + +FROM ghcr.io/astral-sh/uv:0.5.27-python3.11-bookworm-slim AS builder + +# Set working directory +WORKDIR /app + +# Copy core pyproject.toml and lockfile for dependency installation +COPY pyproject.toml uv.lock* ./ + +# Install core dependencies using uv with cache mount +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r pyproject.toml + +# Final runtime stage +FROM python:3.11-slim + +# Set metadata +LABEL maintainer="OpenEnv Team" +LABEL description="Base image for OpenEnv based environment servers with uv" +LABEL version="0.2.0" + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Copy uv from builder +COPY --from=builder /usr/local/bin/uv /usr/local/bin/uvx /usr/local/bin/ + +# Copy installed Python packages from builder +COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages + +# Copy console scripts installed by pip (uvicorn, fastapi, etc.) +COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/fastapi /usr/local/bin/ + +# Set working directory +WORKDIR /app + +# Default environment variables +ENV PYTHONPATH=/app/src +ENV PYTHONUNBUFFERED=1 +ENV UV_SYSTEM_PYTHON=1 + +# Default expose port (can be overridden) +EXPOSE 8000 + +# Note: CMD should be specified in child Dockerfiles diff --git a/src/core/openenv/core/containers/images/README.md b/src/core/openenv/core/containers/images/README.md new file mode 100644 index 0000000000000000000000000000000000000000..69c387909fc487bf4bebb2a18dced2185ecf477d --- /dev/null +++ b/src/core/openenv/core/containers/images/README.md @@ -0,0 +1,92 @@ +# OpenEnv Base Image + +Standard base image for all OpenEnv environment servers. + +## What's Included + +| Layer | Size | Contents | +|-------|------|----------| +| python:3.11-slim | 200 MB | Base Python runtime | +| + Dependencies | 100 MB | FastAPI, uvicorn, requests | +| **Total** | **~300 MB** | Ready for environment servers | + +## Image Sizes + +``` +openenv-base:latest 300 MB (python + fastapi + uvicorn) +``` +echo-env:latest 500 MB (python + fastapi + uvicorn + app) +coding-env:latest 520 MB (python + fastapi + uvicorn + app + tools) +another-env:latest 510 MB (python + fastapi + uvicorn + app) +--- +Total: 1.5 GB (with lots of duplication) +``` + +### With Base Images (✅ Solution) +``` +openenv-base:latest 300 MB (python + fastapi + uvicorn) +echo-env:latest 50 MB (app only, uses base) +coding-env:latest 70 MB (app + tools, uses base) +another-env:latest 45 MB (app only, uses base) +--- +Total: 465 MB (base shared, minimal duplication) +``` + +## Building the Base Image + +```bash +# From project root +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . +``` + +## Usage in Environment Dockerfiles + +Each environment Dockerfile should start with: + +```dockerfile +FROM openenv-base:latest + +# Copy only environment-specific files +COPY src/openenv/core/ /app/src/openenv/core/ +COPY envs/my_env/ /app/envs/my_env/ + +# Run the server +CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +## Base Image Contents + +- Python 3.11-slim +- FastAPI >= 0.104.0 +- Uvicorn >= 0.24.0 +- Requests >= 2.25.0 +- curl (for health checks) + +## Example: Building Echo Environment + +```bash +# Step 1: Build base image (do this once) +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . + +# Step 2: Build echo environment (uses base) +docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile . + +# Step 3: Run echo environment +docker run -p 8000:8000 echo-env:latest +``` + +## Updating the Base + +When dependencies need updating: + +1. Update `src/openenv/core/containers/images/Dockerfile` +2. Rebuild base image +3. Rebuild all environment images (they'll use new base) + +```bash +# Update base +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . + +# Rebuild environments (they automatically use new base) +docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile . +``` diff --git a/src/core/openenv/core/containers/runtime/__init__.py b/src/core/openenv/core/containers/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd514dc2fb78007e4ee1bf1f2e9777864bc76b00 --- /dev/null +++ b/src/core/openenv/core/containers/runtime/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container runtime providers.""" + +from .providers import ( + ContainerProvider, + DockerSwarmProvider, + KubernetesProvider, + LocalDockerProvider, + RuntimeProvider, +) +from .uv_provider import UVProvider + +__all__ = [ + "ContainerProvider", + "DockerSwarmProvider", + "LocalDockerProvider", + "KubernetesProvider", + "RuntimeProvider", + "UVProvider", +] diff --git a/src/core/openenv/core/containers/runtime/daytona_provider.py b/src/core/openenv/core/containers/runtime/daytona_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..08c899fa3f16520dbe7cb8c0804e23250d97f605 --- /dev/null +++ b/src/core/openenv/core/containers/runtime/daytona_provider.py @@ -0,0 +1,572 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Daytona container provider for running OpenEnv environments in Daytona cloud sandboxes. + +Requires the ``daytona`` SDK: ``pip install daytona>=0.10`` +""" + +from __future__ import annotations + +import json +import os +import shlex +import time +from typing import Any, Callable, Dict, Optional + +import yaml + +from .providers import ContainerProvider + + +class DaytonaProvider(ContainerProvider): + """ + Container provider that runs environments in Daytona cloud sandboxes. + + Example: + >>> provider = DaytonaProvider(api_key="your-key") + >>> image = DaytonaProvider.image_from_dockerfile("envs/echo_env/server/Dockerfile") + >>> base_url = provider.start_container(image) + >>> provider.wait_for_ready(base_url) + >>> provider.stop_container() + """ + + _dockerfile_registry: Dict[str, Dict[str, Any]] = {} + + def __init__( + self, + *, + api_key: Optional[str] = None, + public: bool = False, + resources: Optional[Any] = None, + auto_stop_interval: int = 15, + target: Optional[str] = None, + on_snapshot_create_logs: Optional[Callable[[str], None]] = None, + cmd: Optional[str] = None, + create_timeout: float = 300, + ): + """ + Args: + api_key: Daytona API key. Falls back to ``DAYTONA_API_KEY`` env var. + public: If True, the sandbox preview is publicly accessible. + resources: Optional ``daytona.Resources`` instance for CPU/memory. + auto_stop_interval: Minutes of inactivity before auto-stop (0 disables). + target: Daytona target region (e.g. "us"). + on_snapshot_create_logs: Callback for snapshot build log lines. + cmd: Shell command to start the server inside the sandbox. + create_timeout: Seconds to wait for sandbox creation (default 300). + Heavy images (e.g. with Playwright/Chromium) may need more. + """ + from daytona import Daytona, DaytonaConfig + + config_kwargs: Dict[str, Any] = {} + resolved_key = api_key or os.environ.get("DAYTONA_API_KEY") + if resolved_key: + config_kwargs["api_key"] = resolved_key + if target: + config_kwargs["target"] = target + + self._daytona = Daytona(DaytonaConfig(**config_kwargs)) + self._public = public + self._resources = resources + self._auto_stop_interval = auto_stop_interval + self._on_snapshot_create_logs = on_snapshot_create_logs + self._cmd = cmd + self._create_timeout = create_timeout + self._sandbox: Any = None + self._preview_url: Optional[str] = None + + def _discover_server_cmd(self, sandbox: Any, port: int = 8000) -> str: + """Discover the server command from ``openenv.yaml`` inside *sandbox*. + + Finds the file, reads the ``app`` field, and constructs a command + of the form ``cd && python -m uvicorn --host 0.0.0.0 --port ``. + + Raises: + ValueError: If ``openenv.yaml`` is not found or lacks an ``app`` field. + """ + yaml_path = self._find_openenv_yaml(sandbox) + if yaml_path is None: + raise ValueError( + "Could not find openenv.yaml inside the sandbox. " + "Pass an explicit cmd= to DaytonaProvider or start_container()." + ) + + cat_resp = sandbox.process.exec(f"cat {shlex.quote(yaml_path)}", timeout=10) + content = cat_resp.result if hasattr(cat_resp, "result") else str(cat_resp) + app = self._parse_app_field(content) + if app is None: + raise ValueError( + f"openenv.yaml at {yaml_path} does not contain an 'app' field. " + "Pass an explicit cmd= to DaytonaProvider or start_container()." + ) + + # The directory containing openenv.yaml is the env root + env_root = yaml_path.rsplit("/", 1)[0] + return ( + f"cd {shlex.quote(env_root)} && " + f"python -m uvicorn {shlex.quote(app)} --host 0.0.0.0 --port {port}" + ) + + def _find_openenv_yaml(self, sandbox: Any) -> Optional[str]: + """Locate ``openenv.yaml`` inside the sandbox. + + Tries the modern layout path ``/app/env/openenv.yaml`` first, + then falls back to a ``find`` command for the old layout. + """ + # Fast path: modern Dockerfile layout + resp = sandbox.process.exec( + "test -f /app/env/openenv.yaml && echo found", timeout=10 + ) + out = resp.result if hasattr(resp, "result") else str(resp) + if "found" in (out or ""): + return "/app/env/openenv.yaml" + + # Fallback: search for it (redirect stderr so error messages + # like "No such file or directory" don't get mistaken for paths). + resp = sandbox.process.exec( + "find /app -maxdepth 4 -name openenv.yaml -print -quit 2>/dev/null", + timeout=10, + ) + path = (resp.result if hasattr(resp, "result") else str(resp) or "").strip() + if path and path.startswith("/"): + return path + + return None + + @staticmethod + def _parse_app_field(yaml_content: str) -> Optional[str]: + """Extract the ``app`` value from raw openenv.yaml content. + + Uses PyYAML to handle comments, quotes, and nested keys correctly. + """ + try: + data = yaml.safe_load(yaml_content) or {} + except Exception: + return None + + if not isinstance(data, dict): + return None + + value = data.get("app") + if isinstance(value, str): + value = value.strip() + return value if value else None + return None + + @staticmethod + def _parse_dockerfile_cmd(dockerfile_content: str) -> Optional[str]: + """Extract the server command from the last ``CMD`` in a Dockerfile. + + Handles exec form (``CMD ["prog", "arg"]``) and shell form + (``CMD prog arg``). When a Dockerfile has multiple ``CMD`` + instructions (e.g. multi-stage builds), the last one wins - same + semantics as Docker itself. Lines where ``CMD`` appears inside a + comment are ignored. + + Returns: + The command as a single string, or ``None`` if no ``CMD`` found. + """ + import re + + last_cmd: Optional[str] = None + for line in dockerfile_content.splitlines(): + stripped = line.strip() + # Skip comments + if stripped.startswith("#"): + continue + match = re.match(r"CMD\s+(.+)", stripped, flags=re.IGNORECASE) + if match: + last_cmd = match.group(1).strip() + + if last_cmd is None: + return None + + # Exec form: CMD ["executable", "param1", ...] + if last_cmd.startswith("["): + try: + parts = json.loads(last_cmd) + if isinstance(parts, list) and all(isinstance(p, str) for p in parts): + return " ".join(parts) + except (json.JSONDecodeError, TypeError): + pass + + # Shell form: CMD executable param1 ... + return last_cmd if last_cmd else None + + @staticmethod + def strip_buildkit_syntax(dockerfile_content: str) -> str: + """Remove BuildKit ``--mount=...`` flags from ``RUN`` instructions. + + Handles single-line flags, multi-line continuations, and multiple + ``--mount`` flags spread across continuation lines. Only leading + ``--mount`` flags are removed (before the actual command starts). + + Daytona's ``Image.from_dockerfile`` does not support BuildKit + ``--mount`` syntax. This helper strips the flags so that standard + Dockerfiles (like the ones generated by ``openenv build``) can + be used directly. + """ + import re + + def strip_leading_mounts(text: str) -> str: + remaining = text + while True: + match = re.match(r"\s*--mount=\S+\s*", remaining) + if not match: + return remaining + remaining = remaining[match.end() :] + + lines = dockerfile_content.split("\n") + result: list[str] = [] + in_run = False + in_mount_prefix = False + + for line in lines: + line_out = line + run_start = False + if re.match(r"\s*RUN(\s+|$)", line, flags=re.IGNORECASE): + in_run = True + in_mount_prefix = True + run_start = True + + if in_run and in_mount_prefix: + original_ends_with_slash = line_out.rstrip().endswith("\\") + if run_start: + match = re.match(r"(\s*RUN\s+)(.*)$", line_out, flags=re.IGNORECASE) + if match: + run_prefix, remainder = match.group(1), match.group(2) + else: + run_prefix, remainder = line_out, "" + new_remainder = strip_leading_mounts(remainder) + line_out = run_prefix + new_remainder + content_for_check = new_remainder + else: + new_remainder = strip_leading_mounts(line_out) + line_out = new_remainder + content_for_check = new_remainder + + if original_ends_with_slash and not line_out.rstrip().endswith("\\"): + line_out = line_out.rstrip() + " \\" + + if content_for_check.strip() not in ("", "\\"): + in_mount_prefix = False + + if in_run and not line_out.rstrip().endswith("\\"): + in_run = False + in_mount_prefix = False + + result.append(line_out) + + return "\n".join(result) + + @classmethod + def image_from_dockerfile( + cls, + dockerfile_path: str, + context_dir: str | None = None, + ) -> str: + """Validate a Dockerfile and return a ``dockerfile:`` URI for + :meth:`start_container`. + + Eagerly validates the Dockerfile (existence, COPY sources, + BuildKit stripping) and stores the processed content in an + internal registry. The actual ``daytona.Image`` is created + later inside ``start_container``. + + Args: + dockerfile_path: Path to the Dockerfile on disk. + context_dir: Build context directory. Defaults to the + Dockerfile's grandparent directory, matching the + ``openenv init`` convention where Dockerfiles live in + ``/server/Dockerfile`` and the build context is + ``/``. Pass explicitly for non-standard layouts + (e.g. ``context_dir="."`` for repo-root contexts). + + Returns: + A ``"dockerfile:"`` string to pass to + ``start_container``. + + Raises: + FileNotFoundError: If *dockerfile_path* does not exist. + ValueError: If *context_dir* is given but does not exist, + or if COPY sources in the Dockerfile cannot be found + under the resolved context directory. + """ + import pathlib + import re + + src = pathlib.Path(dockerfile_path).resolve() + if not src.is_file(): + raise FileNotFoundError(f"Dockerfile not found: {dockerfile_path}") + + if context_dir is not None: + ctx = pathlib.Path(context_dir) + if not ctx.is_dir(): + raise ValueError(f"context_dir does not exist: {context_dir}") + else: + # Default: grandparent of the Dockerfile, matching the + # openenv init layout (/server/Dockerfile -> /). + ctx = src.parent.parent + + content = src.read_text() + stripped = cls.strip_buildkit_syntax(content) + + # Validate that COPY sources exist under the context directory. + # This catches mismatches early (e.g. a Dockerfile expecting repo + # root as context when we defaulted to the env directory). + for line in stripped.splitlines(): + m = re.match(r"^\s*COPY\s+(?!--from=)(\S+)\s+", line, re.IGNORECASE) + if not m: + continue + copy_src = m.group(1) + if copy_src.startswith("/"): + continue + resolved = ctx / copy_src + if not resolved.exists() and not any(ctx.glob(copy_src)): + raise ValueError( + f"Dockerfile COPY source '{copy_src}' not found " + f"under context_dir '{ctx}'. This Dockerfile may " + f"expect a different build context (e.g. the repo " + f"root). Pass context_dir explicitly." + ) + + # Parse CMD from the original Dockerfile so start_container can + # use it as a fallback when openenv.yaml is unavailable. + parsed_cmd = cls._parse_dockerfile_cmd(content) + + cls._dockerfile_registry[str(src)] = { + "stripped_content": stripped, + "context_dir": str(ctx), + "server_cmd": parsed_cmd, + } + + return f"dockerfile:{src}" + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Create a Daytona sandbox from a Docker image or snapshot. + + Daytona does not execute the image's CMD (known bug — ENTRYPOINT + runs, CMD does not). The server command is resolved in order: + + 1. Explicit ``cmd`` passed to the constructor. + 2. ``cmd`` key in ``**kwargs`` (popped before forwarding). + 3. Auto-discovered from ``openenv.yaml`` inside the sandbox. + 4. ``CMD`` parsed from the Dockerfile (when *image* came from + ``image_from_dockerfile``). + + Args: + image: Docker image name (e.g. ``"echo-env:latest"``), + ``"snapshot:"`` to create from a pre-built snapshot, + or ``"dockerfile:"`` returned by + :meth:`image_from_dockerfile`. + port: Must be ``None`` or ``8000``. Daytona exposes port 8000 + via its preview proxy; other ports raise ``ValueError``. + env_vars: Environment variables forwarded to the sandbox. + **kwargs: ``cmd`` (str) to override the server command; + remaining kwargs passed through to ``Daytona.create()``. + + Returns: + HTTPS preview URL for the sandbox (base_url). + """ + if port is not None and port != 8000: + raise ValueError( + f"DaytonaProvider only supports port 8000 (got {port}). " + "The Daytona preview proxy routes to port 8000 inside the sandbox." + ) + + # Resolve the server command (may be None; discovery happens after + # sandbox creation when we can inspect the filesystem). + cmd = kwargs.pop("cmd", None) or self._cmd + + # CMD parsed from Dockerfile (populated for "dockerfile:" images). + parsed_cmd: Optional[str] = None + + # Build creation params + create_kwargs: Dict[str, Any] = {} + if env_vars: + create_kwargs["env_vars"] = env_vars + if self._public: + create_kwargs["public"] = True + if self._auto_stop_interval != 15: + create_kwargs["auto_stop_interval"] = self._auto_stop_interval + + if image.startswith("snapshot:"): + from daytona import CreateSandboxFromSnapshotParams + + snapshot_name = image[len("snapshot:") :] + params = CreateSandboxFromSnapshotParams( + snapshot=snapshot_name, **create_kwargs + ) + elif image.startswith("dockerfile:"): + from daytona import CreateSandboxFromImageParams, Image + + dockerfile_path = image[len("dockerfile:") :] + meta = self._dockerfile_registry.get(dockerfile_path) + if meta is None: + raise ValueError( + f"No registered Dockerfile metadata for {dockerfile_path}. " + "Call DaytonaProvider.image_from_dockerfile() first." + ) + + parsed_cmd = meta.get("server_cmd") + + # Build the daytona Image from the pre-stripped content. + import pathlib + import uuid + + ctx = pathlib.Path(meta["context_dir"]) + tmp_name = f".daytona-{uuid.uuid4().hex[:8]}.dockerfile" + tmp_path = ctx / tmp_name + try: + tmp_path.write_text(meta["stripped_content"]) + daytona_image = Image.from_dockerfile(str(tmp_path)) + finally: + tmp_path.unlink(missing_ok=True) + + img_kwargs: Dict[str, Any] = { + "image": daytona_image, + **create_kwargs, + } + if self._resources is not None: + img_kwargs["resources"] = self._resources + params = CreateSandboxFromImageParams(**img_kwargs) + else: + from daytona import CreateSandboxFromImageParams + + img_kwargs = {"image": image, **create_kwargs} + if self._resources is not None: + img_kwargs["resources"] = self._resources + params = CreateSandboxFromImageParams(**img_kwargs) + + # Create sandbox + extra: Dict[str, Any] = dict(kwargs) + if self._on_snapshot_create_logs is not None: + extra["on_snapshot_create_logs"] = self._on_snapshot_create_logs + + self._sandbox = self._daytona.create( + params, timeout=self._create_timeout, **extra + ) + + try: + # Discover server command from openenv.yaml if not explicitly set. + if cmd is None: + try: + cmd = self._discover_server_cmd(self._sandbox) + except ValueError: + # Fall back to CMD parsed from Dockerfile (if available). + if parsed_cmd: + cmd = parsed_cmd + else: + raise + + # Wrap in bash -c so compound commands (cd ... && uvicorn ...) + # are handled correctly by nohup. Write PID so we can check + # if the process crashed later in wait_for_ready(). + escaped_cmd = shlex.quote(cmd) + self._sandbox.process.exec( + f"nohup bash -c {escaped_cmd} > /tmp/openenv-server.log 2>&1 &" + " echo $! > /tmp/openenv-server.pid", + timeout=10, + ) + + # Get a signed preview URL for port 8000. The token is + # embedded in the URL itself so no extra headers are needed. + signed = self._sandbox.create_signed_preview_url( + 8000, expires_in_seconds=86400 + ) + self._preview_url = signed.url + except Exception: + self.stop_container() + raise + + return self._preview_url + + def refresh_preview_url(self) -> str: + """Get a fresh signed preview URL (valid for 24h). + + Daytona signed URLs expire after at most 24 hours. Call this to + get a new one for long-running sessions. The returned URL points + to the same sandbox — clients will need to reconnect using it. + """ + if self._sandbox is None: + raise RuntimeError("No active sandbox to refresh URL for.") + signed = self._sandbox.create_signed_preview_url(8000, expires_in_seconds=86400) + self._preview_url = signed.url + return self._preview_url + + def stop_container(self) -> None: + """Delete the Daytona sandbox.""" + if self._sandbox is None: + return + + try: + self._daytona.delete(self._sandbox) + finally: + self._sandbox = None + self._preview_url = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None: + """ + Poll the /health endpoint until the sandbox is ready. + + Uses a longer default timeout (120s) than Docker providers because + Daytona sandboxes may have cold-start latency. + + Args: + base_url: Preview URL returned by ``start_container()``. + timeout_s: Maximum seconds to wait. + + Raises: + TimeoutError: If the sandbox doesn't become ready in time. + RuntimeError: If the server process died (detected via PID check). + """ + import requests + + health_url = f"{base_url}/health" + + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + response = requests.get(health_url, timeout=5.0) + if response.status_code == 200: + return + except requests.RequestException: + pass + + # Early exit: if the server process died, raise immediately + # instead of waiting for the full health-check timeout. + if self._sandbox is not None: + resp = self._sandbox.process.exec( + "kill -0 $(cat /tmp/openenv-server.pid) 2>/dev/null" + " && echo RUNNING || echo DEAD", + timeout=10, + ) + out = resp.result if hasattr(resp, "result") else str(resp) + if "DEAD" in (out or ""): + log_resp = self._sandbox.process.exec( + "cat /tmp/openenv-server.log 2>/dev/null", timeout=10 + ) + log = ( + log_resp.result + if hasattr(log_resp, "result") + else str(log_resp) + ) + raise RuntimeError(f"Server process died.\nLog:\n{log}") + + time.sleep(1.0) + + raise TimeoutError( + f"Daytona sandbox at {base_url} did not become ready within {timeout_s}s" + ) diff --git a/src/core/openenv/core/containers/runtime/providers.py b/src/core/openenv/core/containers/runtime/providers.py new file mode 100644 index 0000000000000000000000000000000000000000..54232a2495746f89cc81590ca87d03e6e48e3d2b --- /dev/null +++ b/src/core/openenv/core/containers/runtime/providers.py @@ -0,0 +1,669 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Container provider abstractions for running environment servers. + +This module provides a pluggable architecture for different container providers +(local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Sequence + + +class ContainerProvider(ABC): + """ + Abstract base class for container providers. + + Providers implement this interface to support different container platforms: + - LocalDockerProvider: Runs containers on local Docker daemon + - KubernetesProvider: Runs containers in Kubernetes cluster + - FargateProvider: Runs containers on AWS Fargate + - CloudRunProvider: Runs containers on Google Cloud Run + + The provider manages a single container lifecycle and provides the base URL + for connecting to it. + + Example: + >>> provider = LocalDockerProvider() + >>> base_url = provider.start_container("echo-env:latest") + >>> print(base_url) # http://localhost:8000 + >>> # Use the environment via base_url + >>> provider.stop_container() + """ + + @abstractmethod + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a container from the specified image. + + Args: + image: Container image name (e.g., "echo-env:latest") + port: Port to expose (if None, provider chooses) + env_vars: Environment variables to pass to container + **kwargs: Provider-specific options + + Returns: + Base URL to connect to the container (e.g., "http://localhost:8000") + + Raises: + RuntimeError: If container fails to start + """ + pass + + @abstractmethod + def stop_container(self) -> None: + """ + Stop and remove the running container. + + This cleans up the container that was started by start_container(). + """ + pass + + @abstractmethod + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for the container to be ready to accept requests. + + This typically polls the /health endpoint until it returns 200. + + Args: + base_url: Base URL of the container + timeout_s: Maximum time to wait + + Raises: + TimeoutError: If container doesn't become ready in time + """ + pass + + +class LocalDockerProvider(ContainerProvider): + """ + Container provider for local Docker daemon. + + This provider runs containers on the local machine using Docker. + Useful for development and testing. + + Example: + >>> provider = LocalDockerProvider() + >>> base_url = provider.start_container("echo-env:latest") + >>> # Container running on http://localhost: + >>> provider.stop_container() + """ + + def __init__(self): + """Initialize the local Docker provider.""" + self._container_id: Optional[str] = None + self._container_name: Optional[str] = None + + # Check if Docker is available + import subprocess + + try: + subprocess.run( + ["docker", "version"], + check=True, + capture_output=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): + raise RuntimeError( + "Docker is not available. Please install Docker Desktop or Docker Engine." + ) + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a Docker container locally. + + Args: + image: Docker image name + port: Port to expose (if None, finds available port) + env_vars: Environment variables for the container + **kwargs: Additional Docker run options + + Returns: + Base URL to connect to the container + """ + import subprocess + import time + + # Find available port if not specified + if port is None: + port = self._find_available_port() + + # Generate container name + self._container_name = self._generate_container_name(image) + + # Build docker run command + cmd = [ + "docker", + "run", + "-d", # Detached + "--name", + self._container_name, + "-p", + f"{port}:8000", # Map port + ] + + # Add environment variables + if env_vars: + for key, value in env_vars.items(): + cmd.extend(["-e", f"{key}={value}"]) + + # Add image + cmd.append(image) + + # Run container + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + self._container_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}" + raise RuntimeError(error_msg) from e + + # Wait a moment for container to start + time.sleep(1) + + base_url = f"http://localhost:{port}" + return base_url + + def stop_container(self) -> None: + """ + Stop and remove the Docker container. + """ + if self._container_id is None: + return + + import subprocess + + try: + # Stop container + subprocess.run( + ["docker", "stop", self._container_id], + capture_output=True, + check=True, + timeout=10, + ) + + # Remove container + subprocess.run( + ["docker", "rm", self._container_id], + capture_output=True, + check=True, + timeout=10, + ) + except subprocess.CalledProcessError: + # Container might already be stopped/removed + pass + finally: + self._container_id = None + self._container_name = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for container to be ready by polling /health endpoint. + + Args: + base_url: Base URL of the container + timeout_s: Maximum time to wait + + Raises: + TimeoutError: If container doesn't become ready + """ + import time + + import requests + + start_time = time.time() + health_url = f"{base_url}/health" + + # Bypass proxy for localhost to avoid proxy issues + proxies = {"http": None, "https": None} + + while time.time() - start_time < timeout_s: + try: + response = requests.get(health_url, timeout=2.0, proxies=proxies) + if response.status_code == 200: + return + except requests.RequestException: + pass + + time.sleep(0.5) + + raise TimeoutError( + f"Container at {base_url} did not become ready within {timeout_s}s" + ) + + def _find_available_port(self) -> int: + """ + Find an available port on localhost. + + Returns: + An available port number + """ + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _generate_container_name(self, image: str) -> str: + """ + Generate a unique container name based on image name and timestamp. + + Args: + image: Docker image name + + Returns: + A unique container name + """ + import time + + clean_image = image.split("/")[-1].split(":")[0] + timestamp = int(time.time() * 1000) + return f"{clean_image}-{timestamp}" + + +class DockerSwarmProvider(ContainerProvider): + """ + Container provider that uses Docker Swarm services for local concurrency. + + This provider creates a replicated Swarm service backed by the local Docker + engine. The built-in load-balancer fans requests across the replicas, + allowing multiple container instances to run concurrently on the developer + workstation (mirroring the workflow described in the Docker stack docs). + """ + + def __init__( + self, + *, + auto_init_swarm: bool = True, + overlay_network: Optional[str] = None, + ): + """ + Args: + auto_init_swarm: Whether to call ``docker swarm init`` when Swarm + is not active. Otherwise, user must manually initialize Swarm. + overlay_network: Optional overlay network name for the service. + When provided, the network is created with + ``docker network create --driver overlay --attachable`` if it + does not already exist. + """ + self._service_name: Optional[str] = None + self._service_id: Optional[str] = None + self._published_port: Optional[int] = None + self._overlay_network = overlay_network + self._auto_init_swarm = auto_init_swarm + + self._ensure_docker_available() + self._ensure_swarm_initialized() + if self._overlay_network: + self._ensure_overlay_network(self._overlay_network) + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start (or scale) a Swarm service for the given image. + + Supported kwargs: + replicas (int): Number of container replicas (default: 2). + cpu_limit (float | str): CPU limit passed to ``--limit-cpu``. + memory_limit (str): Memory limit passed to ``--limit-memory``. + constraints (Sequence[str]): Placement constraints. + labels (Dict[str, str]): Service labels. + command (Sequence[str] | str): Override container command. + """ + import shlex + import subprocess + import time + + allowed_kwargs = { + "replicas", + "cpu_limit", + "memory_limit", + "constraints", + "labels", + "command", + } + unknown = set(kwargs) - allowed_kwargs + if unknown: + raise ValueError(f"Unsupported kwargs for DockerSwarmProvider: {unknown}") + + replicas = int(kwargs.get("replicas", 2)) + cpu_limit = kwargs.get("cpu_limit") + memory_limit = kwargs.get("memory_limit") + constraints: Optional[Sequence[str]] = kwargs.get("constraints") + labels: Optional[Dict[str, str]] = kwargs.get("labels") + command_override = kwargs.get("command") + + if port is None: + port = self._find_available_port() + + self._service_name = self._generate_service_name(image) + self._published_port = port + + cmd = [ + "docker", + "service", + "create", + "--detach", + "--name", + self._service_name, + "--replicas", + str(max(1, replicas)), + "--publish", + f"{port}:8000", + ] + + if self._overlay_network: + cmd.extend(["--network", self._overlay_network]) + + if env_vars: + for key, value in env_vars.items(): + cmd.extend(["--env", f"{key}={value}"]) + + if cpu_limit is not None: + cmd.extend(["--limit-cpu", str(cpu_limit)]) + + if memory_limit is not None: + cmd.extend(["--limit-memory", str(memory_limit)]) + + if constraints: + for constraint in constraints: + cmd.extend(["--constraint", constraint]) + + if labels: + for key, value in labels.items(): + cmd.extend(["--label", f"{key}={value}"]) + + cmd.append(image) + + if command_override: + if isinstance(command_override, str): + cmd.extend(shlex.split(command_override)) + else: + cmd.extend(command_override) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + ) + self._service_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = ( + "Failed to start Docker Swarm service.\n" + f"Command: {' '.join(cmd)}\n" + f"Exit code: {e.returncode}\n" + f"Stdout: {e.stdout}\n" + f"Stderr: {e.stderr}" + ) + raise RuntimeError(error_msg) from e + + # Give Swarm a brief moment to schedule the tasks. + time.sleep(1.0) + + return f"http://localhost:{port}" + + def stop_container(self) -> None: + """ + Remove the Swarm service (and keep the Swarm manager running). + """ + if not self._service_name: + return + + import subprocess + + try: + subprocess.run( + ["docker", "service", "rm", self._service_name], + capture_output=True, + check=True, + timeout=10, + ) + except subprocess.CalledProcessError: + # Service may already be gone; ignore. + pass + finally: + self._service_name = None + self._service_id = None + self._published_port = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for at least one replica to become healthy by polling /health. + + Note: With Swarm's load balancer, requests round-robin across replicas, + so this only verifies that at least one replica is responding. Some + replicas may still be starting when this returns. + """ + import time + + import requests + + deadline = time.time() + timeout_s + health_url = f"{base_url}/health" + + # Bypass proxy for localhost to avoid proxy issues + proxies = {"http": None, "https": None} + + while time.time() < deadline: + try: + response = requests.get(health_url, timeout=2.0, proxies=proxies) + if response.status_code == 200: + return + except requests.RequestException: + pass + + time.sleep(0.5) + + raise TimeoutError( + f"Swarm service at {base_url} did not become ready within {timeout_s}s" + ) + + def _ensure_docker_available(self) -> None: + import subprocess + + try: + subprocess.run( + ["docker", "version"], + check=True, + capture_output=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ) as exc: + raise RuntimeError( + "Docker is not available. Please install Docker Desktop or Docker Engine." + ) from exc + + def _ensure_swarm_initialized(self) -> None: + import subprocess + + try: + result = subprocess.run( + ["docker", "info", "--format", "{{.Swarm.LocalNodeState}}"], + capture_output=True, + text=True, + check=True, + timeout=5, + ) + state = result.stdout.strip().lower() + if state == "active": + return + except subprocess.CalledProcessError: + state = "unknown" + + if not self._auto_init_swarm: + raise RuntimeError( + f"Docker Swarm is not active (state={state}). Enable Swarm manually or pass auto_init_swarm=True." + ) + + try: + subprocess.run( + ["docker", "swarm", "init"], + check=True, + capture_output=True, + timeout=10, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError("Failed to initialize Docker Swarm") from e + + def _ensure_overlay_network(self, network: str) -> None: + import subprocess + + inspect = subprocess.run( + ["docker", "network", "inspect", network], + capture_output=True, + text=True, + check=False, + ) + if inspect.returncode == 0: + return + + try: + subprocess.run( + [ + "docker", + "network", + "create", + "--driver", + "overlay", + "--attachable", + network, + ], + check=True, + capture_output=True, + timeout=10, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to create overlay network '{network}'") from e + + def _find_available_port(self) -> int: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _generate_service_name(self, image: str) -> str: + import time + + clean_image = image.split("/")[-1].split(":")[0] + timestamp = int(time.time() * 1000) + return f"{clean_image}-swarm-{timestamp}" + + +class KubernetesProvider(ContainerProvider): + """ + Container provider for Kubernetes clusters. + + This provider creates pods in a Kubernetes cluster and exposes them + via services or port-forwarding. + + Example: + >>> provider = KubernetesProvider(namespace="envtorch-dev") + >>> base_url = provider.start_container("echo-env:latest") + >>> # Pod running in k8s, accessible via service or port-forward + >>> provider.stop_container() + """ + + pass + + +class RuntimeProvider(ABC): + """ + Abstract base class for runtime providers that are not container providers. + Providers implement this interface to support different runtime platforms: + - UVProvider: Runs environments via `uv run` + + The provider manages a single runtime lifecycle and provides the base URL + for connecting to it. + + Example: + >>> provider = UVProvider(project_path="/path/to/env") + >>> base_url = provider.start() + >>> print(base_url) # http://localhost:8000 + >>> provider.stop() + """ + + @abstractmethod + def start( + self, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a runtime from the specified image. + + Args: + image: Runtime image name + port: Port to expose (if None, provider chooses) + env_vars: Environment variables for the runtime + **kwargs: Additional runtime options + """ + + @abstractmethod + def stop(self) -> None: + """ + Stop the runtime. + """ + pass + + @abstractmethod + def wait_for_ready(self, timeout_s: float = 30.0) -> None: + """ + Wait for the runtime to be ready to accept requests. + """ + pass + + def __enter__(self) -> "RuntimeProvider": + """ + Enter the runtime provider. + """ + self.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + """ + Exit the runtime provider. + """ + self.stop() + return False diff --git a/src/core/openenv/core/containers/runtime/uv_provider.py b/src/core/openenv/core/containers/runtime/uv_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3ddc89b9bdccbd0d18604c3de5f49fd3cbc74612 --- /dev/null +++ b/src/core/openenv/core/containers/runtime/uv_provider.py @@ -0,0 +1,224 @@ +"""Providers for launching ASGI applications via ``uv run``.""" + +from __future__ import annotations + +import os +import socket +import subprocess +import time +from typing import Dict, Optional + +import requests + +from .providers import RuntimeProvider + + +def _check_uv_installed() -> None: + try: + subprocess.check_output(["uv", "--version"]) + except FileNotFoundError as exc: + raise RuntimeError( + "`uv` executable not found. Install uv from https://docs.astral.sh and ensure it is on PATH." + ) from exc + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + sock.listen(1) + return sock.getsockname()[1] + + +def _create_uv_command( + *, + host: str, + port: int, + reload: bool, + workers: int, + app: str, + project_path: str, +) -> list[str]: + command: list[str] = ["uv", "run", "--isolated", "--project", project_path] + + command.append("--") + command.extend( + [ + "uvicorn", + app, + "--host", + host, + "--port", + str(port), + "--workers", + str(workers), + ] + ) + + if reload: + command.append("--reload") + + return command + + +def _poll_health(health_url: str, timeout_s: float) -> None: + """Poll a health endpoint until it returns HTTP 200 or times out.""" + + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + timeout = max(0.0001, min(deadline - time.time(), 2.0)) + response = requests.get(health_url, timeout=timeout) + if response.status_code == 200: + return + except requests.RequestException: + continue + + time.sleep(0.5) + + raise TimeoutError(f"Server did not become ready within {timeout_s:.1f} seconds") + + +class UVProvider(RuntimeProvider): + """ + RuntimeProvider implementation backed by ``uv run``. + + Args: + project_path: Local path to a uv project (passed to ``uv run --project``) + app: ASGI application path for uvicorn (defaults to ``server.app:app``) + host: Host interface to bind to (defaults to ``0.0.0.0``) + reload: Whether to enable uvicorn's reload mode + env_vars: Environment variables to pass through to the spawned process + context_timeout_s: How long to wait for the environment to become ready + + Example: + >>> provider = UVProvider(project_path="/path/to/env") + >>> base_url = provider.start() + >>> print(base_url) # http://localhost:8000 + >>> # Use the environment via base_url + >>> provider.stop() + """ + + def __init__( + self, + *, + project_path: str, + app: str = "server.app:app", + host: str = "0.0.0.0", + reload: bool = False, + env_vars: Optional[Dict[str, str]] = None, + context_timeout_s: float = 60.0, + ): + """Initialize the UVProvider.""" + self.project_path = os.path.abspath(project_path) + self.app = app + self.host = host + self.reload = reload + self.env_vars = env_vars + self.context_timeout_s = context_timeout_s + _check_uv_installed() + self._process = None + self._base_url = None + + def start( + self, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + workers: int = 1, + **_: Dict[str, str], + ) -> str: + """ + Start the environment via `uv run`. + + Args: + port: The port to bind the environment to + env_vars: Environment variables to pass to the environment + workers: The number of workers to use + + Returns: + The base URL of the environment + + Raises: + RuntimeError: If the environment is already running + """ + if self._process is not None and self._process.poll() is None: + raise RuntimeError("UVProvider is already running") + + bind_port = port or _find_free_port() + + command = _create_uv_command( + host=self.host, + port=bind_port, + reload=self.reload, + workers=workers, + app=self.app, + project_path=self.project_path, + ) + + env = os.environ.copy() + + if self.env_vars: + env.update(self.env_vars) + if env_vars: + env.update(env_vars) + + try: + self._process = subprocess.Popen(command, env=env) + except OSError as exc: + raise RuntimeError(f"Failed to launch `uv run`: {exc}") from exc + + client_host = "127.0.0.1" if self.host in {"0.0.0.0", "::"} else self.host + self._base_url = f"http://{client_host}:{bind_port}" + return self._base_url + + def wait_for_ready(self, timeout_s: float = 60.0) -> None: + """ + Wait for the environment to become ready. + + Args: + timeout_s: The timeout to wait for the environment to become ready + + Raises: + RuntimeError: If the environment is not running + TimeoutError: If the environment does not become ready within the timeout + """ + if self._process and self._process.poll() is not None: + code = self._process.returncode + raise RuntimeError(f"uv process exited prematurely with code {code}") + + _poll_health(f"{self._base_url}/health", timeout_s=timeout_s) + + def stop(self) -> None: + """ + Stop the environment. + + Raises: + RuntimeError: If the environment is not running + """ + if self._process is None: + return + + if self._process.poll() is None: + self._process.terminate() + try: + self._process.wait(timeout=10.0) + except subprocess.TimeoutExpired: + self._process.kill() + self._process.wait(timeout=5.0) + + self._process = None + self._base_url = None + + @property + def base_url(self) -> str: + """ + The base URL of the environment. + + Returns: + The base URL of the environment + + Raises: + RuntimeError: If the environment is not running + """ + if self._base_url is None: + raise RuntimeError("UVProvider has not been started") + return self._base_url diff --git a/src/core/openenv/core/containers/test_local_docker_provider.py b/src/core/openenv/core/containers/test_local_docker_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..ac520a4b68afa699894dd68c0508b1e41936704c --- /dev/null +++ b/src/core/openenv/core/containers/test_local_docker_provider.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +""" +End-to-end test for LocalDockerProvider. + +This script tests the complete flow: +1. Start a container using LocalDockerProvider +2. Wait for it to be ready +3. Make HTTP requests to test the environment +4. Clean up the container +""" + +import sys +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import requests +from openenv.core.containers.runtime import LocalDockerProvider + + +# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env +def test_local_docker_provider(): + """Test LocalDockerProvider end-to-end.""" + print("=" * 60) + print("LocalDockerProvider End-to-End Test") + print("=" * 60) + print() + + provider = None + + try: + # Step 1: Create provider + print("Step 1: Creating LocalDockerProvider...") + provider = LocalDockerProvider() + print("✓ Provider created\n") + + # Step 2: Start container + print("Step 2: Starting echo-env container...") + base_url = provider.start_container("echo-env:latest") + print(f"✓ Container started at: {base_url}") + if provider._container_id: + print(f" Container ID: {provider._container_id[:12]}...") + if provider._container_name: + print(f" Container name: {provider._container_name}\n") + + # Step 3: Wait for ready + print("Step 3: Waiting for container to be ready...") + provider.wait_for_ready(base_url, timeout_s=30.0) + print("✓ Container is ready!\n") + + # Step 4: Test health endpoint + print("Step 4: Testing /health endpoint...") + response = requests.get(f"{base_url}/health") + print(f" Status: {response.status_code}") + print(f" Response: {response.json()}") + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + print("✓ Health check passed\n") + + # Step 5: Test reset endpoint + print("Step 5: Testing /reset endpoint...") + response = requests.post( + f"{base_url}/reset", + json={}, + headers={"Content-Type": "application/json"}, + ) + print(f" Status: {response.status_code}") + data = response.json() + print(f" Message: {data['observation']['echoed_message']}") + print(f" Reward: {data['reward']}") + print(f" Done: {data['done']}") + assert response.status_code == 200 + assert data["observation"]["echoed_message"] == "Echo environment ready!" + print("✓ Reset test passed\n") + + # Step 6: Test step endpoint + print("Step 6: Testing /step endpoint...") + response = requests.post( + f"{base_url}/step", + json={"action": {"message": "Hello from LocalDockerProvider!"}}, + headers={"Content-Type": "application/json"}, + ) + print(f" Status: {response.status_code}") + data = response.json() + print(f" Echoed: {data['observation']['echoed_message']}") + print(f" Length: {data['observation']['message_length']}") + print(f" Reward: {data['reward']}") + assert response.status_code == 200 + assert ( + data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!" + ) + assert data["observation"]["message_length"] == 31 + print("✓ Step test passed\n") + + # Step 7: Test state endpoint + print("Step 7: Testing /state endpoint...") + response = requests.get(f"{base_url}/state") + print(f" Status: {response.status_code}") + data = response.json() + print(f" Episode ID: {data['episode_id']}") + print(f" Step count: {data['step_count']}") + assert response.status_code == 200 + assert data["step_count"] == 1 # One step from above + print("✓ State test passed\n") + + # Step 8: Multiple steps + print("Step 8: Testing multiple steps...") + for i in range(3): + response = requests.post( + f"{base_url}/step", + json={"action": {"message": f"Message {i + 1}"}}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200 + print(f" Step {i + 1}: ✓") + + # Check state updated + response = requests.get(f"{base_url}/state") + data = response.json() + assert data["step_count"] == 4 # 1 + 3 more steps + print(f" Final step count: {data['step_count']}") + print("✓ Multiple steps test passed\n") + + print("=" * 60) + print("✓ All tests passed!") + print("=" * 60) + print() + + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + import traceback + + traceback.print_exc() + return False + + finally: + # Step 9: Cleanup + if provider is not None: + print("\nStep 9: Cleaning up container...") + try: + provider.stop_container() + print("✓ Container stopped and removed\n") + except Exception as e: + print(f"⚠️ Cleanup warning: {e}\n") + + +def test_provider_with_custom_port(): + """Test provider with custom port.""" + print("=" * 60) + print("LocalDockerProvider with Custom Port Test") + print("=" * 60) + print() + + provider = None + + try: + provider = LocalDockerProvider() + + print("Starting container on custom port 8123...") + base_url = provider.start_container("echo-env:latest", port=8123) + print(f"✓ Started at: {base_url}") + assert ":8123" in base_url + + print("Waiting for ready...") + provider.wait_for_ready(base_url) + print("✓ Ready!") + + print("Testing health...") + response = requests.get(f"{base_url}/health") + assert response.status_code == 200 + print("✓ Health check passed") + + print("\n✓ Custom port test passed!\n") + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + return False + + finally: + if provider is not None: + provider.stop_container() + print("✓ Cleaned up\n") + + +def test_provider_with_env_vars(): + """Test provider with environment variables.""" + print("=" * 60) + print("LocalDockerProvider with Environment Variables Test") + print("=" * 60) + print() + + provider = None + + try: + provider = LocalDockerProvider() + + print("Starting container with environment variables...") + base_url = provider.start_container( + "echo-env:latest", env_vars={"DEBUG": "true", "LOG_LEVEL": "info"} + ) + print(f"✓ Started at: {base_url}") + + print("Waiting for ready...") + provider.wait_for_ready(base_url) + print("✓ Ready!") + + print("Testing health...") + response = requests.get(f"{base_url}/health") + assert response.status_code == 200 + print("✓ Health check passed") + + print("\n✓ Environment variables test passed!\n") + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + return False + + finally: + if provider is not None: + provider.stop_container() + print("✓ Cleaned up\n") + + +if __name__ == "__main__": + print() + print("🐳 LocalDockerProvider Test Suite") + print() + + results = [] + + # Run basic test + results.append(("Basic End-to-End", test_local_docker_provider())) + + # Run custom port test + results.append(("Custom Port", test_provider_with_custom_port())) + + # Run environment variables test + results.append(("Environment Variables", test_provider_with_env_vars())) + + # Summary + print("=" * 60) + print("Test Summary") + print("=" * 60) + for name, passed in results: + status = "✓ PASSED" if passed else "✗ FAILED" + print(f"{name:25} {status}") + print("=" * 60) + + all_passed = all(result for _, result in results) + if all_passed: + print("\n🎉 All tests passed!") + exit(0) + else: + print("\n❌ Some tests failed") + exit(1) diff --git a/src/core/openenv/core/env_client.py b/src/core/openenv/core/env_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceb344bca20d55d2f9e7ba9aa39595ef61fca30 --- /dev/null +++ b/src/core/openenv/core/env_client.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Environment client for persistent sessions. + +This module provides a WebSocket-based client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. + +The client is async by default. For synchronous usage, use the `.sync()` method +to get a `SyncEnvClient` wrapper. + +Example (async): + >>> async with GenericEnvClient(base_url="ws://localhost:8000") as env: + ... result = await env.reset() + ... result = await env.step({"code": "print('hello')"}) + +Example (sync wrapper): + >>> env = GenericEnvClient(base_url="ws://localhost:8000").sync() + >>> with env: + ... result = env.reset() + ... result = env.step({"code": "print('hello')"}) +""" + +from __future__ import annotations + +import asyncio +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StateT, StepResult +from .containers.runtime import LocalDockerProvider, UVProvider +from .utils import convert_to_ws_url + +if TYPE_CHECKING: + from websockets.asyncio.client import ClientConnection + + from .containers.runtime import ContainerProvider, RuntimeProvider + from .sync_client import SyncEnvClient + +from websockets.asyncio.client import connect as ws_connect + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +EnvClientT = TypeVar("EnvClientT", bound="EnvClient") + + +class EnvClient(ABC, Generic[ActT, ObsT, StateT]): + """ + Async environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + The client is async by default. For synchronous usage, use the `.sync()` + method to get a `SyncEnvClient` wrapper. + + Features: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + - Async by default for modern Python async/await patterns + + Example (async): + >>> from envs.coding_env.client import CodingEnv + >>> + >>> # Connect to a server using async context manager + >>> async with CodingEnv(base_url="ws://localhost:8000") as env: + ... result = await env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = await env.step(action) + + Example (sync wrapper): + >>> env = CodingEnv(base_url="ws://localhost:8000").sync() + >>> with env: + ... result = env.reset(seed=42) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + max_message_size_mb: float = 100.0, + provider: Optional["ContainerProvider | RuntimeProvider"] = None, + mode: Optional[str] = None, + ): + """ + Initialize environment client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + max_message_size_mb: Maximum WebSocket message size in megabytes. + Default 100MB to handle large observations (screenshots, DOM, etc.) + provider: Optional container/runtime provider for lifecycle management. + Can be a ContainerProvider (Docker) or RuntimeProvider (UV). + mode: Communication mode: 'simulation' for Gym-style API (default) or + 'production' for MCP JSON-RPC protocol. Can also be set via the + OPENENV_CLIENT_MODE environment variable. Constructor parameter + takes precedence over environment variable. Case-insensitive. + """ + # Determine mode (constructor > env var > default) + if mode is None: + mode = os.environ.get("OPENENV_CLIENT_MODE", "simulation") + + # Normalize and validate mode + mode = mode.lower() + if mode not in ("simulation", "production"): + raise ValueError( + f"Invalid mode: '{mode}'. Must be 'simulation' or 'production'. " + f"Set via constructor parameter or OPENENV_CLIENT_MODE environment variable." + ) + + # Store mode (use object.__setattr__ to bypass immutability) + object.__setattr__(self, "_mode", mode) + + # Convert HTTP URL to WebSocket URL + ws_url = convert_to_ws_url(base_url) + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._max_message_size = int( + max_message_size_mb * 1024 * 1024 + ) # Convert MB to bytes + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def __setattr__(self, name: str, value: Any) -> None: + """Prevent modification of _mode after initialization.""" + if name == "_mode" and hasattr(self, "_mode"): + raise AttributeError("Cannot modify mode after initialization") + super().__setattr__(name, value) + + async def connect(self) -> "EnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + # Bypass proxy for localhost connections + ws_url_lower = self._ws_url.lower() + is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower + + old_no_proxy = os.environ.get("NO_PROXY") + if is_localhost: + # Set NO_PROXY to bypass proxy for localhost + current_no_proxy = old_no_proxy or "" + if "localhost" not in current_no_proxy.lower(): + os.environ["NO_PROXY"] = ( + f"{current_no_proxy},localhost,127.0.0.1" + if current_no_proxy + else "localhost,127.0.0.1" + ) + + try: + self._ws = await ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + max_size=self._max_message_size, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + finally: + # Restore original NO_PROXY value + if is_localhost: + if old_no_proxy is None: + os.environ.pop("NO_PROXY", None) + else: + os.environ["NO_PROXY"] = old_no_proxy + + return self + + async def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + await self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + async def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + await self.connect() + + async def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + await self._ensure_connected() + assert self._ws is not None + await self._ws.send(json.dumps(message)) + + async def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = await asyncio.wait_for(self._ws.recv(), timeout=self._message_timeout) + return json.loads(raw) + + async def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + await self._send(message) + response = await self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + async def from_docker_image( + cls: Type[EnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + await client.connect() + + return client + + @classmethod + async def from_env( + cls: Type[EnvClientT], + repo_id: str, + *, + use_docker: bool = True, + provider: Optional["ContainerProvider | RuntimeProvider"] = None, + **provider_kwargs: Any, + ) -> EnvClientT: + """ + Create a client from a Hugging Face Space. + + Args: + repo_id: Hugging Face space identifier ``{org}/{space}``. + use_docker: When ``True`` (default) pull from the HF registry and + launch via :class:`LocalDockerProvider`. When ``False`` run the + space locally with :class:`UVProvider`. + provider: Optional provider instance to reuse. Must be a + :class:`ContainerProvider` when ``use_docker=True`` and a + :class:`RuntimeProvider` otherwise. + provider_kwargs: Additional keyword arguments forwarded to + either the container provider's ``start_container`` (docker) + or to the ``UVProvider`` constructor/start (uv). When + ``use_docker=False``, the ``project_path`` argument can be + used to override the default git URL + (``git+https://huggingface.co/spaces/{repo_id}``). + + Returns: + Connected client instance + + Examples: + >>> # Pull and run from HF Docker registry + >>> env = await MyEnv.from_env("openenv/echo-env") + >>> + >>> # Run locally with UV (clones the space) + >>> env = await MyEnv.from_env("openenv/echo-env", use_docker=False) + >>> + >>> # Run from a local checkout + >>> env = await MyEnv.from_env( + ... "openenv/echo-env", + ... use_docker=False, + ... project_path="/path/to/local/checkout" + ... ) + """ + # Extract start args that apply to both providers + start_args = {} + for key in ("port", "env_vars", "workers"): + if key in provider_kwargs: + start_args[key] = provider_kwargs.pop(key) + + if use_docker: + # Docker mode: pull from HF registry + docker_provider = provider or LocalDockerProvider() + tag = provider_kwargs.pop("tag", "latest") + image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + base_url = docker_provider.start_container( + image, **start_args, **provider_kwargs + ) + docker_provider.wait_for_ready(base_url) + + client = cls(base_url=base_url, provider=docker_provider) + await client.connect() + return client + else: + # UV mode: clone and run with uv + if provider is None: + uv_kwargs = dict(provider_kwargs) + project_path = uv_kwargs.pop("project_path", None) + if project_path is None: + project_path = f"git+https://huggingface.co/spaces/{repo_id}" + + provider = UVProvider(project_path=project_path, **uv_kwargs) + else: + if provider_kwargs: + raise ValueError( + "provider_kwargs cannot be used when supplying a provider instance" + ) + + base_url = provider.start(**start_args) + provider.wait_for_ready() + + client = cls(base_url=base_url, provider=provider) + await client.connect() + return client + + @abstractmethod + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + async def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = await self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = await self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + async def state(self) -> StateT: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = await self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + async def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image() or from_env(), + this will also stop and remove the associated container/process. + """ + await self.disconnect() + + if self._provider is not None: + # Handle both ContainerProvider and RuntimeProvider + if hasattr(self._provider, "stop_container"): + self._provider.stop_container() + elif hasattr(self._provider, "stop"): + self._provider.stop() + + async def __aenter__(self) -> "EnvClient": + """Enter async context manager, ensuring connection is established.""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit async context manager, closing connection.""" + await self.close() + + def __enter__(self) -> "EnvClient": + """Sync context manager entry - raises error suggesting async usage.""" + raise TypeError( + "EnvClient is async by default. Use 'async with' instead of 'with', " + "or call .sync() to get a synchronous wrapper:\n" + " async with client: # async usage\n" + " with client.sync(): # sync wrapper" + ) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Sync context manager exit - should not be reached.""" + pass # pragma: no cover + + def sync(self) -> "SyncEnvClient": + """ + Return a synchronous wrapper around this async client. + + Use this method when you need synchronous access to the environment + without async/await syntax. This is useful for: + - Integration with synchronous codebases + - Interactive/REPL usage + - Stopping async from "infecting" the call stack + + Returns: + SyncEnvClient wrapper that provides synchronous methods + + Example: + >>> # Create async client and get sync wrapper + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous API + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"code": "print('hello')"}) + """ + from .sync_client import SyncEnvClient + + return SyncEnvClient(self) diff --git a/src/core/openenv/core/env_server/__init__.py b/src/core/openenv/core/env_server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0f1f2845f09ec758c1fcedb16dbb771059156b --- /dev/null +++ b/src/core/openenv/core/env_server/__init__.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core environment interfaces and types.""" + +from .base_transforms import CompositeTransform, NullTransform +from .exceptions import ( + ConcurrencyConfigurationError, + EnvironmentFactoryError, + OpenEnvError, + SessionCapacityError, + SessionCreationError, + SessionNotFoundError, +) +from .http_server import create_app, create_fastapi_app, HTTPEnvServer +from .interfaces import Environment, Message, ModelTokenizer, Transform + +try: + from .mcp_environment import MCPEnvironment +except ModuleNotFoundError: + MCPEnvironment = None # type: ignore[assignment] + +from .mcp_types import ( + CallToolAction, + CallToolObservation, + JsonRpcError, + # JSON-RPC types + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + ListToolsAction, + ListToolsObservation, + McpMethod, + RESERVED_TOOL_NAMES, + Tool, + ToolError, + ToolErrorType, + WSMCPMessage, + WSMCPResponse, +) +from .route_config import GetEndpointConfig +from .serialization import ( + deserialize_action, + deserialize_action_with_preprocessing, + serialize_observation, +) +from .types import ( + Action, + BaseMessage, + ConcurrencyConfig, + HealthResponse, + HealthStatus, + Observation, + SchemaResponse, + ServerCapacityStatus, + ServerMode, + SessionInfo, + State, + WSCloseMessage, + WSErrorCode, + WSErrorResponse, + WSIncomingMessage, + WSObservationResponse, + WSResetMessage, + WSStateMessage, + WSStateResponse, + WSStepMessage, +) + +try: + from .web_interface import create_web_interface_app, WebInterfaceManager +except ModuleNotFoundError: + create_web_interface_app = None # type: ignore[assignment] + WebInterfaceManager = None # type: ignore[assignment] + +__all__ = [ + # Core interfaces + "Environment", + "Transform", + "Message", + "ModelTokenizer", + # Types + "Action", + "Observation", + "State", + "SchemaResponse", + "HealthResponse", + # Enums + "HealthStatus", + "ServerMode", + "WSErrorCode", + # WebSocket message types + "BaseMessage", + "WSIncomingMessage", + "WSResetMessage", + "WSStepMessage", + "WSStateMessage", + "WSCloseMessage", + "WSObservationResponse", + "WSStateResponse", + "WSErrorResponse", + # Concurrency types + "ConcurrencyConfig", + "ServerCapacityStatus", + "SessionInfo", + # Exceptions + "OpenEnvError", + "ConcurrencyConfigurationError", + "SessionCapacityError", + "SessionNotFoundError", + "SessionCreationError", + "EnvironmentFactoryError", + # Base transforms + "CompositeTransform", + "NullTransform", + # HTTP Server + "HTTPEnvServer", + "create_app", + "create_fastapi_app", + # Web Interface + "create_web_interface_app", + "WebInterfaceManager", + # Serialization utilities + "deserialize_action", + "deserialize_action_with_preprocessing", + "serialize_observation", + # Route configuration + "GetEndpointConfig", + # MCP types + "Tool", + "ToolError", + "ToolErrorType", + "ListToolsAction", + "CallToolAction", + "ListToolsObservation", + "CallToolObservation", + "WSMCPMessage", + "WSMCPResponse", + "RESERVED_TOOL_NAMES", + "MCPEnvironment", + # JSON-RPC types + "JsonRpcErrorCode", + "JsonRpcError", + "JsonRpcRequest", + "JsonRpcResponse", + "McpMethod", +] diff --git a/src/core/openenv/core/env_server/base_transforms.py b/src/core/openenv/core/env_server/base_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ab48ebb48b58962ff56d282713a1d63907b0f390 --- /dev/null +++ b/src/core/openenv/core/env_server/base_transforms.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base transform implementations for composing environment-specific transforms.""" + +from .interfaces import Transform +from .types import Observation + + +class CompositeTransform(Transform): + """Combines multiple transforms into a single transform.""" + + def __init__(self, transforms: list[Transform]): + self.transforms = transforms + + def __call__(self, observation: Observation) -> Observation: + for transform in self.transforms: + observation = transform(observation) + return observation + + +class NullTransform(Transform): + """Default transform that passes through unchanged.""" + + def __call__(self, observation: Observation) -> Observation: + return observation diff --git a/src/core/openenv/core/env_server/exceptions.py b/src/core/openenv/core/env_server/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..5701913e0bcac67e6f84d3861d57c4949665677a --- /dev/null +++ b/src/core/openenv/core/env_server/exceptions.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Custom exceptions for environment server operations.""" + +from typing import Optional + + +class OpenEnvError(Exception): + """Base exception for all OpenEnv errors.""" + + pass + + +class ConcurrencyConfigurationError(OpenEnvError): + """ + Raised when an environment is misconfigured for concurrent sessions. + + This error is raised during server startup when max_concurrent_envs > 1 + is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + + def __init__( + self, + environment_name: str, + max_concurrent_envs: int, + message: Optional[str] = None, + ): + self.environment_name = environment_name + self.max_concurrent_envs = max_concurrent_envs + + if message is None: + message = ( + f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " + f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " + f"Either set max_concurrent_envs=1 or ensure the environment " + f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." + ) + + super().__init__(message) + + +class SessionCapacityError(OpenEnvError): + """ + Raised when the server cannot accept new sessions due to capacity limits. + + This error is raised when a new WebSocket connection is attempted but + the server has already reached max_concurrent_envs active sessions. + """ + + def __init__( + self, + active_sessions: int, + max_sessions: int, + message: Optional[str] = None, + ): + self.active_sessions = active_sessions + self.max_sessions = max_sessions + + if message is None: + message = ( + f"Server at capacity: {active_sessions}/{max_sessions} sessions active. " + f"Cannot accept new connections." + ) + + super().__init__(message) + + +class SessionNotFoundError(OpenEnvError): + """Raised when attempting to access a session that does not exist.""" + + def __init__(self, session_id: str, message: Optional[str] = None): + self.session_id = session_id + + if message is None: + message = f"Session '{session_id}' not found." + + super().__init__(message) + + +class SessionCreationError(OpenEnvError): + """Raised when a session cannot be created.""" + + def __init__(self, reason: str, message: Optional[str] = None): + self.reason = reason + + if message is None: + message = f"Failed to create session: {reason}" + + super().__init__(message) + + +class EnvironmentFactoryError(OpenEnvError): + """Raised when the environment factory fails to create an instance.""" + + def __init__(self, factory_name: str, message: Optional[str] = None): + self.factory_name = factory_name + + if message is None: + message = f"Environment factory '{factory_name}' failed to create instance." + + super().__init__(message) diff --git a/src/core/openenv/core/env_server/gradio_theme.py b/src/core/openenv/core/env_server/gradio_theme.py new file mode 100644 index 0000000000000000000000000000000000000000..7cebea2284d8d19e41d5954b498bcc3bb7ff39a4 --- /dev/null +++ b/src/core/openenv/core/env_server/gradio_theme.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unified terminal-style theme for OpenEnv Gradio UI (light/dark).""" + +from __future__ import annotations + +import gradio as gr + +_MONO_FONTS = ( + "JetBrains Mono", + "Fira Code", + "Cascadia Code", + "Consolas", + "ui-monospace", + "monospace", +) + +_CORE_FONT = ( + "Lato", + "Inter", + "Arial", + "Helvetica", + "sans-serif", +) + +_ZERO_RADIUS = gr.themes.Size( + xxs="0px", + xs="0px", + sm="0px", + md="0px", + lg="0px", + xl="0px", + xxl="0px", +) + +_GREEN_HUE = gr.themes.Color( + c50="#e6f4ea", + c100="#ceead6", + c200="#a8dab5", + c300="#6fcc8b", + c400="#3fb950", + c500="#238636", + c600="#1a7f37", + c700="#116329", + c800="#0a4620", + c900="#033a16", + c950="#04200d", +) + +_NEUTRAL_HUE = gr.themes.Color( + c50="#f6f8fa", + c100="#eaeef2", + c200="#d0d7de", + c300="#afb8c1", + c400="#8c959f", + c500="#6e7781", + c600="#57606a", + c700="#424a53", + c800="#32383f", + c900="#24292f", + c950="#1b1f24", +) + +OPENENV_GRADIO_THEME = gr.themes.Base( + primary_hue=_GREEN_HUE, + secondary_hue=_NEUTRAL_HUE, + neutral_hue=_NEUTRAL_HUE, + font=_CORE_FONT, + font_mono=_MONO_FONTS, + radius_size=_ZERO_RADIUS, +).set( + body_background_fill="#ffffff", + background_fill_primary="#ffffff", + background_fill_secondary="#f6f8fa", + block_background_fill="#ffffff", + block_border_color="#ffffff", + block_label_text_color="#57606a", + block_title_text_color="#24292f", + border_color_primary="#d0d7de", + input_background_fill="#ffffff", + input_border_color="#d0d7de", + button_primary_background_fill="#1a7f37", + button_primary_background_fill_hover="#116329", + button_primary_text_color="#ffffff", + button_secondary_background_fill="#f6f8fa", + button_secondary_background_fill_hover="#eaeef2", + button_secondary_text_color="#24292f", + button_secondary_border_color="#d0d7de", + body_background_fill_dark="#0d1117", + background_fill_primary_dark="#0d1117", + background_fill_secondary_dark="#0d1117", + block_background_fill_dark="#0d1117", + block_border_color_dark="#0d1117", + block_label_text_color_dark="#8b949e", + block_title_text_color_dark="#c9d1d9", + border_color_primary_dark="#30363d", + input_background_fill_dark="#0d1117", + input_border_color_dark="#30363d", + button_primary_background_fill_dark="#30363d", + button_primary_background_fill_hover_dark="#484f58", + button_primary_text_color_dark="#c9d1d9", + button_secondary_background_fill_dark="#21262d", + button_secondary_background_fill_hover_dark="#30363d", + button_secondary_text_color_dark="#c9d1d9", + button_secondary_border_color_dark="#30363d", +) + +OPENENV_GRADIO_CSS = """ +* { border-radius: 0 !important; } +.col-left { padding: 16px !important; } +.col-right { padding: 16px !important; } +.prose, .markdown-text, .md, +.prose > *, .markdown-text > * { + background: transparent !important; + border: none !important; + box-shadow: none !important; +} +.dark .col-left { + border-left-color: rgba(139, 148, 158, 0.4) !important; +} +.dark .col-right { + border-left-color: rgba(201, 209, 217, 0.3) !important; +} +""" diff --git a/src/core/openenv/core/env_server/gradio_ui.py b/src/core/openenv/core/env_server/gradio_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1a630bd1db39588304b42520f08bb45f477e81 --- /dev/null +++ b/src/core/openenv/core/env_server/gradio_ui.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Gradio-based web UI for OpenEnv environments. + +Replaces the legacy HTML/JavaScript interface when ENABLE_WEB_INTERFACE is set. +Mount at /web via gr.mount_gradio_app() from create_web_interface_app(). +""" + +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Optional + +import gradio as gr + +from .types import EnvironmentMetadata + + +def _escape_md(text: str) -> str: + """Escape Markdown special characters in user-controlled content.""" + return re.sub(r"([\\`*_\{\}\[\]()#+\-.!|~>])", r"\\\1", str(text)) + + +def _format_observation(data: Dict[str, Any]) -> str: + """Format reset/step response for Markdown display.""" + lines: List[str] = [] + obs = data.get("observation", {}) + if isinstance(obs, dict): + if obs.get("prompt"): + lines.append(f"**Prompt:**\n\n{_escape_md(obs['prompt'])}\n") + messages = obs.get("messages", []) + if messages: + lines.append("**Messages:**\n") + for msg in messages: + sender = _escape_md(str(msg.get("sender_id", "?"))) + content = _escape_md(str(msg.get("content", ""))) + cat = _escape_md(str(msg.get("category", ""))) + lines.append(f"- `[{cat}]` Player {sender}: {content}") + lines.append("") + reward = data.get("reward") + done = data.get("done") + if reward is not None: + lines.append(f"**Reward:** `{reward}`") + if done is not None: + lines.append(f"**Done:** `{done}`") + return "\n".join(lines) if lines else "*No observation data*" + + +def _readme_section(metadata: Optional[EnvironmentMetadata]) -> str: + """README content for the left panel.""" + if not metadata or not metadata.readme_content: + return "*No README available.*" + return metadata.readme_content + + +def get_gradio_display_title( + metadata: Optional[EnvironmentMetadata], + fallback: str = "OpenEnv Environment", +) -> str: + """Return the title used for the Gradio app (browser tab and Blocks).""" + name = metadata.name if metadata else fallback + return f"OpenEnv Agentic Environment: {name}" + + +def build_gradio_app( + web_manager: Any, + action_fields: List[Dict[str, Any]], + metadata: Optional[EnvironmentMetadata], + is_chat_env: bool, + title: str = "OpenEnv Environment", + quick_start_md: Optional[str] = None, +) -> gr.Blocks: + """ + Build a Gradio Blocks app for the OpenEnv web interface. + + Args: + web_manager: WebInterfaceManager (reset/step_environment, get_state). + action_fields: Field dicts from _extract_action_fields(action_cls). + metadata: Environment metadata for README/name. + is_chat_env: If True, single message textbox; else form from action_fields. + title: App title (overridden by metadata.name when present; see get_gradio_display_title). + quick_start_md: Optional Quick Start markdown (class names already replaced). + + Returns: + gr.Blocks to mount with gr.mount_gradio_app(app, blocks, path="/web"). + """ + readme_content = _readme_section(metadata) + display_title = get_gradio_display_title(metadata, fallback=title) + + async def reset_env(): + try: + data = await web_manager.reset_environment() + obs_md = _format_observation(data) + return ( + obs_md, + json.dumps(data, indent=2), + "Environment reset successfully.", + ) + except Exception as e: + return ("", "", f"Error: {e}") + + def _step_with_action(action_data: Dict[str, Any]): + async def _run(): + try: + data = await web_manager.step_environment(action_data) + obs_md = _format_observation(data) + return ( + obs_md, + json.dumps(data, indent=2), + "Step complete.", + ) + except Exception as e: + return ("", "", f"Error: {e}") + + return _run + + async def step_chat(message: str): + if not (message or str(message).strip()): + return ("", "", "Please enter an action message.") + action = {"message": str(message).strip()} + return await _step_with_action(action)() + + def get_state_sync(): + try: + data = web_manager.get_state() + return json.dumps(data, indent=2) + except Exception as e: + return f"Error: {e}" + + with gr.Blocks(title=display_title) as demo: + with gr.Row(): + with gr.Column(scale=1, elem_classes="col-left"): + if quick_start_md: + with gr.Accordion("Quick Start", open=True): + gr.Markdown(quick_start_md) + with gr.Accordion("README", open=False): + gr.Markdown(readme_content) + + with gr.Column(scale=2, elem_classes="col-right"): + obs_display = gr.Markdown( + value=("# Playground\n\nClick **Reset** to start a new episode."), + ) + with gr.Group(): + if is_chat_env: + action_input = gr.Textbox( + label="Action message", + placeholder="e.g. Enter your message...", + ) + step_inputs = [action_input] + step_fn = step_chat + else: + step_inputs = [] + for field in action_fields: + name = field["name"] + field_type = field.get("type", "text") + label = name.replace("_", " ").title() + placeholder = field.get("placeholder", "") + if field_type == "checkbox": + inp = gr.Checkbox(label=label) + elif field_type == "number": + inp = gr.Number(label=label) + elif field_type == "select": + choices = field.get("choices") or [] + inp = gr.Dropdown( + choices=choices, + label=label, + allow_custom_value=False, + ) + elif field_type in ("textarea", "tensor"): + inp = gr.Textbox( + label=label, + placeholder=placeholder, + lines=3, + ) + else: + inp = gr.Textbox( + label=label, + placeholder=placeholder, + ) + step_inputs.append(inp) + + async def step_form(*values): + if not action_fields: + return await _step_with_action({})() + action_data = {} + for i, field in enumerate(action_fields): + if i >= len(values): + break + name = field["name"] + val = values[i] + if field.get("type") == "checkbox": + action_data[name] = bool(val) + elif val is not None and val != "": + action_data[name] = val + return await _step_with_action(action_data)() + + step_fn = step_form + + with gr.Row(): + step_btn = gr.Button("Step", variant="primary") + reset_btn = gr.Button("Reset", variant="secondary") + state_btn = gr.Button("Get state", variant="secondary") + with gr.Row(): + status = gr.Textbox( + label="Status", + interactive=False, + ) + raw_json = gr.Code( + label="Raw JSON response", + language="json", + interactive=False, + ) + + reset_btn.click( + fn=reset_env, + outputs=[obs_display, raw_json, status], + ) + step_btn.click( + fn=step_fn, + inputs=step_inputs, + outputs=[obs_display, raw_json, status], + ) + if is_chat_env: + action_input.submit( + fn=step_fn, + inputs=step_inputs, + outputs=[obs_display, raw_json, status], + ) + state_btn.click( + fn=get_state_sync, + outputs=[raw_json], + ) + + return demo diff --git a/src/core/openenv/core/env_server/http_server.py b/src/core/openenv/core/env_server/http_server.py new file mode 100644 index 0000000000000000000000000000000000000000..658f63ef98bf78d278b8926271c217da23c79a37 --- /dev/null +++ b/src/core/openenv/core/env_server/http_server.py @@ -0,0 +1,1391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +HTTP server wrapper for Environment instances. + +This module provides utilities to wrap any Environment subclass and expose it +over HTTP and WebSocket endpoints that EnvClient can consume. +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import os +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Dict, Optional, Type + +from fastapi import ( + Body, + FastAPI, + HTTPException, + Request, + status, + WebSocket, + WebSocketDisconnect, +) +from pydantic import ValidationError + +from .interfaces import Environment +from .mcp_environment import get_server_tools +from .mcp_types import ( + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + McpMethod, + WSMCPMessage, + WSMCPResponse, +) +from .route_config import GetEndpointConfig, register_get_endpoints +from .serialization import deserialize_action, serialize_observation +from .types import ( + Action, + ConcurrencyConfig, + EnvironmentMetadata, + HealthResponse, + HealthStatus, + Observation, + ResetRequest, + ResetResponse, + SchemaResponse, + ServerCapacityStatus, + ServerMode, + SessionInfo, + State, + StepRequest, + StepResponse, + WSCloseMessage, + WSErrorCode, + WSErrorResponse, + WSObservationResponse, + WSResetMessage, + WSStateMessage, + WSStateResponse, + WSStepMessage, +) + + +def _make_json_serializable(obj: Any) -> Any: + """ + Convert an object to a JSON-serializable form. + + Handles Pydantic models, dataclasses, and other common types. + + Args: + obj: The object to convert + + Returns: + A JSON-serializable representation of the object + """ + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, (list, tuple)): + return [_make_json_serializable(item) for item in obj] + if isinstance(obj, dict): + return {k: _make_json_serializable(v) for k, v in obj.items()} + if hasattr(obj, "model_dump"): + # Pydantic model + return obj.model_dump() + if hasattr(obj, "__dict__"): + # Object with __dict__ + return {k: _make_json_serializable(v) for k, v in obj.__dict__.items()} + # Fallback to string representation + return str(obj) + + +from .exceptions import ( + ConcurrencyConfigurationError, + EnvironmentFactoryError, + SessionCapacityError, +) + + +class HTTPEnvServer: + """ + HTTP server wrapper for Environment instances. + + This class wraps an Environment and exposes its reset(), step(), and state + methods as HTTP and WebSocket endpoints compatible with EnvClient. + + The server expects: + - Action deserialization: Converts JSON dict to Action subclass + - Observation serialization: Converts Observation subclass to JSON dict + + Example: + >>> from core.env_server import HTTPEnvServer + >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation + >>> + >>> # Pass environment class (factory pattern) + >>> server = HTTPEnvServer( + ... env=CodeExecutionEnvironment, + ... action_cls=CodeAction, + ... observation_cls=CodeObservation, + ... max_concurrent_envs=4, + ... ) + >>> + >>> # Register routes with FastAPI + >>> from fastapi import FastAPI + >>> app = FastAPI() + >>> server.register_routes(app) + """ + + def __init__( + self, + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, + ): + """ + Initialize HTTP server wrapper. + + Args: + env: Environment factory (callable) that creates new instances. + Will be called to create a new environment for each WebSocket session. + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Raises: + ValueError: If both max_concurrent_envs and concurrency_config are provided. + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + # Validate that env is callable + if not callable(env): + raise TypeError( + f"env must be a callable (class or factory function), got {type(env)}. " + f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." + ) + + self._env_factory: Callable[[], Environment] = env + + # Handle concurrency configuration + if max_concurrent_envs is not None and concurrency_config is not None: + raise ValueError( + "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. " + "Please use only one method to configure concurrency." + ) + + if concurrency_config is not None: + self._concurrency_config = concurrency_config + elif max_concurrent_envs is not None: + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=max_concurrent_envs, + session_timeout=None, + ) + else: + # Default configuration + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=1, + session_timeout=None, + ) + + self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs + + # Validate concurrency configuration + self._validate_concurrency_safety() + + self.action_cls = action_cls + self.observation_cls = observation_cls + + # Session management for WebSocket connections + self._sessions: Dict[str, Environment] = {} + self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_info: Dict[str, SessionInfo] = {} + self._session_lock = asyncio.Lock() + + # Create thread pool for running sync code in async context + # This is needed for environments using sync libraries (e.g., Playwright) + self._executor = ThreadPoolExecutor(max_workers=32) + + def _validate_concurrency_safety(self) -> None: + """ + Validate that the environment supports the configured concurrency level. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + if self._max_concurrent_envs <= 1: + return + + if inspect.isclass(self._env_factory): + env_cls = self._env_factory + else: + _temp_env = self._env_factory() + env_cls = type(_temp_env) + _temp_env.close() + del _temp_env + + if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False): + raise ConcurrencyConfigurationError( + environment_name=env_cls.__name__, + max_concurrent_envs=self._max_concurrent_envs, + ) + + def get_capacity_status(self) -> ServerCapacityStatus: + """ + Get the current capacity status of the server. + + Returns: + ServerCapacityStatus with current session counts and availability. + """ + return ServerCapacityStatus.from_counts( + active=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + async def _run_sync_in_thread_pool( + self, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the thread pool executor.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) + + def _get_valid_kwargs( + self, + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: + """Filter kwargs to only include parameters accepted by the function signature.""" + if skip_params is None: + skip_params = set() + + valid_kwargs = {} + + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + for k, v in kwargs.items(): + if k in sig.parameters or has_kwargs: + if k not in skip_params: + valid_kwargs[k] = v + + return valid_kwargs + + async def _create_session(self) -> tuple[str, Environment]: + """ + Create a new WebSocket session with its own environment instance. + + Returns: + Tuple of (session_id, environment) + + Raises: + SessionCapacityError: If max concurrent sessions reached + EnvironmentFactoryError: If the factory fails to create an environment + """ + async with self._session_lock: + if len(self._sessions) >= self._max_concurrent_envs: + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + session_id = str(uuid.uuid4()) + current_time = time.time() + + # Create executor and reserve slot so capacity is not exceeded while + # we create the env outside the lock (avoids blocking other sessions) + executor = ThreadPoolExecutor(max_workers=1) + self._session_executors[session_id] = executor + self._sessions[session_id] = None # placeholder until env is ready + + try: + # Create environment in the executor thread (outside lock) + loop = asyncio.get_event_loop() + env = await loop.run_in_executor(executor, self._env_factory) + except Exception as e: + async with self._session_lock: + executor.shutdown(wait=False) + self._session_executors.pop(session_id, None) + self._sessions.pop(session_id, None) + factory_name = getattr( + self._env_factory, "__name__", str(self._env_factory) + ) + raise EnvironmentFactoryError(factory_name) from e + + async with self._session_lock: + self._sessions[session_id] = env + self._session_info[session_id] = SessionInfo( + session_id=session_id, + created_at=current_time, + last_activity_at=current_time, + step_count=0, + environment_type=type(env).__name__, + ) + + return session_id, env + + async def _destroy_session(self, session_id: str) -> None: + """ + Destroy a WebSocket session and cleanup resources. + + Args: + session_id: The session ID to destroy + """ + async with self._session_lock: + env = self._sessions.pop(session_id, None) + executor = self._session_executors.pop(session_id, None) + self._session_info.pop(session_id, None) + + # Run close() in the same executor where the env was created + # This is required for thread-sensitive libraries like Playwright/greenlet + if env is not None: + if executor is not None: + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(executor, env.close) + except Exception: + # If executor close fails, try direct close as fallback + try: + env.close() + except Exception: + pass # Best effort cleanup + else: + try: + env.close() + except Exception: + pass # Best effort cleanup + + # Shutdown executor after close is done + if executor is not None: + executor.shutdown(wait=False) + + def _update_session_activity( + self, session_id: str, increment_step: bool = False + ) -> None: + """ + Update session activity timestamp and optionally increment step count. + + Args: + session_id: The session ID to update + increment_step: If True, increment the step count + """ + if session_id in self._session_info: + self._session_info[session_id].last_activity_at = time.time() + if increment_step: + self._session_info[session_id].step_count += 1 + + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: + """ + Get information about a specific session. + + Args: + session_id: The session ID to query + + Returns: + SessionInfo if the session exists, None otherwise + """ + return self._session_info.get(session_id) + + async def _run_in_session_executor( + self, session_id: str, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the session's thread pool executor.""" + executor = self._session_executors.get(session_id, self._executor) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + @property + def active_sessions(self) -> int: + """Return the number of active WebSocket sessions.""" + return len(self._sessions) + + @property + def max_concurrent_envs(self) -> int: + """Return the maximum number of concurrent environments.""" + return self._max_concurrent_envs + + @property + def is_concurrency_safe(self) -> bool: + """Return whether the environment is marked as concurrency safe.""" + import inspect + + if inspect.isclass(self._env_factory): + return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + else: + _temp_env = self._env_factory() + result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + _temp_env.close() + del _temp_env + return result + + @property + def concurrency_config(self) -> ConcurrencyConfig: + """Return the concurrency configuration.""" + return self._concurrency_config + + def register_routes( + self, app: FastAPI, mode: ServerMode | str = ServerMode.SIMULATION + ) -> None: + """ + Register HTTP routes on a FastAPI application. + + Args: + app: FastAPI application instance + mode: Server mode - either SIMULATION or PRODUCTION (or string equivalents). + In production mode, simulation control endpoints (/reset, /step, /state) + are NOT registered. Only safe endpoints (/health, /schema, /metadata, /ws) + are available. Defaults to SIMULATION for backwards compatibility. + + Raises: + ValueError: If mode is not a valid ServerMode or string equivalent. + """ + # Convert string to ServerMode enum for backwards compatibility + if isinstance(mode, str): + try: + mode = ServerMode(mode.lower()) + except ValueError: + valid_modes = [m.value for m in ServerMode] + raise ValueError( + f"Invalid mode: '{mode}'. Must be one of: {valid_modes}" + ) + + # Helper function to handle reset endpoint + async def reset_handler( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + """Reset endpoint - returns initial observation.""" + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True) + + is_async = _env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(_env.reset_async) + else: + sig = inspect.signature(_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + if is_async: + observation = await _env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + finally: + _env.close() + + # Helper function to handle step endpoint + async def step_handler(request: StepRequest) -> StepResponse: + """Step endpoint - executes action and returns observation.""" + action_data = request.action + + try: + action = deserialize_action(action_data, self.action_cls) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() + ) + + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + + is_async = _env.step_async.__func__ is not Environment.step_async + + if is_async: + sig = inspect.signature(_env.step_async) + else: + sig = inspect.signature(_env.step) + valid_kwargs = self._get_valid_kwargs( + sig, kwargs, skip_params={"action"} + ) + + if is_async: + observation = await _env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.step, action, **valid_kwargs + ) + + return StepResponse(**serialize_observation(observation)) + finally: + _env.close() + + # Helper function to handle MCP endpoint + async def mcp_handler( + request: JsonRpcRequest, session_env: Optional[Environment] = None + ) -> JsonRpcResponse: + """ + Handle MCP JSON-RPC requests. + + Supports tools/list and tools/call methods in JSON-RPC 2.0 format. + """ + method = request.method + request_id = request.id + + # Use provided session environment or create temporary one + if session_env is not None: + _env = session_env + should_close = False + else: + _env = self._env_factory() + should_close = True + try: + if method == McpMethod.TOOLS_LIST: + # Check if environment is MCP-enabled + if not hasattr(_env, "mcp_client"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ) + + # Use async context manager for MCP client + async with _env.mcp_client: + tools = await _env.mcp_client.list_tools() + + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() if hasattr(t, "model_dump") else dict(t) + for t in tools + ] + }, + request_id=request_id, + ) + + elif method == McpMethod.TOOLS_CALL: + params = request.params + tool_name = params.get("name") + arguments = params.get("arguments", {}) + + if not hasattr(_env, "mcp_client"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ) + + if not tool_name: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + "Missing 'name' in params", + request_id=request_id, + ) + + # Use async context manager for MCP client + async with _env.mcp_client: + result = await _env.mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + + # Ensure result is JSON serializable + serializable_result = _make_json_serializable(result) + + return JsonRpcResponse.success( + result=serializable_result, + request_id=request_id, + ) + + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.METHOD_NOT_FOUND, + f"Method not found: {method}", + request_id=request_id, + ) + + except Exception as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=request_id, + ) + finally: + if should_close: + _env.close() + + # Register MCP WebSocket endpoint (available in both production and simulation modes) + @app.websocket("/mcp") + async def mcp_websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for MCP JSON-RPC requests. + + Each WebSocket connection gets its own environment instance for MCP operations. + + Message Protocol: + - Client sends: JSON-RPC 2.0 request (tools/list, tools/call) + - Server responds: JSON-RPC 2.0 response (result or error) + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + jsonrpc_dict = json.loads(raw_message) + jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) + except json.JSONDecodeError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + except ValidationError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + try: + # Call mcp_handler with session environment + response = await mcp_handler( + jsonrpc_request, session_env=session_env + ) + await websocket.send_text(response.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=jsonrpc_request.id, + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + data={ + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + }, + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + data={"factory_name": e.factory_name}, + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + + # Register simulation control routes only in simulation mode + if mode == ServerMode.SIMULATION: + + @app.post( + "/reset", + response_model=ResetResponse, + tags=["Environment Control"], + summary="Reset the environment", + description=""" +Reset the environment to its initial state and return the first observation. + +You can optionally provide a seed for reproducibility and an episode_id for tracking. + """, + responses={ + 200: { + "description": "Environment reset successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "ready", "data": {}}, + "reward": None, + "done": False, + } + } + }, + } + }, + ) + async def reset( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + return await reset_handler(request) + + @app.post( + "/step", + response_model=StepResponse, + tags=["Environment Control"], + summary="Execute an action in the environment", + description=""" +Execute an action in the environment and receive the resulting observation. + +The action must conform to the environment's action schema, which can be +retrieved from the `/schema` endpoint. If the action is invalid, +the endpoint will return HTTP 422 with detailed validation errors. + +The response includes: +- **observation**: The environment's response to the action +- **reward**: Optional reward signal (float or None) +- **done**: Boolean indicating if the episode has terminated + """, + responses={ + 200: { + "description": "Action executed successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "success", "data": {}}, + "reward": 1.0, + "done": False, + } + } + }, + }, + 422: { + "description": "Validation error - invalid action format or values", + "content": { + "application/json": { + "example": { + "detail": [ + { + "type": "string_too_short", + "loc": ["body", "action", "message"], + "msg": "String should have at least 1 character", + "input": "", + } + ] + } + } + }, + }, + 500: { + "description": "Internal server error during action execution" + }, + }, + ) + async def step(request: StepRequest) -> StepResponse: + return await step_handler(request) + + def get_state_handler() -> State: + _env = self._env_factory() + try: + return _env.state + finally: + _env.close() + + def get_metadata_handler() -> EnvironmentMetadata: + _env = self._env_factory() + try: + return _env.get_metadata() + finally: + _env.close() + + # Build list of GET endpoints based on mode + get_endpoints = [ + GetEndpointConfig( + path="/metadata", + handler=get_metadata_handler, + response_model=EnvironmentMetadata, + tag="Environment Info", + summary="Get environment metadata", + description=""" +Get metadata about this environment. + +Returns information about the environment including name, description, +version, author, and documentation links. + """, + ), + GetEndpointConfig( + path="/health", + handler=lambda: HealthResponse(status=HealthStatus.HEALTHY), + response_model=HealthResponse, + tag="Health", + summary="Health check", + description="Check if the environment server is running and healthy.", + ), + ] + + # Only register /state endpoint in simulation mode + if mode == ServerMode.SIMULATION: + get_endpoints.insert( + 0, + GetEndpointConfig( + path="/state", + handler=get_state_handler, + response_model=State, + tag="State Management", + summary="Get current environment state", + description=""" +Retrieve the current internal state of the environment. + +The structure of the state object is defined by the environment's State model. + """, + ), + ) + + register_get_endpoints(app, get_endpoints) + + # Register combined schema endpoint + @app.get( + "/schema", + response_model=SchemaResponse, + tags=["Schema"], + summary="Get all JSON schemas", + description=""" +Get JSON schemas for actions, observations, and state in a single response. + +Returns a combined schema object containing: +- **action**: JSON schema for actions accepted by this environment +- **observation**: JSON schema for observations returned by this environment +- **state**: JSON schema for environment state objects + +This is more efficient than calling individual schema endpoints and provides +all schema information needed to interact with the environment. + """, + responses={ + 200: { + "description": "Combined schemas retrieved successfully", + "content": { + "application/json": { + "example": { + "action": { + "type": "object", + "properties": {"message": {"type": "string"}}, + }, + "observation": { + "type": "object", + "properties": {"response": {"type": "string"}}, + }, + "state": { + "type": "object", + "properties": {"step_count": {"type": "integer"}}, + }, + } + } + }, + } + }, + ) + async def get_schemas() -> SchemaResponse: + """Return all schemas in one response.""" + return SchemaResponse( + action=self.action_cls.model_json_schema(), + observation=self.observation_cls.model_json_schema(), + state=State.model_json_schema(), + ) + + # Register MCP endpoint for production mode (direct MCP access) + @app.post("/mcp") + async def mcp_endpoint(request_raw: Request) -> Dict[str, Any]: + """ + MCP JSON-RPC endpoint for production mode. + + Bypasses step() overhead and provides direct access to MCP tools. + Supports tools/list and tools/call methods. + """ + # Parse JSON manually to handle parse errors gracefully + try: + body = await request_raw.body() + request_dict = json.loads(body) + request = JsonRpcRequest(**request_dict) + except json.JSONDecodeError: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR + ).model_dump() + except ValidationError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ).model_dump() + except Exception: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR + ).model_dump() + + method = request.method + params = request.params + request_id = request.id + + # Create a temporary environment for MCP access + _env = self._env_factory() + + try: + # Check if environment supports MCP + if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ).model_dump() + + if method == McpMethod.TOOLS_LIST: + # List tools from MCP server + if hasattr(_env, "mcp_client") and _env.mcp_client: + async with _env.mcp_client: + tools = await _env.mcp_client.list_tools() + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() + if hasattr(t, "model_dump") + else dict(t) + for t in tools + ] + }, + request_id=request_id, + ).model_dump() + elif hasattr(_env, "mcp_server") and _env.mcp_server: + # Use server directly + tools = [] + for tool_name, tool in get_server_tools( + _env.mcp_server + ).items(): + tool_dict = { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.parameters or {}, + } + tools.append(tool_dict) + return JsonRpcResponse.success( + result={"tools": tools}, + request_id=request_id, + ).model_dump() + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, + ).model_dump() + + elif method == McpMethod.TOOLS_CALL: + tool_name = params.get("name") + arguments = params.get("arguments", {}) + + if not tool_name: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Invalid params - 'name' is required", + request_id=request_id, + ).model_dump() + + # Call tool via MCP + if hasattr(_env, "mcp_client") and _env.mcp_client: + async with _env.mcp_client: + result = await _env.mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif hasattr(_env, "mcp_server") and _env.mcp_server: + # Call tool directly on FastMCP server + server_tools = get_server_tools(_env.mcp_server) + if tool_name in server_tools: + tool = server_tools[tool_name] + result = tool.fn(**arguments) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Tool not found: {tool_name}", + request_id=request_id, + ).model_dump() + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, + ).model_dump() + + # Make result JSON serializable + serializable_result = _make_json_serializable(result) + + return JsonRpcResponse.success( + result=serializable_result, + request_id=request_id, + ).model_dump() + + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.METHOD_NOT_FOUND, + f"Method not found: {method}", + request_id=request_id, + ).model_dump() + + except Exception as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=request_id, + ).model_dump() + finally: + _env.close() + + # Register WebSocket endpoint for persistent sessions + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for persistent environment sessions. + + Each WebSocket connection gets its own environment instance. + + Message Protocol: + - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage + - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": WSErrorCode.INVALID_JSON, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + msg_type = message_dict.get("type", "") + + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) + + is_async = ( + session_env.reset_async.__func__ + is not Environment.reset_async + ) + + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async( + **valid_kwargs + ) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation), + ) + + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action(msg.data, self.action_cls) + + is_async = ( + session_env.step_async.__func__ + is not Environment.step_async + ) + + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) + + self._update_session_activity( + session_id, increment_step=True + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case "mcp": + msg = WSMCPMessage(**message_dict) + try: + rpc_request = JsonRpcRequest(**msg.data) + except (ValidationError, Exception) as e: + rpc_response = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + else: + rpc_response = await mcp_handler( + rpc_request, + session_env=session_env, + ) + response = WSMCPResponse(data=rpc_response.model_dump()) + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": WSErrorCode.UNKNOWN_TYPE, + } + ) + + await websocket.send_text(response.model_dump_json()) + + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": WSErrorCode.VALIDATION_ERROR, + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.EXECUTION_ERROR, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.CAPACITY_REACHED, + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.FACTORY_ERROR, + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": WSErrorCode.SESSION_ERROR} + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + + +def create_app( + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, + gradio_builder: Optional[Callable[..., Any]] = None, +) -> FastAPI: + """ + Create a FastAPI application with or without web interface. + + This function creates a FastAPI app with the web interface enabled by default, + including README integration for better user experience. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + gradio_builder: Optional callable to build a custom Gradio UI at /web. + Signature: (web_manager, action_fields, metadata, is_chat_env, title, + quick_start_md) -> gr.Blocks. When None, the default Gradio app is used. + See docs/customizing-web-ui.md. + + Returns: + FastAPI application instance with or without web interface and README integration + """ + # Check if web interface should be enabled + # This can be controlled via environment variable or build argument + enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( + "true", + "1", + "yes", + ) + + if enable_web: + # Gradio-based web UI (gradio is a core dependency) + from .web_interface import create_web_interface_app + + return create_web_interface_app( + env, + action_cls, + observation_cls, + env_name, + max_concurrent_envs, + concurrency_config, + gradio_builder=gradio_builder, + ) + else: + # Use standard FastAPI app without web interface + return create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) + + +def create_fastapi_app( + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, +) -> FastAPI: + """ + Create a FastAPI application with comprehensive documentation. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Returns: + FastAPI application instance + """ + try: + from fastapi import FastAPI + except ImportError: + raise ImportError( + "FastAPI is required. Install with: pip install fastapi uvicorn" + ) + + app = FastAPI( + title="OpenEnv Environment HTTP API", + version="1.0.0", + description=""" +# OpenEnv Environment HTTP API + +HTTP API for interacting with OpenEnv environments through a standardized interface. + +## Features + +* **Environment Reset**: Initialize or restart episodes +* **Action Execution**: Send actions and receive observations +* **State Inspection**: Query current environment state +* **Schema Access**: Retrieve JSON schemas for actions and observations + +## Workflow + +1. Call `/reset` to start a new episode and get initial observation +2. Call `/step` repeatedly with actions to interact with environment +3. Episode ends when observation returns `done: true` +4. Call `/state` anytime to inspect current environment state + +## Documentation + +* **Swagger UI**: Available at `/docs` +* **ReDoc**: Available at `/redoc` +* **OpenAPI Schema**: Available at `/openapi.json` + """, + openapi_tags=[ + { + "name": "Environment Control", + "description": "Core operations for environment interaction (reset, step)", + }, + { + "name": "State Management", + "description": "Operations for inspecting environment state", + }, + { + "name": "Environment Info", + "description": "Information about the environment", + }, + { + "name": "Schema", + "description": "JSON Schema endpoints for actions, observations, and state", + }, + {"name": "Health", "description": "Service health and status checks"}, + ], + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + contact={ + "name": "OpenEnv Team", + "url": "https://github.com/meta-pytorch/OpenEnv", + }, + license_info={ + "name": "BSD-3-Clause", + "url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE", + }, + ) + + server = HTTPEnvServer( + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config=concurrency_config, + ) + server.register_routes(app) + return app diff --git a/src/core/openenv/core/env_server/interfaces.py b/src/core/openenv/core/env_server/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..9fa837549aa1e2bf1c439f1d7a52e845a556ae18 --- /dev/null +++ b/src/core/openenv/core/env_server/interfaces.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Generic, Optional, Protocol, TYPE_CHECKING, TypeVar + +from typing_extensions import TypedDict + +from .types import Action, EnvironmentMetadata, Observation, State + +if TYPE_CHECKING: + from openenv.core.rubrics import Rubric + +ActT = TypeVar("ActT", bound=Action) +ObsT = TypeVar("ObsT", bound=Observation) +StateT = TypeVar("StateT", bound=State) + + +class Message(TypedDict): + """A message in a conversation. + + Compatible with Huggingface chat template format. + """ + + role: str + content: str + + +class ModelTokenizer(Protocol): + """Protocol for tokenizers that support chat templates. + + This protocol defines the interface that tokenizers must implement + to work with chat-based environments. It's compatible with + Huggingface transformers tokenizers. + """ + + def apply_chat_template( + self, + conversation: list[Message], + tokenize: bool = True, + return_tensors: str | None = None, + **kwargs: Any, + ) -> Any: + """Apply a chat template to format and optionally tokenize a conversation. + + Args: + conversation: List of message dictionaries with 'role' and 'content' + tokenize: Whether to tokenize the output + return_tensors: Format for returned tensors ('pt' for PyTorch) + **kwargs: Additional arguments + + Returns: + Formatted and optionally tokenized conversation + """ + ... + + def decode( + self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any + ) -> str: + """Decode token IDs back to text. + + Args: + token_ids: Token IDs to decode + skip_special_tokens: Whether to skip special tokens in output + **kwargs: Additional arguments + + Returns: + Decoded text string + """ + ... + + +class Transform(ABC, Generic[ObsT]): + """Transform observations to add rewards, metrics, or other modifications. + + Transforms follow the TorchRL pattern where they take an observation + and return a (potentially modified) observation. This allows for + flexible reward computation and observation augmentation. + """ + + @abstractmethod + def __call__(self, observation: ObsT) -> ObsT: + """Transform an observation. + + Args: + observation: The input observation + + Returns: + The transformed observation + """ + pass + + +class Environment(ABC, Generic[ActT, ObsT, StateT]): + """Base class for all environment servers following Gym/Gymnasium API. + + Args: + transform: Optional transform to apply to observations + rubric: Optional rubric for reward computation. When provided, the + rubric's output can be used to set the observation's reward in step(). + + Class Attributes: + SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. + When True, multiple WebSocket connections can each have their own + environment instance (up to max_concurrent_envs). When False (default), + the environment should only be used with a single session at a time. + + Set this to True in your Environment subclass if: + - The environment uses proper session isolation (e.g., unique working dirs) + - No shared mutable state exists between instances + - External resources (databases, APIs) can handle concurrent access + + Attributes: + rubric: Optional rubric for computing rewards. Environments can set this + in __init__ and use it in step() to compute observation rewards. + Training infrastructure can access it for introspection: + for name, r in env.rubric.named_rubrics(): + print(f"{name}: {r.last_score}") + + See RFC 004 for rubric design: rfcs/004-rubrics.md + """ + + # Class-level flag indicating whether this environment supports concurrent sessions + SUPPORTS_CONCURRENT_SESSIONS: bool = False + + # Optional rubric for reward computation + rubric: Optional["Rubric"] + + def __init__( + self, + transform: Optional[Transform[ObsT]] = None, + rubric: Optional["Rubric"] = None, + ): + self.transform = transform + self.rubric = rubric + + @abstractmethod + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Reset the environment and return initial observation.""" + pass + + async def reset_async( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of reset. Default implementation calls sync reset. + + Override to provide true async implementation. + """ + return self.reset(seed=seed, episode_id=episode_id, **kwargs) + + @abstractmethod + def step( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Take a step in the environment.""" + pass + + async def step_async( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of step. Default implementation calls sync step. + + Override to provide true async implementation. + """ + return self.step(action, timeout_s=timeout_s, **kwargs) + + @property + @abstractmethod + def state(self) -> StateT: + """Get the current environment state.""" + pass + + def get_metadata(self) -> EnvironmentMetadata: + """ + Get metadata about this environment. + + Override this method to provide custom metadata for the environment. + Default implementation returns basic metadata derived from class name. + + Returns: + EnvironmentMetadata with environment information + """ + return EnvironmentMetadata( + name=self.__class__.__name__, + description=f"{self.__class__.__name__} environment", + version="1.0.0", + ) + + def _apply_transform(self, observation: ObsT) -> ObsT: + """Apply transform if one is provided.""" + if self.transform is not None: + return self.transform(observation) + return observation + + def _apply_rubric(self, action: ActT, observation: ObsT) -> float: + """Apply rubric if one is provided. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value from the rubric, or 0.0 if no rubric is set. + + Usage in step(): + def step(self, action: MyAction, ...) -> MyObservation: + # ... execute action and create observation ... + observation.reward = self._apply_rubric(action, observation) + return observation + """ + if self.rubric is not None: + return self.rubric(action, observation) + return 0.0 + + async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float: + """Apply rubric asynchronously if one is provided. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value from the rubric, or 0.0 if no rubric is set. + + Usage in step_async(): + async def step_async(self, action: MyAction, ...) -> MyObservation: + # ... execute action and create observation ... + observation.reward = await self._apply_rubric_async(action, observation) + return observation + """ + if self.rubric is not None: + result = self.rubric(action, observation) + # If rubric returns a coroutine, await it + if inspect.iscoroutine(result): + return await result + return result + return 0.0 + + def _reset_rubric(self) -> None: + """Reset the rubric state if one is provided. + + Call this in reset() to clear any trajectory state in the rubric. + + Usage in reset(): + def reset(self, ...) -> MyObservation: + self._reset_rubric() + # ... create initial observation ... + return observation + """ + if self.rubric is not None: + self.rubric.reset() + + async def _reset_rubric_async(self) -> None: + """Reset the rubric state asynchronously if one is provided. + + Call this in reset_async() to clear any trajectory state in the rubric. + + Usage in reset_async(): + async def reset_async(self, ...) -> MyObservation: + await self._reset_rubric_async() + # ... create initial observation ... + return observation + """ + if self.rubric is not None: + # Check if rubric has async reset method + if hasattr(self.rubric, "reset_async"): + result = self.rubric.reset_async() + if inspect.iscoroutine(result): + await result + else: + self.rubric.reset() + + def close(self) -> None: + """Clean up resources used by the environment. + + Override this method to implement custom cleanup logic. + Called when the environment is being destroyed or reset. + """ + pass diff --git a/src/core/openenv/core/env_server/mcp_environment.py b/src/core/openenv/core/env_server/mcp_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..03f66e37897ec81796d468f3d0590d465deddea1 --- /dev/null +++ b/src/core/openenv/core/env_server/mcp_environment.py @@ -0,0 +1,624 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP Environment base class for OpenEnv. + +This module provides the MCPEnvironment base class that integrates FastMCP servers +with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery +and invocation through the step() API, following RFC 003. + +Key features: +- Automatic routing of ListToolsAction and CallToolAction to MCP server +- Reserved tool name validation (reset, step, state, close are protected) +- Timeout handling for tool calls +- Proper error categorization (tool not found, execution errors, timeouts) +- Mode-aware tool registration (production vs simulation) +- Code mode support via get_callables() and execute_code() + +Usage: + from fastmcp import FastMCP + from openenv.core.env_server.mcp_environment import MCPEnvironment + + class MyMCPEnv(MCPEnvironment): + def __init__(self): + mcp = FastMCP("my-server") + + # Register mode-specific tools + @self.tool(mode="production") + def my_tool(arg: str) -> str: + return f"Production: {arg}" + + @self.tool(mode="simulation") + def my_tool(arg: str) -> str: + return f"Simulation: {arg}" + + super().__init__(mcp) + + def reset(self, seed=None, episode_id=None, **kwargs): + # Reset logic here + ... + + def _step_impl(self, action): + # Handle non-MCP actions + ... + + @property + def state(self): + # Return current state + ... +""" + +import asyncio +import inspect +from abc import abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, Optional + +from fastmcp import Client +from fastmcp.client.client import CallToolResult +from mcp.types import TextContent + +from ..utils import run_async_safely +from .interfaces import Environment +from .mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + RESERVED_TOOL_NAMES, + Tool, + ToolError, + ToolErrorType, +) +from .types import Action, Observation + + +# Default timeout for MCP tool calls in seconds +MCP_TOOL_CALL_TIMEOUT = 30.0 + +# Valid modes for tool registration +VALID_MODES = {"production", "simulation"} + + +def get_server_tools(mcp_server: Any) -> Dict[str, Any]: + """ + Get tools from a FastMCP server, compatible with both 2.x and 3.x. + + Returns: + Dictionary mapping tool names to tool objects. + """ + # FastMCP 2.x: get_tools() returns dict {name: Tool} + if hasattr(mcp_server, "get_tools"): + result = run_async_safely(mcp_server.get_tools()) + if isinstance(result, dict): + return result + # FastMCP 3.x: list_tools() returns list of Tool objects + if hasattr(mcp_server, "list_tools"): + tools_list = run_async_safely(mcp_server.list_tools()) + return {t.name: t for t in tools_list} + return {} + + +class MCPEnvironment(Environment): + """ + Base class for environments that expose tools via MCP (Model Context Protocol). + + MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing + agents to discover and invoke MCP tools through the standard step() interface. + + The class automatically handles: + - ListToolsAction: Returns available tools from the MCP server + - CallToolAction: Invokes a specific tool with arguments + + All other actions are delegated to the abstract _step_impl() method, + which subclasses must implement. + + Args: + mcp_server: A FastMCP server instance containing tool definitions. + The server's tools will be validated against reserved names. + transform: Optional transform to apply to observations (inherited from Environment). + + Raises: + ValueError: If any tool in the MCP server uses a reserved name + (reset, step, state, close). + + Example: + >>> from fastmcp import FastMCP + >>> mcp = FastMCP("calculator") + >>> @mcp.tool() + ... def add(a: int, b: int) -> int: + ... return a + b + >>> env = MyMCPEnvironment(mcp) + >>> obs = env.step(ListToolsAction()) + >>> obs.tools[0].name + 'add' + """ + + def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None: + """ + Initialize the MCP environment. + + Args: + mcp_server: A FastMCP server instance with tool definitions. + transform: Optional transform to apply to observations. + + Raises: + ValueError: If any tool uses a reserved name (reset, step, state, close). + """ + super().__init__(transform=transform) + + # Validate tool names before storing + self._validate_tool_names(mcp_server) + + self.mcp_server = mcp_server + self.mcp_client = Client(mcp_server) + + # Track mode-specific tools: {tool_name: {mode: func}} + # mode can be "production", "simulation", or None (available in all modes) + self._mode_tools = defaultdict(dict) + + # Track tool schemas for list_tools: {tool_name: {mode: schema}} + self._mode_tool_schemas = defaultdict(dict) + + @property + def supports_code_mode(self) -> bool: + """Check if this environment supports code mode (execute_code).""" + return True + + def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]: + """ + Get tools from a FastMCP server, compatible with both 2.x and 3.x. + + Returns: + Dictionary mapping tool names to tool objects. + """ + return get_server_tools(mcp_server) + + def get_callables(self) -> Dict[str, Callable]: + """ + Get callable functions for code mode. + + Returns tool functions as direct Python callables, enabling code mode + where agents write Python code that calls tools directly (no JSON-RPC + overhead). Mode-specific tools are filtered by the current mode. + + Returns: + Dictionary mapping tool names to callables. + """ + callables: Dict[str, Callable] = {} + current_mode = getattr(self, "_mode", None) + + # Extract callables from FastMCP server using public API + for tool_name, tool in self._get_server_tools(self.mcp_server).items(): + if hasattr(tool, "fn") and callable(tool.fn): + callables[tool_name] = tool.fn + + # Add mode-specific tools available in current mode + for tool_name, mode_funcs in self._mode_tools.items(): + if None in mode_funcs: + # Tool available in all modes (already in FastMCP if registered there) + if tool_name not in callables: + callables[tool_name] = mode_funcs[None] + elif current_mode in mode_funcs: + # Tool available in current mode only + callables[tool_name] = mode_funcs[current_mode] + + return callables + + def execute_code(self, code: str) -> Observation: + """ + Execute Python code with tools available as callables. + + This enables the CodeAct pattern where agents write Python code + that calls tools directly as functions, avoiding JSON-RPC overhead. + + Args: + code: Python code to execute. Tools are available as functions + in the execution namespace. Set a variable named 'result' + to capture the return value. + + Returns: + Observation with result in metadata["result"] or error in + metadata["error"]. + """ + namespace = self.get_callables() + + result_dict: Dict[str, Any] = {} + try: + exec(code, namespace, result_dict) + result = result_dict.get("result") + return Observation(done=False, reward=0.0, metadata={"result": result}) + except SyntaxError as e: + return Observation( + done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"} + ) + except Exception as e: + return Observation(done=False, reward=0.0, metadata={"error": str(e)}) + + def _validate_tool_names(self, mcp_server: Any) -> None: + """ + Validate that no tools use reserved names. + + Reserved names (reset, step, state, close) are protected to maintain + the dual API boundary between infrastructure and agent APIs. + + Args: + mcp_server: The FastMCP server to validate. + + Raises: + ValueError: If any tool uses a reserved name. + """ + tools_dict = self._get_server_tools(mcp_server) + if tools_dict: + tool_names = set(tools_dict.keys()) + conflicts = tool_names & RESERVED_TOOL_NAMES + if conflicts: + raise ValueError( + f"MCP tools cannot use reserved names: {sorted(conflicts)}. " + f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" + ) + + def tool(self, mode: Optional[str] = None) -> Callable: + """ + Decorator for registering mode-aware tools. + + Args: + mode: Optional mode for the tool ("production" or "simulation"). + If None, tool is available in all modes. + + Returns: + A decorator function for registering tools. + + Raises: + ValueError: If mode is not None, "production", or "simulation". + """ + if mode is not None and mode not in VALID_MODES: + raise ValueError( + f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None." + ) + + def decorator(func: Callable) -> Callable: + tool_name = func.__name__ + # Validate tool name is not reserved + if tool_name in RESERVED_TOOL_NAMES: + raise ValueError( + f"Tool name '{tool_name}' is reserved and cannot be used. " + f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" + ) + + # If mode is None, register with FastMCP as usual + if mode is None: + decorated_func = self.mcp_server.tool()(func) + self._mode_tools[tool_name][None] = func + return decorated_func + + # For mode-specific tools, don't register with FastMCP + # Instead, track them ourselves + self._mode_tools[tool_name][mode] = func + + # Extract schema information from function signature + sig = inspect.signature(func) + schema = { + "type": "object", + "properties": {}, + "required": [], + } + + for param_name, param in sig.parameters.items(): + # Get type annotation + param_type = param.annotation + json_type = "string" # default + if param_type in (int, "int"): + json_type = "integer" + elif param_type in (float, "float"): + json_type = "number" + elif param_type in (bool, "bool"): + json_type = "boolean" + + schema["properties"][param_name] = {"type": json_type} + + # If no default value, it's required + if param.default == inspect.Parameter.empty: + schema["required"].append(param_name) + + # Store the schema for this mode-specific tool + self._mode_tool_schemas[tool_name][mode] = { + "name": tool_name, + "description": func.__doc__ or "", + "input_schema": schema, + } + + return func + + return decorator + + def step( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """ + Execute an action in the environment. + + This method routes MCP-specific actions (ListToolsAction, CallToolAction) + to the appropriate handlers, while delegating all other actions to + the subclass's _step_impl() method. + + Args: + action: The action to execute. Can be: + - ListToolsAction: Returns available MCP tools + - CallToolAction: Invokes a specific MCP tool + - Any other Action: Delegated to _step_impl() + timeout_s: Optional timeout in seconds for the action. + Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions. + **kwargs: Additional arguments passed to handlers. + + Returns: + Observation appropriate to the action type: + - ListToolsObservation for ListToolsAction + - CallToolObservation for CallToolAction + - Subclass-defined Observation for other actions + """ + if isinstance(action, ListToolsAction): + return self._handle_list_tools() + elif isinstance(action, CallToolAction): + return self._handle_call_tool(action, timeout_s=timeout_s) + else: + return self._step_impl(action, timeout_s=timeout_s, **kwargs) + + def _handle_list_tools(self) -> ListToolsObservation: + """ + Handle a ListToolsAction by querying the MCP server. + + Returns: + ListToolsObservation containing all available tools with their + names, descriptions, and input schemas, filtered by current mode. + """ + try: + # Get current mode + current_mode = getattr(self, "_mode", None) + + # Start with tools from FastMCP server (mode=None tools) + tools_result = run_async_safely(self._async_list_tools()) + + # Build list of Tool objects + tools = [] + + # Add FastMCP tools that are not mode-specific + for tool in tools_result: + if tool.name not in self._mode_tool_schemas: + tools.append( + Tool( + name=tool.name, + description=tool.description or "", + input_schema=tool.inputSchema + if hasattr(tool, "inputSchema") + else {}, + ) + ) + + # Add mode-specific tools available in current mode + for tool_name, mode_schemas in self._mode_tool_schemas.items(): + if None in mode_schemas: + # Tool available in all modes + schema = mode_schemas[None] + tools.append( + Tool( + name=schema["name"], + description=schema["description"], + input_schema=schema["input_schema"], + ) + ) + elif current_mode in mode_schemas: + # Tool available in current mode + schema = mode_schemas[current_mode] + tools.append( + Tool( + name=schema["name"], + description=schema["description"], + input_schema=schema["input_schema"], + ) + ) + + return ListToolsObservation(tools=tools) + + except Exception as e: + # Return an observation with error in metadata + return ListToolsObservation( + tools=[], + metadata={ + "error": str(e), + "error_type": "list_tools_failed", + }, + ) + + async def _async_list_tools(self) -> list: + """ + Async helper to list tools from the MCP client. + + Returns: + List of tool objects from the MCP server. + """ + async with self.mcp_client: + return await self.mcp_client.list_tools() + + def _handle_call_tool( + self, + action: CallToolAction, + timeout_s: Optional[float] = None, + ) -> CallToolObservation: + """ + Handle a CallToolAction by invoking the specified tool. + + Args: + action: The CallToolAction containing tool_name and arguments. + timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). + + Returns: + CallToolObservation with the tool's result or an error. + """ + timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT + + # Check if this is a mode-specific tool + tool_name = action.tool_name + current_mode = getattr(self, "_mode", None) + + if tool_name in self._mode_tools: + mode_info = self._mode_tools[tool_name] + + # Check if tool is available in current mode + # Tool is available if: + # 1. It has a None mode (available in all modes), OR + # 2. It has an implementation for the current mode + if None in mode_info: + # Use the mode-agnostic version + func = mode_info[None] + elif current_mode in mode_info: + # Use the mode-specific version + func = mode_info[current_mode] + else: + # Tool not available in current mode + return CallToolObservation( + tool_name=tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.TOOL_NOT_FOUND, + message=f"Tool '{tool_name}' not available in {current_mode} mode", + ), + ) + + # Call the mode-specific function directly + try: + # Check if function is async and await if necessary + if inspect.iscoroutinefunction(func): + result = run_async_safely(func(**action.arguments)) + else: + result = func(**action.arguments) + + # Wrap result in CallToolResult format to match FastMCP behavior + return CallToolObservation( + tool_name=tool_name, + result=CallToolResult( + content=[TextContent(type="text", text=str(result))], + structured_content={"result": result}, + meta=None, + data=result, + is_error=False, + ), + ) + except Exception as e: + return CallToolObservation( + tool_name=tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.EXECUTION_ERROR, + message=str(e), + ), + ) + + # Not a mode-specific tool, use FastMCP + try: + # Run the async call_tool with timeout + # Use run_async_safely to handle both sync and async contexts + result = run_async_safely( + asyncio.wait_for( + self._async_call_tool(action.tool_name, action.arguments), + timeout=timeout, + ) + ) + + return CallToolObservation( + tool_name=action.tool_name, + result=result, + ) + + except asyncio.TimeoutError: + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.TIMEOUT, + message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", + ), + ) + + except Exception as e: + error_message = str(e) + + # Determine error type based on the exception + if ( + "not found" in error_message.lower() + or "unknown tool" in error_message.lower() + ): + error_type = ToolErrorType.TOOL_NOT_FOUND + elif ( + "invalid" in error_message.lower() + or "argument" in error_message.lower() + ): + error_type = ToolErrorType.INVALID_ARGS + else: + error_type = ToolErrorType.EXECUTION_ERROR + + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=error_type, + message=error_message, + ), + ) + + async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + """ + Async helper to call a tool on the MCP server. + + Args: + tool_name: Name of the tool to invoke. + arguments: Dictionary of arguments to pass to the tool. + + Returns: + The result from the tool execution. + """ + async with self.mcp_client: + return await self.mcp_client.call_tool(tool_name, arguments) + + @abstractmethod + def _step_impl( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """ + Handle non-MCP actions in the environment. + + Subclasses must implement this method to handle any actions that are + not ListToolsAction or CallToolAction. This is where environment-specific + action processing should occur. + + Args: + action: The action to execute (guaranteed not to be an MCP action). + timeout_s: Optional timeout in seconds. + **kwargs: Additional arguments. + + Returns: + An Observation appropriate for the action. + """ + pass + + def close(self) -> None: + """ + Clean up resources used by the environment. + + This method cleans up the MCP client and any other resources. + Subclasses should call super().close() if they override this method. + """ + # The MCP client uses async context manager, so cleanup happens + # automatically when the context exits. We just clear references. + self.mcp_client = None + self.mcp_server = None diff --git a/src/core/openenv/core/env_server/mcp_types.py b/src/core/openenv/core/env_server/mcp_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa5b7449e2fa60dea46efc6b0992a6359146b2b --- /dev/null +++ b/src/core/openenv/core/env_server/mcp_types.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP (Model Context Protocol) type definitions for OpenEnv. + +This module defines strongly typed models for MCP tool discovery and invocation, +following RFC 003. These types map MCP's REST-like API (tools/list, tools/call) +to Gym-style action types. + +Key design decisions: +- Tool discovery (list_tools) does NOT require reset() first +- Reserved tool names (reset, step, state, close) are prohibited +- Both step() and WebSocket /mcp paths are supported +""" + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +from .types import Action, BaseMessage, Observation + + +# ============================================================================= +# JSON-RPC 2.0 Types +# ============================================================================= + + +class JsonRpcErrorCode(int, Enum): + """ + Standard JSON-RPC 2.0 error codes. + + See: https://www.jsonrpc.org/specification#error_object + """ + + # Standard JSON-RPC errors + PARSE_ERROR = -32700 # Invalid JSON was received + INVALID_REQUEST = -32600 # JSON is not a valid Request object + METHOD_NOT_FOUND = -32601 # Method does not exist / is not available + INVALID_PARAMS = -32602 # Invalid method parameter(s) + INTERNAL_ERROR = -32603 # Internal JSON-RPC error + + # Server errors (reserved for implementation-defined errors) + SERVER_ERROR = -32000 # Generic server error + + +class McpMethod(str, Enum): + """Supported MCP method names.""" + + TOOLS_LIST = "tools/list" + TOOLS_CALL = "tools/call" + + +class JsonRpcError(BaseModel): + """ + JSON-RPC 2.0 error object. + + See: https://www.jsonrpc.org/specification#error_object + """ + + model_config = ConfigDict(extra="forbid") + + code: int = Field(description="Error code indicating the error type") + message: str = Field(description="Short description of the error") + data: Optional[Any] = Field( + default=None, description="Additional error information" + ) + + @classmethod + def from_code( + cls, code: JsonRpcErrorCode, message: Optional[str] = None, data: Any = None + ) -> "JsonRpcError": + """Create an error from a standard error code.""" + default_messages = { + JsonRpcErrorCode.PARSE_ERROR: "Parse error", + JsonRpcErrorCode.INVALID_REQUEST: "Invalid Request", + JsonRpcErrorCode.METHOD_NOT_FOUND: "Method not found", + JsonRpcErrorCode.INVALID_PARAMS: "Invalid params", + JsonRpcErrorCode.INTERNAL_ERROR: "Internal error", + JsonRpcErrorCode.SERVER_ERROR: "Server error", + } + return cls( + code=code.value, + message=message or default_messages.get(code, "Unknown error"), + data=data, + ) + + +class JsonRpcRequest(BaseModel): + """ + JSON-RPC 2.0 request object. + + See: https://www.jsonrpc.org/specification#request_object + """ + + model_config = ConfigDict(extra="forbid") + + jsonrpc: Literal["2.0"] = Field(description="JSON-RPC version, must be '2.0'") + method: str = Field(description="Name of the method to be invoked") + params: Dict[str, Any] = Field( + default_factory=dict, description="Parameter values for the method" + ) + id: Optional[Union[str, int]] = Field( + default=None, description="Request identifier established by the client" + ) + + +class JsonRpcResponse(BaseModel): + """ + JSON-RPC 2.0 response object. + + Per JSON-RPC 2.0 spec, a response has either 'result' or 'error', not both. + This model excludes None values during serialization to comply with the spec. + + See: https://www.jsonrpc.org/specification#response_object + """ + + model_config = ConfigDict(extra="forbid") + + jsonrpc: Literal["2.0"] = Field(default="2.0", description="JSON-RPC version") + result: Optional[Any] = Field( + default=None, description="Result of the method invocation" + ) + error: Optional[JsonRpcError] = Field( + default=None, description="Error object if method invocation failed" + ) + id: Optional[Union[str, int]] = Field( + default=None, description="Request identifier from the request" + ) + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """Serialize to dict, excluding result or error when None (JSON-RPC compliance).""" + # Always include jsonrpc and id, but only include result OR error + data: Dict[str, Any] = {"jsonrpc": self.jsonrpc, "id": self.id} + if self.error is not None: + data["error"] = ( + self.error.model_dump() + if hasattr(self.error, "model_dump") + else self.error + ) + else: + # Only include result if there's no error + data["result"] = self.result + return data + + def model_dump_json(self, **kwargs) -> str: + """Serialize to JSON string, excluding result or error when None (JSON-RPC compliance).""" + import json + + return json.dumps(self.model_dump()) + + @classmethod + def success( + cls, result: Any, request_id: Optional[Union[str, int]] = None + ) -> "JsonRpcResponse": + """Create a success response.""" + return cls(result=result, id=request_id) + + @classmethod + def error_response( + cls, + code: JsonRpcErrorCode, + message: Optional[str] = None, + data: Any = None, + request_id: Optional[Union[str, int]] = None, + ) -> "JsonRpcResponse": + """Create an error response from a standard error code.""" + return cls( + error=JsonRpcError.from_code(code, message, data), + id=request_id, + ) + + +# ============================================================================= +# MCP Tool Types +# ============================================================================= + + +class Tool(BaseModel): + """ + Strongly typed MCP tool specification. + + Follows the MCP ToolSpec format for tool discovery. + See: https://modelcontextprotocol.io/specification/2025-06-18/server/tools + """ + + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="Unique identifier for the tool") + description: str = Field( + description="Human-readable description of what the tool does" + ) + input_schema: Dict[str, Any] = Field( + description="JSON Schema for the tool's input parameters" + ) + + +class ToolErrorType(str, Enum): + """Types of errors that can occur during tool execution.""" + + EXECUTION_ERROR = "execution_error" # Tool ran but failed + INVALID_ARGS = "invalid_args" # Invalid arguments provided + TRANSPORT_ERROR = "transport_error" # Communication failure + TOOL_NOT_FOUND = "tool_not_found" # Tool doesn't exist + TIMEOUT = "timeout" # Operation timed out + + +class ToolError(BaseModel): + """ + Structured error for tool execution failures. + + This is used for transport/framework errors, NOT for errors returned + by the tool itself (those go in the result field). + """ + + model_config = ConfigDict(extra="forbid") + + error_type: ToolErrorType = Field(description="Category of the error") + message: str = Field(description="Human-readable error message") + + +# --- MCP Actions --- + + +class ListToolsAction(Action): + """ + Request list of available tools from the environment. + + This action triggers MCP's tools/list operation and returns + all available tools with their schemas. + + Note: Does NOT require reset() to be called first. + """ + + type: Literal["list_tools"] = Field( + default="list_tools", description="Action type discriminator" + ) + + +class CallToolAction(Action): + """ + Call a specific tool via MCP. + + This action triggers MCP's tools/call operation with the + specified tool name and arguments. + """ + + type: Literal["call_tool"] = Field( + default="call_tool", description="Action type discriminator" + ) + tool_name: str = Field(description="Name of the tool to call") + arguments: Dict[str, Any] = Field( + default_factory=dict, description="Arguments to pass to the tool" + ) + + +# --- MCP Observations --- + + +class ListToolsObservation(Observation): + """ + Response containing available tools. + + Returned when processing a ListToolsAction. + """ + + tools: List[Tool] = Field(description="List of available tools with their schemas") + + +class CallToolObservation(Observation): + """ + Response from tool execution. + + Contains the tool's result or an error if the call failed. + Tool-specific errors (from the tool itself) are included in the result. + Transport/framework errors use the error field. + """ + + tool_name: str = Field(description="Name of the tool that was called") + result: Any = Field( + default=None, description="Tool-specific result (may include tool errors)" + ) + error: Optional[ToolError] = Field( + default=None, description="Transport/framework error if call failed" + ) + + +# --- WebSocket Message Types for MCP --- + + +class WSMCPMessage(BaseMessage): + """ + WebSocket message for MCP JSON-RPC requests. + + Allows direct MCP access via WebSocket for production inference, + bypassing the step() API. + """ + + type: Literal["mcp"] = Field(default="mcp", description="Message type") + data: Dict[str, Any] = Field(description="JSON-RPC payload (method, params, id)") + + +class WSMCPResponse(BaseModel): + """ + WebSocket response for MCP JSON-RPC. + + Contains the JSON-RPC response from the MCP server. + """ + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="mcp", description="Response type") + data: Dict[str, Any] = Field(description="JSON-RPC response payload") + + +# Reserved tool names that cannot be used (protects dual API boundary) +RESERVED_TOOL_NAMES = frozenset(["reset", "step", "state", "close"]) diff --git a/src/core/openenv/core/env_server/route_config.py b/src/core/openenv/core/env_server/route_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d74a7f202be0731400a6b954dfd37d9012c1f8f7 --- /dev/null +++ b/src/core/openenv/core/env_server/route_config.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Route configuration utilities for declarative FastAPI route registration. + +This module provides utilities to reduce boilerplate in route registration +by using configuration objects instead of repeated function calls. +""" + +from dataclasses import dataclass +from typing import Callable, List, Type + +from fastapi import FastAPI +from pydantic import BaseModel + + +@dataclass +class GetEndpointConfig: + """Configuration for a simple GET endpoint.""" + + path: str + handler: Callable[[], BaseModel | dict] + response_model: Type[BaseModel] | type[dict] + tag: str + summary: str + description: str + + +def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None: + """ + Register multiple GET endpoints from configuration. + + Args: + app: FastAPI application instance + configs: List of GET endpoint configurations + """ + for config in configs: + # Capture handler in a closure to avoid non-serializable default parameter + def make_endpoint( + handler: Callable[[], BaseModel | dict], + ) -> Callable[[], BaseModel | dict]: + async def endpoint() -> BaseModel | dict: + return handler() + + return endpoint + + app.get( + config.path, + response_model=config.response_model, + tags=[config.tag], + summary=config.summary, + description=config.description, + )(make_endpoint(config.handler)) diff --git a/src/core/openenv/core/env_server/serialization.py b/src/core/openenv/core/env_server/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b50d9aeb873794044e77ee398a7f2b5fca8093 --- /dev/null +++ b/src/core/openenv/core/env_server/serialization.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared serialization and deserialization utilities for OpenEnv HTTP servers. + +This module provides common utilities for converting between JSON dictionaries +and Pydantic models (Action/Observation) to eliminate code duplication across +HTTP server and web interface implementations. +""" + +from typing import Any, Dict, Type + +from .types import Action, Observation + + +def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action: + """ + Convert JSON dict to Action instance using Pydantic validation. + + This is a basic deserialization that works for most environments. + For special cases (e.g., tensor fields, custom type conversions), + use deserialize_action_with_preprocessing(). + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + + Note: + This uses Pydantic's model_validate() for automatic validation. + """ + return action_cls.model_validate(action_data) + + +def deserialize_action_with_preprocessing( + action_data: Dict[str, Any], action_cls: Type[Action] +) -> Action: + """ + Convert JSON dict to Action instance with preprocessing for special types. + + This version handles common type conversions needed for web interfaces: + - Converting lists/strings to tensors for 'tokens' field + - Converting string action_id to int + - Other custom preprocessing as needed + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + """ + processed_data = {} + + for key, value in action_data.items(): + if key == "tokens" and isinstance(value, (list, str)): + # Convert list or string to tensor + if isinstance(value, str): + # If it's a string, try to parse it as a list of numbers + try: + import json + + value = json.loads(value) + except Exception: + # If parsing fails, treat as empty list + value = [] + if isinstance(value, list): + try: + import torch # type: ignore + + processed_data[key] = torch.tensor(value, dtype=torch.long) + except ImportError: + # If torch not available, keep as list + processed_data[key] = value + else: + processed_data[key] = value + elif key == "action_id" and isinstance(value, str): + # Convert action_id from string to int + try: + processed_data[key] = int(value) + except ValueError: + # If conversion fails, keep original value + processed_data[key] = value + else: + processed_data[key] = value + + return action_cls.model_validate(processed_data) + + +def serialize_observation(observation: Observation) -> Dict[str, Any]: + """ + Convert Observation instance to JSON-compatible dict using Pydantic. + + Args: + observation: Observation instance + + Returns: + Dictionary compatible with EnvClient._parse_result() + + The format matches what EnvClient expects: + { + "observation": {...}, # Observation fields + "reward": float | None, + "done": bool, + } + """ + # Use Pydantic's model_dump() for serialization + obs_dict = observation.model_dump( + exclude={ + "reward", + "done", + "metadata", + } # Exclude these from observation dict + ) + + # Extract reward and done directly from the observation + reward = observation.reward + done = observation.done + + # Return in EnvClient expected format + return { + "observation": obs_dict, + "reward": reward, + "done": done, + } diff --git a/src/core/openenv/core/env_server/types.py b/src/core/openenv/core/env_server/types.py new file mode 100644 index 0000000000000000000000000000000000000000..34a198013442e5000f7fbf75b7f24157b6c04683 --- /dev/null +++ b/src/core/openenv/core/env_server/types.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Annotated, Any, Dict, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +# Type aliases +Scalar = Union[int, float, bool] + + +# ============================================================================= +# Enums for Type Safety +# ============================================================================= + + +class ServerMode(str, Enum): + """Server operation mode.""" + + SIMULATION = "simulation" + PRODUCTION = "production" + + +class HealthStatus(str, Enum): + """Server health status values.""" + + HEALTHY = "healthy" + UNHEALTHY = "unhealthy" + DEGRADED = "degraded" + + +class WSErrorCode(str, Enum): + """WebSocket error codes for structured error handling.""" + + INVALID_JSON = "INVALID_JSON" + UNKNOWN_TYPE = "UNKNOWN_TYPE" + VALIDATION_ERROR = "VALIDATION_ERROR" + EXECUTION_ERROR = "EXECUTION_ERROR" + CAPACITY_REACHED = "CAPACITY_REACHED" + FACTORY_ERROR = "FACTORY_ERROR" + SESSION_ERROR = "SESSION_ERROR" + + +# ============================================================================= +# Core Types +# ============================================================================= + + +class Action(BaseModel): + """Base class for all environment actions. + + All action subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", # Reject unknown fields + validate_assignment=True, # Validate on field assignment + arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc. + ) + + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the action" + ) + + +class Observation(BaseModel): + """Base class for all environment observations. + + All observation subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + done: bool = Field(default=False, description="Whether the episode has terminated") + reward: bool | int | float | None = Field( + default=None, description="Reward signal from the last action" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the observation" + ) + + +class ResetRequest(BaseModel): + """Request model for environment reset.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom reset parameters + json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]}, + ) + + seed: Optional[int] = Field( + default=None, ge=0, description="Random seed for reproducible episodes" + ) + episode_id: Optional[str] = Field( + default=None, max_length=255, description="Custom episode identifier" + ) + + +class ResetResponse(BaseModel): + """Response model for environment reset.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Initial observation from the environment" + ) + reward: Optional[float] = Field( + default=None, description="Initial reward (typically None at reset)" + ) + done: bool = Field( + default=False, description="Whether episode is already done (typically False)" + ) + + +class StepRequest(BaseModel): + """Request model for environment step.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom step parameters + json_schema_extra={ + "examples": [ + {"action": {"value": 1}, "timeout_s": 30.0}, + {"action": {"value": 1}, "render": True, "verbose": False}, + ] + }, + ) + + action: Dict[str, Any] = Field( + ..., + description="Action to execute, must conform to environment's action schema", + ) + timeout_s: Optional[float] = Field( + default=None, + gt=0, + description="Optional timeout in seconds for action execution", + ) + request_id: Optional[str] = Field( + default=None, + max_length=255, + description="Optional request identifier for tracking", + ) + + +class StepResponse(BaseModel): + """Response model for environment step.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Observation resulting from the action" + ) + reward: Optional[float] = Field( + default=None, description="Reward signal from the action" + ) + done: bool = Field(default=False, description="Whether the episode has terminated") + + +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + +class State(BaseModel): + """Base class for environment state. + + Represents internal environment state, separate from observations. + """ + + model_config = ConfigDict( + extra="allow", # Allow extra fields for flexibility + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + episode_id: Optional[str] = Field( + default=None, description="Unique identifier for the current episode" + ) + step_count: int = Field( + default=0, + ge=0, # Greater than or equal to 0 + description="Number of steps taken in the current episode", + ) + + +class CodeExecResult(BaseMessage): + """Result of code execution containing stdout, stderr, and exit code.""" + + stdout: str = Field(description="Standard output from code execution") + stderr: str = Field(description="Standard error from code execution") + exit_code: int = Field(description="Exit code from code execution") + + +class EnvironmentMetadata(BaseMessage): + """Metadata about an environment for documentation and UI purposes.""" + + name: str = Field(description="Name of the environment") + description: str = Field(description="Description of what the environment does") + readme_content: Optional[str] = Field( + default=None, description="Content of the README file for the environment" + ) + version: Optional[str] = Field( + default=None, description="Version of the environment" + ) + author: Optional[str] = Field(default=None, description="Author of the environment") + documentation_url: Optional[str] = Field( + default=None, description="URL to the environment's documentation" + ) + + +class SchemaResponse(BaseMessage): + """Response model for the combined schema endpoint.""" + + action: Dict[str, Any] = Field( + description="JSON schema for actions accepted by this environment" + ) + observation: Dict[str, Any] = Field( + description="JSON schema for observations returned by this environment" + ) + state: Dict[str, Any] = Field( + description="JSON schema for environment state objects" + ) + + +class HealthResponse(BaseMessage): + """Response model for health check endpoint.""" + + status: HealthStatus = Field( + default=HealthStatus.HEALTHY, + description="Health status of the environment server", + ) + + +class WSResetMessage(BaseMessage): + """WebSocket message to reset the environment.""" + + type: Literal["reset"] = Field(default="reset", description="Message type") + data: Dict[str, Any] = Field( + default_factory=dict, + description="Optional reset parameters (seed, episode_id, etc.)", + ) + + +class WSStepMessage(BaseMessage): + """WebSocket message to execute a step.""" + + type: Literal["step"] = Field(default="step", description="Message type") + data: Dict[str, Any] = Field( + ..., description="Action data conforming to environment's action schema" + ) + + +class WSStateMessage(BaseMessage): + """WebSocket message to request current state.""" + + type: Literal["state"] = Field(default="state", description="Message type") + + +class WSCloseMessage(BaseMessage): + """WebSocket message to close the session.""" + + type: Literal["close"] = Field(default="close", description="Message type") + + +# Discriminated union for incoming WebSocket messages +# Note: WSMCPMessage is defined in mcp_types.py to avoid circular imports +# The union here covers the core message types; MCP messages are handled separately +WSIncomingMessage = Annotated[ + WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, + Field(discriminator="type"), +] + + +class WSObservationResponse(BaseModel): + """WebSocket response containing an observation.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["observation"] = Field( + default="observation", description="Response type" + ) + data: Dict[str, Any] = Field(description="Observation data") + + +class WSStateResponse(BaseModel): + """WebSocket response containing environment state.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["state"] = Field(default="state", description="Response type") + data: Dict[str, Any] = Field(description="State data") + + +class WSErrorResponse(BaseModel): + """WebSocket response for errors.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["error"] = Field(default="error", description="Response type") + data: Dict[str, Any] = Field(description="Error details including message and code") + + +class ConcurrencyConfig(BaseMessage): + """Configuration for concurrent environment sessions.""" + + max_concurrent_envs: int = Field( + default=1, + ge=1, + description="Maximum number of concurrent WebSocket sessions allowed", + ) + session_timeout: Optional[float] = Field( + default=None, + gt=0, + description="Timeout in seconds for inactive sessions. None means no timeout.", + ) + + +class ServerCapacityStatus(BaseMessage): + """Status of server capacity for concurrent sessions.""" + + active_sessions: int = Field( + ge=0, + description="Number of currently active sessions", + ) + max_sessions: int = Field( + ge=1, + description="Maximum number of allowed sessions", + ) + + @model_validator(mode="after") + def check_capacity_bounds(self) -> "ServerCapacityStatus": + if self.active_sessions > self.max_sessions: + raise ValueError( + f"active_sessions ({self.active_sessions}) cannot exceed " + f"max_sessions ({self.max_sessions})" + ) + return self + + @property + def available_slots(self) -> int: + """Number of available session slots.""" + return self.max_sessions - self.active_sessions + + @property + def is_at_capacity(self) -> bool: + """Whether the server has reached maximum capacity.""" + return self.available_slots == 0 + + @classmethod + def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": + """Create status from active and max session counts.""" + return cls( + active_sessions=active, + max_sessions=max_sessions, + ) + + +class SessionInfo(BaseMessage): + """Information about an active session.""" + + session_id: str = Field(description="Unique identifier for the session") + created_at: float = Field(description="Unix timestamp when the session was created") + last_activity_at: float = Field( + description="Unix timestamp of the last activity in the session" + ) + step_count: int = Field( + default=0, + ge=0, + description="Number of steps executed in this session", + ) + environment_type: str = Field( + description="Environment type for this session (e.g. `CodingEnv`)" + ) diff --git a/src/core/openenv/core/env_server/web_interface.py b/src/core/openenv/core/env_server/web_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..284740eb408b8e2b798037918967b7a50abee72d --- /dev/null +++ b/src/core/openenv/core/env_server/web_interface.py @@ -0,0 +1,644 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Web interface for OpenEnv environments. + +When ENABLE_WEB_INTERFACE is set, the server exposes a Gradio UI at /web for +reset, step, and state observation. Controlled by the CLI enable_interface +option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var. +""" + +from __future__ import annotations + +import asyncio +import json +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Type + +import gradio as gr +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from pydantic import BaseModel, ConfigDict, Field + +from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME +from .gradio_ui import build_gradio_app, get_gradio_display_title +from .interfaces import Environment +from .serialization import deserialize_action_with_preprocessing, serialize_observation +from .types import Action, EnvironmentMetadata, Observation, State + +# Quick Start markdown template; placeholders match init suffixes (__ENV_NAME__, __ENV_CLASS_NAME__*). +DEFAULT_QUICK_START_MARKDOWN = """ +### Connect to this environment + +Connect from Python using `__ENV_CLASS_NAME__Env`: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +with __ENV_CLASS_NAME__Env.from_env("") as env: + result = await env.step(__ENV_CLASS_NAME__Action(message="...")) +``` + +Or connect directly to a running server: + +```python +env = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") +``` + +### Contribute to this environment + +Submit improvements via pull request on the Hugging Face Hub. + +```bash +openenv fork --repo-id / +``` + +Then make your changes and submit a pull request: + +```bash +cd +openenv push --create-pr +``` + +For more information, see the [OpenEnv documentation](https://meta-pytorch.org/OpenEnv/). +""" + + +def get_quick_start_markdown( + metadata: Optional[EnvironmentMetadata], + action_cls: Type[Action], + observation_cls: Type[Observation], +) -> str: + """ + Build Quick Start markdown with class names replaced from current env (init-style suffixes). + + Uses the same placeholder names as the init template so that __ENV_CLASS_NAME__Env, + __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation and __ENV_NAME__ are + replaced with the actual class/package names. + """ + import os + + # Prefix from action class (e.g. EchoAction -> Echo) + action_name = getattr(action_cls, "__name__", "Action") + if action_name.endswith("Action"): + prefix = action_name[: -len("Action")] + else: + prefix = action_name.replace("Action", "").strip() or "Env" + + env_client_name = f"{prefix}Env" + obs_name = getattr(observation_cls, "__name__", "Observation") + pkg_name = (metadata.name if metadata else "env").replace(" ", "_").lower() + + space_id = os.environ.get("SPACE_ID", "/") + + content = DEFAULT_QUICK_START_MARKDOWN + content = content.replace("__ENV_CLASS_NAME__Env", env_client_name) + content = content.replace("__ENV_CLASS_NAME__Action", action_name) + content = content.replace("__ENV_CLASS_NAME__Observation", obs_name) + content = content.replace("__ENV_CLASS_NAME__", prefix) + content = content.replace("__ENV_NAME__", pkg_name) + content = content.replace("", space_id) + return content.strip() + + +def load_environment_metadata( + env: Environment, env_name: Optional[str] = None +) -> EnvironmentMetadata: + """ + Load environment metadata including README content. + + Args: + env: The environment instance, class, or factory function. + - If a class: used as a factory, won't call instance methods + - If a function: used as a factory, won't call instance methods + - If an instance: may call get_metadata() if available + env_name: Optional environment name for README file lookup + + Returns: + EnvironmentMetadata with loaded information + """ + import inspect + + # Determine what type of env we received: + # 1. A class (used as factory) - e.g., PythonCodeActEnv + # 2. A function (factory function) - e.g., create_chat_environment + # 3. An actual instance - e.g., SnakeEnvironment() + is_class = inspect.isclass(env) + is_function = inspect.isfunction(env) or inspect.ismethod(env) + is_factory = is_class or is_function + + # Try to get metadata from environment if it's an instance with get_metadata + if not is_factory and hasattr(env, "get_metadata"): + return env.get_metadata() + + # Determine the class name for default metadata + if is_class: + # env is the class itself + class_name = env.__name__ + elif is_function: + # env is a factory function - use its name or derive from env_name + class_name = env_name or env.__name__ + else: + # env is an instance + class_name = env.__class__.__name__ + + # Default metadata + metadata = EnvironmentMetadata( + name=env_name or class_name, + description=f"{class_name} environment", + version="1.0.0", + ) + + # Try to load README from file system + readme_content = _load_readme_from_filesystem(env_name) + if readme_content: + metadata.readme_content = readme_content + + return metadata + + +def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: + """ + Load README content from the filesystem. + + Tries multiple locations: + 1. Container filesystem: /app/README.md + 2. Local development: src/envs/{env_name}/README.md + 3. Environment variable: ENV_README_PATH + """ + import os + from pathlib import Path + + # Try container filesystem first + container_readme = Path("/app/README.md") + if container_readme.exists(): + try: + return container_readme.read_text(encoding="utf-8") + except Exception: + pass + + # Try environment variable path + custom_path = os.environ.get("ENV_README_PATH") + if custom_path and Path(custom_path).exists(): + try: + return Path(custom_path).read_text(encoding="utf-8") + except Exception: + pass + + # Try local development path + if env_name: + local_readme = Path(f"src/envs/{env_name}/README.md") + if local_readme.exists(): + try: + return local_readme.read_text(encoding="utf-8") + except Exception: + pass + + return None + + +class ActionLog(BaseModel): + """Log entry for an action taken.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + timestamp: str = Field(description="Timestamp when action was taken") + action: Dict[str, Any] = Field(description="Action that was taken") + observation: Dict[str, Any] = Field(description="Observation returned from action") + reward: Optional[float] = Field( + default=None, description="Reward received from action" + ) + done: bool = Field(description="Whether the episode is done after this action") + step_count: int = Field(description="Step count when this action was taken") + + +class EpisodeState(BaseModel): + """Current episode state for the web interface.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + episode_id: Optional[str] = Field(default=None, description="Current episode ID") + step_count: int = Field(description="Current step count in episode") + current_observation: Optional[Dict[str, Any]] = Field( + default=None, description="Current observation" + ) + action_logs: List[ActionLog] = Field( + default_factory=list, description="List of action logs" + ) + is_reset: bool = Field( + default=True, description="Whether the episode has been reset" + ) + + +class WebInterfaceManager: + """Manages the web interface for an environment.""" + + MAX_ACTION_LOGS = 1000 + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + metadata: Optional[EnvironmentMetadata] = None, + ): + import inspect + + # If env is a class or factory function, instantiate it + if inspect.isclass(env) or inspect.isfunction(env): + self.env = env() + else: + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + self.metadata = metadata or EnvironmentMetadata( + name=env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + ) + self.episode_state = EpisodeState( + episode_id=None, + step_count=0, + current_observation=None, + action_logs=[], + ) + self.connected_clients: List[WebSocket] = [] + # Thread pool for running sync code (e.g., Playwright sync API) in async context + self._executor = ThreadPoolExecutor(max_workers=1) + + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + """Run a synchronous function in the thread pool executor. + + This is needed for environments using sync libraries (e.g., Playwright sync API) + that cannot be called directly from an async context. + """ + loop = asyncio.get_event_loop() + # Use default arguments to capture values at lambda definition time + # to avoid closure issues with late binding + return await loop.run_in_executor( + self._executor, lambda f=func, a=args, kw=kwargs: f(*a, **kw) + ) + + async def connect_websocket(self, websocket: WebSocket): + """Connect a new WebSocket client.""" + await websocket.accept() + self.connected_clients.append(websocket) + + # Send current state to the new client + await self._send_state_update() + + async def disconnect_websocket(self, websocket: WebSocket): + """Disconnect a WebSocket client.""" + if websocket in self.connected_clients: + self.connected_clients.remove(websocket) + + async def _send_state_update(self): + """Send current state to all connected clients.""" + if not self.connected_clients: + return + + state_data = { + "type": "state_update", + "episode_state": self.episode_state.model_dump(), + } + + # Send to all connected clients + disconnected_clients = [] + for client in self.connected_clients: + try: + await client.send_text(json.dumps(state_data)) + except Exception: + disconnected_clients.append(client) + + # Remove disconnected clients + for client in disconnected_clients: + self.connected_clients.remove(client) + + async def reset_environment(self) -> Dict[str, Any]: + """Reset the environment and update state.""" + # Run sync reset in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool(self.env.reset) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = 0 + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs = [] + self.episode_state.is_reset = True + + # Send state update + await self._send_state_update() + + return serialized + + async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: + """Execute a step in the environment and update state.""" + # Deserialize action with preprocessing for web interface special cases + action: Action = deserialize_action_with_preprocessing( + action_data, self.action_cls + ) + + # Run sync step in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool( + self.env.step, action + ) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Create action log + action_log = ActionLog( + timestamp=datetime.now().isoformat(), + action=action.model_dump(exclude={"metadata"}), + observation=serialized["observation"], + reward=observation.reward, + done=observation.done, + step_count=state.step_count, + ) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = state.step_count + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs.append(action_log) + if len(self.episode_state.action_logs) > self.MAX_ACTION_LOGS: + self.episode_state.action_logs = self.episode_state.action_logs[ + -self.MAX_ACTION_LOGS : + ] + self.episode_state.is_reset = False + + # Send state update + await self._send_state_update() + + return serialized + + def get_state(self) -> Dict[str, Any]: + """Get current environment state.""" + state: State = self.env.state + return state.model_dump() + + +def create_web_interface_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[Any] = None, + gradio_builder: Optional[Callable[..., Any]] = None, +) -> FastAPI: + """ + Create a FastAPI application with web interface for the given environment. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings + gradio_builder: Optional callable (web_manager, action_fields, metadata, + is_chat_env, title, quick_start_md) -> gr.Blocks to use instead of the + default Gradio UI. Lets envs replace or customize the /web interface. + + Returns: + FastAPI application instance with web interface + """ + from .http_server import create_fastapi_app + + # Create the base environment app + app = create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) + + # Load environment metadata + metadata = load_environment_metadata(env, env_name) + + # Create web interface manager + web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + + # Web API routes first (so they take precedence over Gradio mount at /web) + @app.get("/web/metadata") + async def web_metadata(): + """Get environment metadata.""" + return web_manager.metadata.model_dump() + + @app.websocket("/ws/ui") + async def websocket_ui_endpoint(websocket: WebSocket): + """WebSocket endpoint for web UI real-time updates. + + Note: Uses /ws/ui to avoid conflict with /ws in http_server.py + which is used for concurrent environment sessions. + """ + await web_manager.connect_websocket(websocket) + try: + while True: + # Keep connection alive + await websocket.receive_text() + except WebSocketDisconnect: + await web_manager.disconnect_websocket(websocket) + + @app.post("/web/reset") + async def web_reset(): + """Reset endpoint for web interface.""" + return await web_manager.reset_environment() + + @app.post("/web/step") + async def web_step(request: Dict[str, Any]): + """Step endpoint for web interface.""" + # Check if this is a message-based request (chat environment) + if "message" in request: + message = request["message"] + if hasattr(web_manager.env, "message_to_action"): + action = web_manager.env.message_to_action(message) + if hasattr(action, "tokens"): + action_data = {"tokens": action.tokens.tolist()} + else: + action_data = action.model_dump(exclude={"metadata"}) + else: + action_data = {"message": message} + else: + action_data = request.get("action", {}) + + return await web_manager.step_environment(action_data) + + @app.get("/web/state") + async def web_state(): + """State endpoint for web interface.""" + return web_manager.get_state() + + action_fields = _extract_action_fields(action_cls) + is_chat_env = _is_chat_env(action_cls) + quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls) + + default_blocks = build_gradio_app( + web_manager, + action_fields, + metadata, + is_chat_env, + title=metadata.name, + quick_start_md=quick_start_md, + ) + if gradio_builder is not None: + custom_blocks = gradio_builder( + web_manager, + action_fields, + metadata, + is_chat_env, + metadata.name, + quick_start_md, + ) + if not isinstance(custom_blocks, gr.Blocks): + raise TypeError( + f"gradio_builder must return a gr.Blocks instance, " + f"got {type(custom_blocks).__name__}" + ) + gradio_blocks = gr.TabbedInterface( + [default_blocks, custom_blocks], + tab_names=["Playground", "Visualization"], + title=get_gradio_display_title(metadata), + ) + else: + gradio_blocks = default_blocks + app = gr.mount_gradio_app( + app, + gradio_blocks, + path="/web", + theme=OPENENV_GRADIO_THEME, + css=OPENENV_GRADIO_CSS, + ) + + return app + + +def _is_chat_env(action_cls: Type[Action]) -> bool: + """Return True if the action class is a chat-style env (tokens field).""" + if hasattr(action_cls, "model_fields"): + for field_name, field_info in action_cls.model_fields.items(): + if ( + field_name == "tokens" + and hasattr(field_info.annotation, "__name__") + and "Tensor" in str(field_info.annotation) + ): + return True + return False + + +def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: + """Extract enhanced field metadata from Action class for form generation.""" + # Use Pydantic's JSON schema generation for robust metadata extraction + try: + schema = action_cls.model_json_schema() + except AttributeError: + # Fallback for non-Pydantic v2 models or if something goes wrong + return [] + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + action_fields = [] + + for field_name, field_info in properties.items(): + if field_name == "metadata": + continue + + # JSON schema "type" can be a string or list/undefined + # Determine our internal input type + input_type = _determine_input_type_from_schema(field_info, field_name) + + is_required = field_name in required_fields + + action_fields.append( + { + "name": field_name, + "type": input_type, + "required": is_required, + "description": field_info.get("description", ""), + "default_value": field_info.get("default"), + "choices": field_info.get("enum"), + "min_value": field_info.get("minimum"), + "max_value": field_info.get("maximum"), + "min_length": field_info.get("minLength"), + "max_length": field_info.get("maxLength"), + "pattern": field_info.get("pattern"), + "placeholder": _generate_placeholder(field_name, field_info), + "help_text": _generate_help_text(field_name, field_info), + } + ) + + return action_fields + + +def _determine_input_type_from_schema( + field_info: Dict[str, Any], field_name: str +) -> str: + """Determine input type from JSON schema for form generation (Gradio UI).""" + schema_type = field_info.get("type") + + # Check for specific tensor field convention + if "tokens" in field_name.lower(): + return "tensor" + + if "enum" in field_info: + return "select" + + if schema_type == "boolean": + return "checkbox" + + if schema_type == "integer" or schema_type == "number": + return "number" + + if schema_type == "string": + # Check if it should be a textarea + if ( + field_info.get("maxLength", 0) > 100 + or "message" in field_name.lower() + or "code" in field_name.lower() + ): + return "textarea" + return "text" + + # Default fallback + return "text" + + +def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate placeholder text.""" + if "message" in field_name.lower(): + return f"Enter {field_name.replace('_', ' ')}..." + elif "code" in field_name.lower(): + return "Enter Python code here..." + elif "tokens" in field_name.lower(): + return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" + else: + return f"Enter {field_name.replace('_', ' ')}..." + + +def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate help text.""" + description = field_info.get("description", "") + if description: + return description + + if "action_id" in field_name.lower(): + return "The action ID to execute in environment" + elif "game_name" in field_name.lower(): + return "Name of game or environment" + elif "tokens" in field_name.lower(): + return "Token IDs as a comma-separated list of integers" + elif "code" in field_name.lower(): + return "Python code to execute in environment" + elif "message" in field_name.lower(): + return "Text message to send" + + return "" diff --git a/src/core/openenv/core/evals/__init__.py b/src/core/openenv/core/evals/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52e564a09b5e4976f2cd5a8c1fe1c7848bb47ecb --- /dev/null +++ b/src/core/openenv/core/evals/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Evaluation harness support for OpenEnv.""" + +from openenv.core.evals.base import EvalHarness +from openenv.core.evals.inspect_harness import InspectAIHarness +from openenv.core.evals.types import EvalConfig, EvalResult + +__all__ = [ + "EvalHarness", + "EvalConfig", + "EvalResult", + "InspectAIHarness", +] diff --git a/src/core/openenv/core/evals/base.py b/src/core/openenv/core/evals/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e457d8adb740569ad79143cbf70bc58b05a8cef9 --- /dev/null +++ b/src/core/openenv/core/evals/base.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base class for evaluation harnesses.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + +from openenv.core.evals.types import EvalConfig, EvalResult + + +class EvalHarness(ABC): + """Abstract base class for evaluation harnesses. + + Subclasses implement run() to define evaluation logic. + """ + + @abstractmethod + def run( + self, + harness_version: str, + library_versions: Dict[str, str], + dataset: str, + eval_parameters: Dict[str, Any], + ) -> Dict[str, Any]: + """Run the evaluation and return scores. + + Args: + harness_version: Version of the evaluation harness. + library_versions: Versions of libraries used in the evaluation. + dataset: Name of the dataset to evaluate on. + eval_parameters: Parameters for the evaluation. + + Returns: + Dictionary of scores from the evaluation. + """ + raise NotImplementedError + + def run_from_config(self, config: EvalConfig) -> EvalResult: + """Run evaluation from an EvalConfig and return an EvalResult. + + Args: + config: Configuration for the evaluation. + + Returns: + EvalResult containing the config and scores. + """ + scores = self.run( + harness_version=config.harness_version, + library_versions=config.library_versions, + dataset=config.dataset, + eval_parameters=config.eval_parameters, + ) + return EvalResult(config=config, scores=scores) + + @property + def name(self) -> str: + """Return the name of the harness (class name).""" + return self.__class__.__name__ diff --git a/src/core/openenv/core/evals/inspect_harness.py b/src/core/openenv/core/evals/inspect_harness.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf91105db6cf325587623891905e5cbc71c124e --- /dev/null +++ b/src/core/openenv/core/evals/inspect_harness.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Inspect AI harness integration for OpenEnv. + +Requires the ``inspect-ai`` package: ``pip install 'inspect-ai>=0.3.0'`` +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from openenv.core.evals.base import EvalHarness + + +class InspectAIHarness(EvalHarness): + """Evaluation harness wrapping Inspect AI's ``eval()`` function. + + All ``inspect_ai`` imports are deferred to :meth:`run` so this class is + importable without inspect-ai installed. An ``ImportError`` with a clear + message is raised at call time if the dependency is missing. + + Args: + log_dir: Directory for evaluation log output. Defaults to None + (Inspect AI writes logs to its default location). + + ``eval_parameters`` keys accepted by :meth:`run`: + + +--------------------------+----------+-----------------+-----------------------------------+ + | Key | Type | Default | Purpose | + +==========================+==========+=================+===================================+ + | ``model`` | str | *required* | Model string, e.g. "openai/gpt-4o"| + | ``task`` | str|None | ``dataset`` arg | Task file path or task string | + | ``task_args`` | dict | ``{}`` | Arguments to pass to the task | + | ``max_samples`` | int|None | None | Limit samples per task | + | ``temperature`` | float|None| None | Model generation temperature | + | ``max_tokens`` | int|None | None | Max generation tokens | + | ``epochs`` | int|None | None | Number of evaluation epochs | + | ``solver`` | list|None| None | Solver pipeline override | + | ``scorer`` | list|None| None | Scorer override | + | ``model_args`` | dict | ``{}`` | Provider-specific model kwargs | + +--------------------------+----------+-----------------+-----------------------------------+ + """ + + def __init__( + self, + *, + log_dir: Optional[str] = None, + ): + self.log_dir = log_dir + + def run( + self, + harness_version: str, + library_versions: Dict[str, str], + dataset: str, + eval_parameters: Dict[str, Any], + ) -> Dict[str, Any]: + """Run an Inspect AI evaluation. + + Args: + harness_version: Version of inspect-ai being used. + library_versions: Versions of supporting libraries. + dataset: Default task string (used when ``task`` is not specified + in *eval_parameters*). + eval_parameters: See class docstring for accepted keys. + + Returns: + Dictionary mapping metric names to scores. + + Raises: + ImportError: If ``inspect-ai`` is not installed. + ValueError: If ``model`` is missing from *eval_parameters*. + RuntimeError: If the evaluation fails (log status is not "success"). + """ + try: + from inspect_ai import eval as inspect_eval + except ImportError: + raise ImportError( + "inspect-ai is required for InspectAIHarness. " + "Install it with: pip install 'inspect-ai>=0.3.0'" + ) + + # Extract required model parameter + model = eval_parameters.get("model") + if model is None: + raise ValueError( + "eval_parameters must include 'model' " + "(e.g. 'openai/gpt-4o', 'hf/meta-llama/...')." + ) + + # Task: explicit parameter or fall back to dataset + task = eval_parameters.get("task", dataset) + + # Build eval kwargs + eval_kwargs: Dict[str, Any] = {} + + task_args = eval_parameters.get("task_args", {}) + if task_args: + eval_kwargs["task_args"] = task_args + + model_args = eval_parameters.get("model_args", {}) + if model_args: + eval_kwargs["model_args"] = model_args + + for key in ("max_samples", "temperature", "max_tokens", "epochs"): + value = eval_parameters.get(key) + if value is not None: + eval_kwargs[key] = value + + if eval_parameters.get("solver") is not None: + eval_kwargs["solver"] = eval_parameters["solver"] + + if eval_parameters.get("scorer") is not None: + eval_kwargs["scorer"] = eval_parameters["scorer"] + + if self.log_dir is not None: + eval_kwargs["log_dir"] = self.log_dir + + # Run evaluation + logs = inspect_eval(task, model=model, **eval_kwargs) + + # Extract results from the first log + if not logs: + raise RuntimeError( + "Inspect AI evaluation returned no logs. " + "Check that the task and model arguments are valid." + ) + log = logs[0] + if log.status != "success": + raise RuntimeError( + f"Inspect AI evaluation failed with status: {log.status}" + ) + + return self._extract_scores(log) + + def _extract_scores(self, log: Any) -> Dict[str, Any]: + """Parse an EvalLog's results into a flat score dictionary. + + Iterates over ``log.results.scores`` (a list of ``EvalScore``), + flattening each scorer's ``metrics`` dict into a single output dict. + + Args: + log: An ``inspect_ai`` ``EvalLog`` object. + + Returns: + Dictionary mapping metric names to their values. + """ + scores: Dict[str, Any] = {} + if log.results is None: + return scores + + for eval_score in log.results.scores: + for metric_name, metric in eval_score.metrics.items(): + scores[metric_name] = metric.value + + return scores diff --git a/src/core/openenv/core/evals/types.py b/src/core/openenv/core/evals/types.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6b14f762624c607c345e5dff1bc77faa5b4b56 --- /dev/null +++ b/src/core/openenv/core/evals/types.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pydantic models for eval configuration and results.""" + +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, Field + + +class EvalConfig(BaseModel): + """Configuration for running an evaluation.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + harness_name: str = Field(description="Name of the evaluation harness") + harness_version: str = Field(description="Version of the evaluation harness") + library_versions: Dict[str, str] = Field( + description="Versions of libraries used in the evaluation" + ) + dataset: str = Field(description="Name of the dataset to evaluate on") + eval_parameters: Dict[str, Any] = Field(description="Parameters for the evaluation") + + +class EvalResult(BaseModel): + """Result of running an evaluation.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + config: EvalConfig = Field(description="Configuration used for the evaluation") + scores: Dict[str, Any] = Field(description="Scores from the evaluation") diff --git a/src/core/openenv/core/generic_client.py b/src/core/openenv/core/generic_client.py new file mode 100644 index 0000000000000000000000000000000000000000..17576862293feeebf68b4a90d6a4a80de369dd34 --- /dev/null +++ b/src/core/openenv/core/generic_client.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Generic environment client that works with raw dictionaries. + +This module provides a GenericEnvClient that doesn't require installing +environment-specific packages. It's useful for connecting to remote servers +without running any untrusted code locally. +""" + +from typing import Any, Dict + +from .client_types import StepResult +from .env_client import EnvClient + + +class GenericEnvClient(EnvClient[Dict[str, Any], Dict[str, Any], Dict[str, Any]]): + """ + Environment client that works with raw dictionaries instead of typed classes. + + This client doesn't require installing environment-specific packages, making it + ideal for: + - Connecting to remote servers without installing their packages + - Quick prototyping and testing + - Environments where type safety isn't needed + - Security-conscious scenarios where you don't want to run remote code + + The trade-off is that you lose type safety and IDE autocomplete for actions + and observations. Instead of typed objects, you work with plain dictionaries. + + Example: + >>> # Direct connection to a running server (no installation needed) + >>> with GenericEnvClient(base_url="http://localhost:8000") as env: + ... result = env.reset() + ... result = env.step({"code": "print('hello')"}) + ... print(result.observation) # Dict[str, Any] + ... print(result.observation.get("output")) + + >>> # From local Docker image + >>> env = GenericEnvClient.from_docker_image("coding-env:latest") + >>> result = env.reset() + >>> result = env.step({"code": "x = 1 + 2"}) + >>> env.close() + + >>> # From HuggingFace Hub (pulls Docker image, no pip install) + >>> env = GenericEnvClient.from_env("user/my-env", use_docker=True) + >>> result = env.reset() + >>> env.close() + + Note: + GenericEnvClient inherits `from_docker_image()` and `from_env()` from + EnvClient, so you can use it with Docker containers and HuggingFace + Spaces without any package installation. + """ + + def _step_payload(self, action: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert action to payload for the server. + + For GenericEnvClient, this handles both raw dictionaries and + typed Action objects (Pydantic models). If a Pydantic model is + passed, it will be converted to a dictionary using model_dump(). + + Args: + action: Action as a dictionary or Pydantic BaseModel + + Returns: + The action as a dictionary for the server + """ + # If it's already a dict, return as-is + if isinstance(action, dict): + return action + + # If it's a Pydantic model (Action subclass), convert to dict + if hasattr(action, "model_dump"): + return action.model_dump() + + # Fallback for other objects with __dict__ + if hasattr(action, "__dict__"): + return vars(action) + + # Last resort: try to convert to dict + return dict(action) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Dict[str, Any]]: + """ + Parse server response into a StepResult. + + Extracts the observation, reward, and done fields from the + server response. + + Args: + payload: Response payload from the server + + Returns: + StepResult with observation as a dictionary + """ + return StepResult( + observation=payload.get("observation", {}), + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """ + Parse state response from the server. + + For GenericEnvClient, this returns the payload as-is since + we're working with dictionaries. + + Args: + payload: State payload from the server + + Returns: + The state as a dictionary + """ + return payload + + +class GenericAction(Dict[str, Any]): + """ + A dictionary subclass for creating actions when using GenericEnvClient. + + This provides a semantic wrapper around dictionaries to make code more + readable when working with GenericEnvClient. It behaves exactly like a + dict but signals intent that this is an action for an environment. + + Example: + >>> # Without GenericAction (works fine) + >>> env.step({"code": "print('hello')"}) + + >>> # With GenericAction (more explicit) + >>> action = GenericAction(code="print('hello')") + >>> env.step(action) + + >>> # With multiple fields + >>> action = GenericAction(code="x = 1", timeout=30, metadata={"tag": "test"}) + >>> env.step(action) + + Note: + GenericAction is just a dict with a constructor that accepts keyword + arguments. It's provided for symmetry with typed Action classes and + to make code more readable. + """ + + def __init__(self, **kwargs: Any) -> None: + """ + Create a GenericAction from keyword arguments. + + Args: + **kwargs: Action fields as keyword arguments + + Example: + >>> action = GenericAction(code="print(1)", timeout=30) + >>> action["code"] + 'print(1)' + """ + super().__init__(kwargs) + + def __repr__(self) -> str: + """Return a readable representation.""" + items = ", ".join(f"{k}={v!r}" for k, v in self.items()) + return f"GenericAction({items})" diff --git a/src/core/openenv/core/llm_client.py b/src/core/openenv/core/llm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9df2ff27ae7c2054108ff159b9dec8e4c9dd238c --- /dev/null +++ b/src/core/openenv/core/llm_client.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""LLM client abstraction for calling LLM endpoints. + +Provides a generic RPC abstraction: point it at an endpoint/port, tell it the +protocol, and it works. OpenAI-compatible API is the first implementation, +covering OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, etc. +Anthropic's native API is supported via ``AnthropicClient``. + +Usage: + client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") + response = await client.complete("What is 2+2?") + + # Or use the factory for hosted APIs: + client = create_llm_client("openai", model="gpt-4", api_key="sk-...") + response = await client.complete_with_tools(messages, tools) +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from openai import AsyncOpenAI + + +@dataclass +class ToolCall: + """A single tool/function call returned by the model.""" + + id: str + name: str + args: dict[str, Any] + + +@dataclass +class LLMResponse: + """Normalized response from an LLM, with optional tool calls.""" + + content: str + tool_calls: list[ToolCall] = field(default_factory=list) + + def to_message_dict(self) -> dict[str, Any]: + """Convert to an OpenAI-format assistant message dict.""" + msg: dict[str, Any] = {"role": "assistant", "content": self.content} + if self.tool_calls: + msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.args), + }, + } + for tc in self.tool_calls + ] + return msg + + +class LLMClient(ABC): + """Abstract base for LLM endpoint clients. + + Subclass and implement ``complete()`` for your protocol. + + Args: + endpoint: The base URL of the LLM service (e.g. "http://localhost"). + port: The port the service listens on. + """ + + def __init__(self, endpoint: str, port: int): + self.endpoint = endpoint + self.port = port + + @abstractmethod + async def complete(self, prompt: str, **kwargs) -> str: + """Send a prompt, return the text response. + + Args: + prompt: The user prompt to send. + **kwargs: Override default parameters (temperature, max_tokens, etc.). + + Returns: + The model's text response. + """ + ... + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + """Send messages with tool definitions, return a normalized response. + + Messages use OpenAI-format dicts (``{"role": "...", "content": "..."}``). + Tools use MCP tool definitions; they are converted internally. + + Args: + messages: Conversation history as OpenAI-format message dicts. + tools: MCP tool definitions. + **kwargs: Override default parameters (temperature, max_tokens, etc.). + + Returns: + An ``LLMResponse`` with the model's text and any tool calls. + """ + raise NotImplementedError( + f"{type(self).__name__} does not support tool calling" + ) + + @property + def base_url(self) -> str: + """Construct base URL from endpoint and port.""" + return f"{self.endpoint}:{self.port}" + + +class OpenAIClient(LLMClient): + """Client for OpenAI-compatible APIs. + + Works with: OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, + or any endpoint that speaks the OpenAI chat completions format. + + Args: + endpoint: The base URL (e.g. "http://localhost"). + port: The port number. + model: Model name to pass to the API. + api_key: API key. Defaults to "not-needed" for local endpoints. + system_prompt: Optional system message prepended to every request. + temperature: Default sampling temperature. + max_tokens: Default max tokens in the response. + """ + + def __init__( + self, + endpoint: str, + port: int, + model: str, + api_key: str | None = None, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 256, + ): + super().__init__(endpoint, port) + self.model = model + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + + self._client = AsyncOpenAI( + base_url=f"{self.base_url}/v1", + api_key=api_key if api_key is not None else "not-needed", + ) + + async def complete(self, prompt: str, **kwargs) -> str: + """Send a chat completion request. + + Args: + prompt: The user message. + **kwargs: Overrides for temperature, max_tokens. + + Returns: + The assistant's response text. + """ + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = await self._client.chat.completions.create( + model=self.model, + messages=messages, + temperature=kwargs.get("temperature", self.temperature), + max_tokens=kwargs.get("max_tokens", self.max_tokens), + ) + return response.choices[0].message.content or "" + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": messages, + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + openai_tools = _mcp_tools_to_openai(tools) + if openai_tools: + create_kwargs["tools"] = openai_tools + + response = await self._client.chat.completions.create(**create_kwargs) + msg = response.choices[0].message + + tool_calls = [] + if msg.tool_calls: + for tc in msg.tool_calls: + tool_calls.append( + ToolCall( + id=tc.id, + name=tc.function.name, + args=json.loads(tc.function.arguments), + ) + ) + + return LLMResponse(content=msg.content or "", tool_calls=tool_calls) + + +class AnthropicClient(LLMClient): + """Client for Anthropic's Messages API. + + Requires the ``anthropic`` package (lazy-imported at construction time). + + Args: + endpoint: The base URL (e.g. "https://api.anthropic.com"). + port: The port number. + model: Model name (e.g. "claude-sonnet-4-20250514"). + api_key: Anthropic API key. + system_prompt: Optional system message prepended to every request. + temperature: Default sampling temperature. + max_tokens: Default max tokens in the response. + """ + + def __init__( + self, + endpoint: str, + port: int, + model: str, + api_key: str | None = None, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 256, + ): + super().__init__(endpoint, port) + self.model = model + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + + try: + from anthropic import AsyncAnthropic + except ImportError as exc: + raise ImportError( + "AnthropicClient requires the 'anthropic' package. " + "Install it with: pip install anthropic" + ) from exc + + self._client = AsyncAnthropic( + base_url=self.base_url, + api_key=api_key if api_key is not None else "not-needed", + ) + + async def complete(self, prompt: str, **kwargs) -> str: + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + if self.system_prompt: + create_kwargs["system"] = self.system_prompt + + response = await self._client.messages.create(**create_kwargs) + return "".join(block.text for block in response.content if block.type == "text") + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + system, anthropic_msgs = _openai_msgs_to_anthropic(messages) + + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": anthropic_msgs, + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + system_text = system or self.system_prompt + if system_text: + create_kwargs["system"] = system_text + anthropic_tools = _mcp_tools_to_anthropic(tools) + if anthropic_tools: + create_kwargs["tools"] = anthropic_tools + + response = await self._client.messages.create(**create_kwargs) + + content = "" + tool_calls = [] + for block in response.content: + if block.type == "text": + content += block.text + elif block.type == "tool_use": + tool_calls.append( + ToolCall(id=block.id, name=block.name, args=block.input) + ) + + return LLMResponse(content=content, tool_calls=tool_calls) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +_HOSTED_PROVIDERS: dict[str, tuple[str, int, type[LLMClient]]] = { + "openai": ("https://api.openai.com", 443, OpenAIClient), + "anthropic": ("https://api.anthropic.com", 443, AnthropicClient), +} + + +def create_llm_client( + provider: str, + model: str, + api_key: str, + *, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 4096, +) -> LLMClient: + """Create an LLM client for a hosted provider. + + Args: + provider: Provider name ("openai" or "anthropic"). + model: Model identifier. + api_key: API key for the provider. + system_prompt: Optional system message prepended to every request. + temperature: Sampling temperature. + max_tokens: Maximum tokens in the response. + + Returns: + A configured ``LLMClient`` instance. + """ + key = provider.lower() + if key not in _HOSTED_PROVIDERS: + raise ValueError( + f"Unsupported provider: {provider!r}. " + f"Supported: {sorted(_HOSTED_PROVIDERS)}" + ) + endpoint, port, cls = _HOSTED_PROVIDERS[key] + return cls( + endpoint, + port, + model, + api_key=api_key, + system_prompt=system_prompt, + temperature=temperature, + max_tokens=max_tokens, + ) + + +# --------------------------------------------------------------------------- +# MCP tool-schema helpers +# --------------------------------------------------------------------------- + + +def _clean_mcp_schema(schema: dict[str, Any]) -> dict[str, Any]: + """Normalize an MCP tool ``inputSchema`` for LLM function-calling APIs.""" + if not isinstance(schema, dict): + return {"type": "object", "properties": {}, "required": []} + + # Shallow copy to avoid mutating the caller's schema dict. + schema = dict(schema) + + if "oneOf" in schema: + for option in schema["oneOf"]: + if isinstance(option, dict) and option.get("type") == "object": + schema = option + break + else: + return {"type": "object", "properties": {}, "required": []} + + if "allOf" in schema: + merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + for sub in schema["allOf"]: + if isinstance(sub, dict): + if "properties" in sub: + merged["properties"].update(sub["properties"]) + if "required" in sub: + merged["required"].extend(sub["required"]) + schema = merged + + if "anyOf" in schema: + for option in schema["anyOf"]: + if isinstance(option, dict) and option.get("type") == "object": + schema = option + break + else: + return {"type": "object", "properties": {}, "required": []} + + schema.setdefault("type", "object") + if schema.get("type") == "object" and "properties" not in schema: + schema["properties"] = {} + return schema + + +def _mcp_tools_to_openai( + mcp_tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert MCP tool definitions to OpenAI function-calling format.""" + result = [] + for tool in mcp_tools: + input_schema = tool.get( + "inputSchema", {"type": "object", "properties": {}, "required": []} + ) + result.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": _clean_mcp_schema(input_schema), + }, + } + ) + return result + + +def _mcp_tools_to_anthropic( + mcp_tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert MCP tool definitions to Anthropic tool format.""" + result = [] + for tool in mcp_tools: + input_schema = tool.get( + "inputSchema", {"type": "object", "properties": {}, "required": []} + ) + result.append( + { + "name": tool["name"], + "description": tool.get("description", ""), + "input_schema": _clean_mcp_schema(input_schema), + } + ) + return result + + +def _openai_msgs_to_anthropic( + messages: list[dict[str, Any]], +) -> tuple[str, list[dict[str, Any]]]: + """Convert OpenAI-format messages to Anthropic format. + + Returns ``(system_text, anthropic_messages)``. System-role messages are + extracted and concatenated; tool-result messages are converted to + Anthropic's ``tool_result`` content blocks inside user turns. + """ + system_parts: list[str] = [] + anthropic_msgs: list[dict[str, Any]] = [] + + for msg in messages: + role = msg["role"] + + if role == "system": + system_parts.append(msg["content"]) + + elif role == "user": + anthropic_msgs.append({"role": "user", "content": msg["content"]}) + + elif role == "assistant": + if msg.get("tool_calls"): + content: list[dict[str, Any]] = [] + if msg.get("content"): + content.append({"type": "text", "text": msg["content"]}) + for tc in msg["tool_calls"]: + args = tc["function"]["arguments"] + if isinstance(args, str): + args = json.loads(args) + content.append( + { + "type": "tool_use", + "id": tc["id"], + "name": tc["function"]["name"], + "input": args, + } + ) + anthropic_msgs.append({"role": "assistant", "content": content}) + else: + anthropic_msgs.append( + {"role": "assistant", "content": msg.get("content", "")} + ) + + elif role == "tool": + tool_result = { + "type": "tool_result", + "tool_use_id": msg["tool_call_id"], + "content": msg["content"], + } + # Anthropic requires tool results in user turns; merge if possible. + if ( + anthropic_msgs + and anthropic_msgs[-1]["role"] == "user" + and isinstance(anthropic_msgs[-1]["content"], list) + ): + anthropic_msgs[-1]["content"].append(tool_result) + else: + anthropic_msgs.append({"role": "user", "content": [tool_result]}) + + system = "\n\n".join(system_parts) + return system, anthropic_msgs diff --git a/src/core/openenv/core/mcp_client.py b/src/core/openenv/core/mcp_client.py new file mode 100644 index 0000000000000000000000000000000000000000..edac3529d3a34e798781d86cf4d2495dc9611713 --- /dev/null +++ b/src/core/openenv/core/mcp_client.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP Client classes for tool-calling environments. + +This module provides async client classes for interacting with MCP-enabled environments: +- MCPClientBase: Base class with shared tool discovery +- MCPToolClient: Client for tool-calling style (one tool per step) + +These clients abstract away the MCP protocol details, providing a clean interface +for listing and calling tools on remote environments. All clients are async by default. + +Architecture Overview:: + + ┌─────────────────────────────────────────────────────────┐ + │ HTTPEnvServer │ + ├─────────────────────────────────────────────────────────┤ + │ Simulation Mode (default): │ + │ /ws → OpenEnv protocol (reset/step/state) │ + │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ + │ /reset, /step, /state → HTTP endpoints │ + ├─────────────────────────────────────────────────────────┤ + │ Production Mode (use_production_mode=True): │ + │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ + │ Bypasses step() for direct tool access │ + └─────────────────────────────────────────────────────────┘ + + Client Usage: + MCPToolClient (default) → /ws (step-based, with rewards) + MCPToolClient (production) → /mcp (direct tool access, no rewards) + +Example (async): + >>> from openenv.core.mcp_client import MCPToolClient + >>> + >>> async with MCPToolClient(base_url="http://localhost:8000") as env: + ... # Discover available tools + ... tools = await env.list_tools() + ... print([t.name for t in tools]) + ... + ... # Call a tool + ... result = await env.call_tool("echo_message", message="Hello!") + ... print(result) + +Example (sync wrapper): + >>> env = MCPToolClient(base_url="http://localhost:8000").sync() + >>> with env: + ... tools = env.list_tools() + ... result = env.call_tool("echo_message", message="Hello!") +""" + +from typing import Any, Dict, List, Optional + +from .client_types import StepResult +from .env_client import EnvClient +from .env_server.mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + Tool, + ToolError, +) +from .env_server.types import Observation, State + + +class MCPClientBase(EnvClient[Any, Observation, State]): + """ + Base class for MCP clients with tool discovery. + + This class provides the common `list_tools()` method for discovering + available tools from an MCP-enabled environment. Subclasses implement + specific interaction patterns (tool-calling or CodeAct). + + Attributes: + _tools_cache: Cached list of tools (populated on first `list_tools()` call) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional[Any] = None, + mode: Optional[str] = None, + ): + """ + Initialize MCP client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + connect_timeout_s: Timeout for establishing WebSocket connection. + message_timeout_s: Timeout for receiving responses to messages. + provider: Optional container/runtime provider for lifecycle management. + mode: Communication mode. Must be 'production' for MCP clients. Defaults to 'production'. + """ + # MCPClientBase defaults to production mode, but allow override for validation + if mode is None: + mode = "production" + + # Validate that mode is production + mode_lower = mode.lower() + if mode_lower != "production": + raise ValueError( + f"MCPToolClient only supports 'production' mode, got '{mode}'. " + f"Use GenericEnvClient for simulation mode." + ) + + super().__init__( + base_url=base_url, + connect_timeout_s=connect_timeout_s, + message_timeout_s=message_timeout_s, + provider=provider, + mode=mode, + ) + self._tools_cache: Optional[List[Tool]] = None + self.use_production_mode = False + + async def list_tools(self, use_cache: bool = True) -> List[Tool]: + """ + Discover available tools from the environment. + + Args: + use_cache: If True, return cached tools if available. + Set to False to force a fresh request. + + Returns: + List of Tool objects with name, description, and input_schema. + + Example: + >>> tools = await env.list_tools() + >>> for tool in tools: + ... print(f"{tool.name}: {tool.description}") + """ + if use_cache and self._tools_cache is not None: + return self._tools_cache + + # Use production mode HTTP endpoint if enabled + if self.use_production_mode: + import requests + + # Convert ws:// URL to http:// URL + url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") + # Remove /ws suffix if present and add /mcp + url = url.rstrip("/ws").rstrip("/") + "/mcp" + + try: + response = requests.post( + url, + json={ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 1, + }, + ) + data = response.json() + if "result" in data and "tools" in data["result"]: + tools = [ + Tool( + name=t.get("name", ""), + description=t.get("description", ""), + input_schema=t.get( + "input_schema", t.get("inputSchema", {}) + ), + ) + for t in data["result"]["tools"] + ] + self._tools_cache = tools + return tools + except Exception: + # If HTTP request fails, return empty list + pass + return [] + + result = await self.step(ListToolsAction()) + self._tools_cache = result.observation.tools + return self._tools_cache + + def _step_payload(self, action: Any) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + if isinstance(action, ListToolsAction): + return {"type": "list_tools"} + elif isinstance(action, CallToolAction): + return { + "type": "call_tool", + "tool_name": action.tool_name, + "arguments": action.arguments, + } + else: + # For unknown actions, try to serialize as dict + if hasattr(action, "model_dump"): + return action.model_dump() + return {"action": str(action)} + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: + """Convert a JSON response from the env server to StepResult[Observation].""" + obs_data = payload.get("observation", {}) + + # Check if this is a ListToolsObservation + if "tools" in obs_data: + tools = [ + Tool( + name=t.get("name", ""), + description=t.get("description", ""), + input_schema=t.get("input_schema", t.get("inputSchema", {})), + ) + for t in obs_data.get("tools", []) + ] + observation = ListToolsObservation( + tools=tools, + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + # Check if this is a CallToolObservation + elif "tool_name" in obs_data: + error = None + if obs_data.get("error"): + error = ToolError(**obs_data["error"]) + + observation = CallToolObservation( + tool_name=obs_data.get("tool_name", ""), + result=obs_data.get("result"), + error=error, + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + else: + # Generic observation + observation = Observation( + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> State: + """Convert a JSON response from the state endpoint to a State object.""" + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) + + +class MCPToolClient(MCPClientBase): + """ + Async client for tool-calling style MCP interactions. + + Each step invokes a single tool. Use this for traditional function-calling + agent patterns where the agent decides which tool to call next. + + This client provides convenience methods for tool discovery and invocation: + - `list_tools()`: Get all available tools with their schemas + - `call_tool(name, **kwargs)`: Invoke a tool by name with arguments + + Example (async): + >>> async with MCPToolClient(base_url="http://localhost:8000") as env: + ... # Reset the environment + ... await env.reset() + ... + ... # Discover available tools + ... tools = await env.list_tools() + ... print([t.name for t in tools]) # ['echo_message', 'echo_with_length'] + ... + ... # Call a tool directly + ... result = await env.call_tool("echo_message", message="Hello!") + ... print(result) # "Hello!" + ... + ... # Or use the full action interface + ... from openenv.core.env_server.mcp_types import CallToolAction + ... step_result = await env.step(CallToolAction( + ... tool_name="echo_with_length", + ... arguments={"message": "Test"} + ... )) + ... print(step_result.observation.result) + + Example (sync wrapper): + >>> env = MCPToolClient(base_url="http://localhost:8000").sync() + >>> with env: + ... tools = env.list_tools() + ... result = env.call_tool("echo_message", message="Hello!") + """ + + async def call_tool(self, name: str, **kwargs: Any) -> Any: + """ + Call a tool by name. + + This is a convenience method that creates a CallToolAction, executes it, + and returns the result directly. For more control, use `step()` with + a CallToolAction directly. + + Args: + name: Name of the tool to invoke (must match a tool from `list_tools()`). + **kwargs: Arguments to pass to the tool. Must match the tool's input_schema. + + Returns: + The tool's result. The type depends on the tool being called. + + Raises: + RuntimeError: If the server returns an error response. + + Example: + >>> result = await env.call_tool("add", a=5, b=3) + >>> print(result) # 8 + >>> + >>> result = await env.call_tool("greet", name="Claude") + >>> print(result) # "Hello, Claude!" + """ + action = CallToolAction(tool_name=name, arguments=kwargs) + result = await self.step(action) + obs = result.observation + + # Check for transport/framework errors + if isinstance(obs, CallToolObservation) and obs.error is not None: + raise RuntimeError( + f"Tool '{name}' failed: {obs.error.message} " + f"(type: {obs.error.error_type.value})" + ) + + # Return the result + if isinstance(obs, CallToolObservation): + result = obs.result + # Handle FastMCP CallToolResult objects + # - As object: has .data attribute + # - As dict (from JSON): has "data" key + if hasattr(result, "data"): + return result.data + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + + # Fallback for unexpected observation types + return obs + + async def get_tool(self, name: str) -> Optional[Tool]: + """ + Get a specific tool by name. + + Args: + name: Name of the tool to find. + + Returns: + The Tool object if found, None otherwise. + + Example: + >>> tool = await env.get_tool("echo_message") + >>> if tool: + ... print(tool.description) + ... print(tool.input_schema) + """ + tools = await self.list_tools() + for tool in tools: + if tool.name == name: + return tool + return None + + async def has_tool(self, name: str) -> bool: + """ + Check if a tool exists. + + Args: + name: Name of the tool to check. + + Returns: + True if the tool exists, False otherwise. + """ + return await self.get_tool(name) is not None diff --git a/src/core/openenv/core/rubrics/__init__.py b/src/core/openenv/core/rubrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe368494b70cfbabb04e86cb2277aa8c838bdf7 --- /dev/null +++ b/src/core/openenv/core/rubrics/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Rubrics for reward computation. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +from openenv.core.rubrics.base import Rubric +from openenv.core.rubrics.containers import ( + Gate, + RubricDict, + RubricList, + Sequential, + WeightedSum, +) +from openenv.core.rubrics.llm_judge import LLMJudge +from openenv.core.rubrics.trajectory import ( + ExponentialDiscountingTrajectoryRubric, + TrajectoryRubric, +) + +__all__ = [ + # Base + "Rubric", + # Containers + "Sequential", + "Gate", + "WeightedSum", + "RubricList", + "RubricDict", + # Trajectory + "TrajectoryRubric", + "ExponentialDiscountingTrajectoryRubric", + # LLM Judge + "LLMJudge", +] diff --git a/src/core/openenv/core/rubrics/base.py b/src/core/openenv/core/rubrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..38c7a381bc4f40a7bc1dac832902e9e6ac93a282 --- /dev/null +++ b/src/core/openenv/core/rubrics/base.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base Rubric class for reward computation. + +Rubrics compute rewards from actions and observations. The API is modeled +after PyTorch's nn.Module: users implement forward(), and the framework +handles child registration and hooks. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + + +class Rubric(ABC): + """Abstract base class for reward computation. + + A Rubric computes a reward signal from an action and observation. + Subclasses implement forward() to define the reward logic. + + Usage: + class MyRubric(Rubric): + def forward(self, action, observation) -> float: + return 1.0 if action.valid else 0.0 + + rubric = MyRubric() + reward = rubric(action, observation) + + Child rubrics are auto-registered when assigned as attributes, + enabling hierarchical composition and introspection. + """ + + _rubric_children: Dict[str, "Rubric"] + _forward_hooks: List[Callable] + _forward_pre_hooks: List[Callable] + last_score: Optional[float] + + def __init__(self): + # Use object.__setattr__ to avoid triggering __setattr__ during init + object.__setattr__(self, "_rubric_children", {}) + object.__setattr__(self, "_forward_hooks", []) + object.__setattr__(self, "_forward_pre_hooks", []) + object.__setattr__(self, "last_score", None) + + def __setattr__(self, name: str, value: Any) -> None: + # Auto-register child rubrics when assigned as attributes + if isinstance(value, Rubric): + self._rubric_children[name] = value + object.__setattr__(self, name, value) + + def __call__(self, action: Any, observation: Any): + """Evaluate the rubric with hooks. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + # Check if forward method is async BEFORE calling it + if inspect.iscoroutinefunction(self.forward): + # Async path - pre-hooks will be called in _call_async + result = self.forward(action, observation) + return self._call_async(action, observation, result) + else: + # Sync path - call pre-hooks BEFORE forward() + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = self.forward(action, observation) + return self._call_sync(action, observation, result) + + def _call_sync(self, action: Any, observation: Any, result: float) -> float: + """Synchronous call path.""" + self.last_score = result + + # Post-forward hooks + for hook in self._forward_hooks: + hook(self, action, observation, result) + + return result + + async def _call_async(self, action: Any, observation: Any, result_coro) -> float: + """Asynchronous call path.""" + # Pre-forward hooks + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Await the forward result + result = await result_coro + self.last_score = result + + # Post-forward hooks + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + + return result + + @abstractmethod + def forward(self, action: Any, observation: Any) -> float: + """Compute the reward. Implement this in subclasses. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + raise NotImplementedError + + def register_forward_hook( + self, hook: Callable[["Rubric", Any, Any, float], None] + ) -> None: + """Register a hook called after forward(). + + Args: + hook: Callable with signature (rubric, action, observation, result). + """ + self._forward_hooks.append(hook) + + def register_forward_pre_hook( + self, hook: Callable[["Rubric", Any, Any], None] + ) -> None: + """Register a hook called before forward(). + + Args: + hook: Callable with signature (rubric, action, observation). + """ + self._forward_pre_hooks.append(hook) + + def children(self) -> Iterator["Rubric"]: + """Iterate over immediate child rubrics.""" + yield from self._rubric_children.values() + + def named_children(self) -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over immediate child rubrics with names.""" + yield from self._rubric_children.items() + + def rubrics(self) -> Iterator["Rubric"]: + """Iterate over all descendant rubrics (depth-first).""" + for child in self._rubric_children.values(): + yield child + yield from child.rubrics() + + def named_rubrics(self, prefix: str = "") -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over all descendant rubrics with dot-separated names.""" + for name, child in self._rubric_children.items(): + full_name = f"{prefix}.{name}" if prefix else name + yield full_name, child + yield from child.named_rubrics(full_name) + + def get_rubric(self, path: str) -> "Rubric": + """Access a nested rubric by dot-separated path. + + Args: + path: Dot-separated path (e.g., "code.syntax"). + + Returns: + The rubric at the specified path. + + Raises: + KeyError: If the path does not exist. + """ + parts = path.split(".") + current = self + for part in parts: + if part not in current._rubric_children: + raise KeyError(f"Rubric path not found: {path}") + current = current._rubric_children[part] + return current + + def reset(self) -> None: + """Reset any internal state. Override in subclasses if needed.""" + pass + + def state_dict(self) -> Dict[str, Any]: + """Serialize rubric configuration for checkpointing.""" + return {} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load rubric configuration from checkpoint.""" + pass diff --git a/src/core/openenv/core/rubrics/containers.py b/src/core/openenv/core/rubrics/containers.py new file mode 100644 index 0000000000000000000000000000000000000000..7a587ee7885efdf71b03d644b54524f1855474d9 --- /dev/null +++ b/src/core/openenv/core/rubrics/containers.py @@ -0,0 +1,574 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container rubrics for composing reward computations. + +These containers provide common aggregation patterns for rubrics, +similar to how PyTorch provides nn.Sequential alongside nn.Module. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import asyncio +import inspect +from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union + +from openenv.core.rubrics.base import Rubric + + +def _in_async_context() -> bool: + """Check if we're currently in an async context.""" + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + +class Sequential(Rubric): + """Run rubrics in order, fail-fast on zero. + + Runs child rubrics in order. If any returns 0, stops immediately + and returns 0. This implements hierarchical gating patterns where + syntax checks run before execution checks. + + Usage: + rubric = Sequential( + Gate(Compiles()), + Gate(PassesTests(), threshold=0.5), + WeightedSum([PassesTests(), StyleRubric()], weights=[0.7, 0.3]) + ) + """ + + def __init__(self, *rubrics: Rubric): + """Initialize with rubrics to run in sequence. + + Args: + *rubrics: Rubrics to run in order. Stops and returns 0 if any + child returns 0. + """ + super().__init__() + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + + def forward(self, action: Any, observation: Any) -> float: + """Run rubrics in order, return 0 if any returns 0. Sync version.""" + result = 1.0 + for rubric in self._rubric_list: + score = rubric(action, observation) + if score == 0.0: + return 0.0 + result = score + return result + + def __call__(self, action: Any, observation: Any): + """Override to choose sync or async path based on children.""" + # Empty case - check if in async context + if not self._rubric_list: + if _in_async_context(): + return self._empty_async(action, observation) + else: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = 1.0 + self.last_score = result + for hook in self._forward_hooks: + hook(self, action, observation, result) + return result + + # Call first rubric to see if it's async + first_result = self._rubric_list[0](action, observation) + if inspect.iscoroutine(first_result): + # At least one child is async, use async path + return self._call_async_detected(action, observation, first_result) + else: + # Continue with sync path + if first_result == 0.0: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = 0.0 + for hook in self._forward_hooks: + hook(self, action, observation, 0.0) + return 0.0 + + final_result = first_result + for i, rubric in enumerate(self._rubric_list[1:], start=1): + score = rubric(action, observation) + if inspect.iscoroutine(score): + # Found async mid-way, switch to async + # We already called rubric at index i, so pass the coroutine and remaining rubrics + return self._call_async_mid( + action, + observation, + final_result, + score, + self._rubric_list[i + 1 :], + ) + if score == 0.0: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = 0.0 + for hook in self._forward_hooks: + hook(self, action, observation, 0.0) + return 0.0 + final_result = score + + # All sync - check if in async context + if _in_async_context(): + return self._wrap_sync_result(action, observation, final_result) + else: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = final_result + for hook in self._forward_hooks: + hook(self, action, observation, final_result) + return final_result + + async def _empty_async(self, action, observation): + """Async path for empty sequential.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + result = 1.0 + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _wrap_sync_result(self, action, observation, result): + """Wrap sync result for async context.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _call_async_detected(self, action, observation, first_coro): + """Async path when first child is async.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + result = await first_coro + if result == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return 0.0 + + for rubric in self._rubric_list[1:]: + score = rubric(action, observation) + if inspect.iscoroutine(score): + score = await score + if score == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + result = score + + self.last_score = result + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _call_async_mid( + self, action, observation, current_result, first_async_coro, remaining + ): + """Async path when async detected mid-execution.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Await the first async rubric (already called) + result = await first_async_coro + if result == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + + # Continue with remaining rubrics + for rubric in remaining: + score = rubric(action, observation) + if inspect.iscoroutine(score): + score = await score + if score == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + result = score + + self.last_score = result + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + def __len__(self) -> int: + return len(self._rubric_list) + + def __getitem__(self, index: int) -> Rubric: + return self._rubric_list[index] + + +class Gate(Rubric): + """Threshold wrapper - returns 0 if child score is below threshold. + + Useful for hard constraints like "must pass 50% of tests". + + Usage: + rubric = Gate(PassesTests(), threshold=0.5) + # Returns PassesTests() score if >= 0.5, else 0.0 + """ + + def __init__(self, rubric: Rubric, threshold: float = 1.0): + """Initialize with a rubric and threshold. + + Args: + rubric: The rubric to gate. + threshold: Minimum score required. If child returns less than + this, Gate returns 0. Default is 1.0 (must pass completely). + """ + super().__init__() + self.rubric = rubric + self.threshold = threshold + + def forward(self, action: Any, observation: Any) -> float: + """Return child score if >= threshold, else 0. Sync version.""" + score = self.rubric(action, observation) + if score < self.threshold: + return 0.0 + return score + + def __call__(self, action: Any, observation: Any): + """Override to handle async child.""" + # Call child + score = self.rubric(action, observation) + + if inspect.iscoroutine(score): + # Child is async + return self._call_async(action, observation, score) + else: + # Child is sync + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = 0.0 if score < self.threshold else score + self.last_score = result + for hook in self._forward_hooks: + hook(self, action, observation, result) + return result + + async def _call_async(self, action, observation, score_coro): + """Async path.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + score = await score_coro + result = 0.0 if score < self.threshold else score + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + +class WeightedSum(Rubric): + """Weighted combination of child rubrics. + + Standard aggregation pattern for multi-criteria evaluation. + + Usage: + rubric = WeightedSum( + [PassesTests(), StyleRubric()], + weights=[0.7, 0.3] + ) + """ + + def __init__(self, rubrics: List[Rubric], weights: List[float]): + """Initialize with rubrics and weights. + + Args: + rubrics: List of rubrics to combine. + weights: Weight for each rubric. Must sum to 1.0. + + Raises: + ValueError: If lengths don't match or weights don't sum to 1.0. + """ + super().__init__() + if len(rubrics) != len(weights): + raise ValueError( + f"Number of rubrics ({len(rubrics)}) must match " + f"number of weights ({len(weights)})" + ) + if abs(sum(weights) - 1.0) > 1e-6: + raise ValueError(f"Weights must sum to 1.0, got {sum(weights)}") + + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + self._weights = list(weights) + + def forward(self, action: Any, observation: Any) -> float: + """Return weighted sum of child scores. Sync version.""" + total = 0.0 + for rubric, weight in zip(self._rubric_list, self._weights): + score = rubric(action, observation) + total += score * weight + return total + + def __call__(self, action: Any, observation: Any): + """Override to handle async children with parallel execution.""" + # Call all rubrics + results = [rubric(action, observation) for rubric in self._rubric_list] + + # Check if any are async + has_async = any(inspect.iscoroutine(r) for r in results) + + if has_async: + # Use async path + return self._call_async(action, observation, results) + else: + # Sync path + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + total = 0.0 + for score, weight in zip(results, self._weights): + total += score * weight + self.last_score = total + for hook in self._forward_hooks: + hook(self, action, observation, total) + return total + + async def _call_async(self, action, observation, results): + """Async path with parallel execution.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Separate sync and async results + async_tasks = [] + async_indices = [] + scores = [None] * len(results) + + for i, result in enumerate(results): + if inspect.iscoroutine(result): + async_tasks.append(result) + async_indices.append(i) + else: + scores[i] = result + + # Await all async tasks in parallel + if async_tasks: + async_scores = await asyncio.gather(*async_tasks) + for i, score in zip(async_indices, async_scores): + scores[i] = score + + # Compute weighted sum + total = 0.0 + for score, weight in zip(scores, self._weights): + total += score * weight + + self.last_score = total + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, total) + else: + hook(self, action, observation, total) + return total + + @property + def weights(self) -> List[float]: + """Get the weights (read-only copy).""" + return list(self._weights) + + +class RubricList(Rubric): + """Container for dynamic lists of rubrics. + + Analogous to nn.ModuleList. Does not define aggregation - use within + a parent rubric that implements custom logic. + + Usage: + class MultiGameRubric(Rubric): + def __init__(self, games: List[str]): + super().__init__() + self.games = RubricList([GameRubric(g) for g in games]) + + def forward(self, action, obs) -> float: + return self.games[obs.game_index](action, obs) + """ + + def __init__(self, rubrics: List[Rubric] = None): + """Initialize with optional list of rubrics. + + Args: + rubrics: Optional list of rubrics to start with. + """ + super().__init__() + self._rubrics: List[Rubric] = [] + if rubrics is not None: + for i, rubric in enumerate(rubrics): + self.append(rubric) + + def forward(self, action: Any, observation: Any) -> float: + """RubricList does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricList.forward() is not implemented. " + "Use RubricList within a parent rubric that defines aggregation." + ) + + def append(self, rubric: Rubric) -> None: + """Add a rubric to the list.""" + index = len(self._rubrics) + setattr(self, f"rubric_{index}", rubric) + self._rubrics.append(rubric) + + def extend(self, rubrics: List[Rubric]) -> None: + """Add multiple rubrics to the list.""" + for rubric in rubrics: + self.append(rubric) + + def __len__(self) -> int: + return len(self._rubrics) + + def __getitem__(self, index: int) -> Rubric: + return self._rubrics[index] + + def __iter__(self) -> Iterator[Rubric]: + return iter(self._rubrics) + + +class RubricDict(Rubric): + """Container for named rubrics with keyed access. + + Analogous to nn.ModuleDict. Enables keyed access for multi-task + environments where different tasks require different rubrics. + + Usage: + class AtariRubric(Rubric): + def __init__(self): + super().__init__() + self.games = RubricDict({ + "pong": PongRubric(), + "breakout": BreakoutRubric(), + "space_invaders": SpaceInvadersRubric(), + }) + + def forward(self, action, obs) -> float: + return self.games[obs.game_id](action, obs) + + # Access: env.rubric.games["pong"] + """ + + def __init__(self, rubrics: Dict[str, Rubric] = None): + """Initialize with optional dictionary of rubrics. + + Args: + rubrics: Optional dictionary mapping names to rubrics. + """ + super().__init__() + self._rubric_dict: Dict[str, Rubric] = {} + if rubrics is not None: + for name, rubric in rubrics.items(): + self[name] = rubric + + def forward(self, action: Any, observation: Any) -> float: + """RubricDict does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricDict.forward() is not implemented. " + "Use RubricDict within a parent rubric that defines aggregation." + ) + + def __setitem__(self, key: str, rubric: Rubric) -> None: + """Add a rubric with the given key.""" + setattr(self, key, rubric) + self._rubric_dict[key] = rubric + + def __getitem__(self, key: str) -> Rubric: + """Get rubric by key.""" + return self._rubric_dict[key] + + def __contains__(self, key: str) -> bool: + """Check if key exists.""" + return key in self._rubric_dict + + def __len__(self) -> int: + return len(self._rubric_dict) + + def __iter__(self) -> Iterator[str]: + return iter(self._rubric_dict) + + def keys(self) -> Iterator[str]: + """Iterate over keys.""" + return iter(self._rubric_dict.keys()) + + def values(self) -> Iterator[Rubric]: + """Iterate over rubrics.""" + return iter(self._rubric_dict.values()) + + def items(self) -> Iterator[Tuple[str, Rubric]]: + """Iterate over (key, rubric) pairs.""" + return iter(self._rubric_dict.items()) + + def update(self, rubrics: Union[Dict[str, Rubric], Mapping[str, Rubric]]) -> None: + """Update with rubrics from a dictionary.""" + for name, rubric in rubrics.items(): + self[name] = rubric diff --git a/src/core/openenv/core/rubrics/llm_judge.py b/src/core/openenv/core/rubrics/llm_judge.py new file mode 100644 index 0000000000000000000000000000000000000000..4963956eb4a51270c03809f9f0e14f1c66b91958 --- /dev/null +++ b/src/core/openenv/core/rubrics/llm_judge.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""LLM-as-a-judge rubric for reward computation. + +Uses an LLM endpoint (via LLMClient) to evaluate agent actions/observations. + +Usage: + client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") + judge = LLMJudge( + prompt_template="Rate this code solution:\\n{action}\\n\\nScore (0-1):", + client=client, + ) + score = await judge(action, observation) + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import re +from typing import Any, Dict + +from openenv.core.llm_client import LLMClient +from openenv.core.rubrics.base import Rubric + + +class LLMJudge(Rubric): + """Rubric that uses an LLM to evaluate agent actions/observations. + + The prompt template is formatted with ``{action}`` and ``{observation}`` + placeholders. The LLM response is parsed for a numeric score. + + Args: + prompt_template: Template string with {action} and {observation} placeholders. + client: An LLMClient instance for making LLM calls. + score_pattern: Regex to extract the score from the LLM response. + Defaults to matching the first decimal number. + default_score: Score returned when parsing fails. + normalize: If True, clamp extracted score to [0, 1]. + """ + + def __init__( + self, + prompt_template: str, + client: LLMClient, + *, + score_pattern: str | None = None, + default_score: float = 0.0, + normalize: bool = True, + ): + super().__init__() + self.prompt_template = prompt_template + self._client = client + self._score_pattern = re.compile(score_pattern or r"(\d+\.?\d*)") + self.default_score = default_score + self.normalize = normalize + + async def forward(self, action: Any, observation: Any) -> float: + """Evaluate by sending a prompt to the LLM and parsing the score. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Parsed score from the LLM response. + """ + prompt = self._render_prompt(action, observation) + response = await self._client.complete(prompt) + return self._parse_score(response) + + def _render_prompt(self, action: Any, observation: Any) -> str: + """Format the prompt template with action and observation. + + Override in subclasses for custom prompt construction. + """ + return self.prompt_template.format(action=action, observation=observation) + + def _parse_score(self, response: str) -> float: + """Extract a numeric score from the LLM response. + + Uses the configured regex pattern to find the first match. + Returns default_score if no match is found. + """ + match = self._score_pattern.search(response) + if match is None: + return self.default_score + try: + # Use first capture group if present, otherwise full match + text = match.group(1) if match.lastindex else match.group(0) + score = float(text) + except (ValueError, IndexError): + return self.default_score + if self.normalize: + score = max(0.0, min(1.0, score)) + return score + + def state_dict(self) -> Dict[str, Any]: + """Serialize rubric configuration.""" + return { + "prompt_template": self.prompt_template, + "score_pattern": self._score_pattern.pattern, + "default_score": self.default_score, + "normalize": self.normalize, + } + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load rubric configuration from checkpoint.""" + if "prompt_template" in state: + self.prompt_template = state["prompt_template"] + if "score_pattern" in state: + self._score_pattern = re.compile(state["score_pattern"]) + if "default_score" in state: + self.default_score = state["default_score"] + if "normalize" in state: + self.normalize = state["normalize"] diff --git a/src/core/openenv/core/rubrics/trajectory.py b/src/core/openenv/core/rubrics/trajectory.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bb9aa9172047a24f89fae1fee6917abb861257 --- /dev/null +++ b/src/core/openenv/core/rubrics/trajectory.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Trajectory-based rubrics for delayed reward computation. + +These rubrics accumulate trajectory data and compute rewards based on +episode outcomes rather than individual steps. This supports scenarios +where reward signals depend on future events: + +- Terminal games (chess, Go): Win/loss known only at game end +- Plan execution: Plan quality depends on execution success +- Multi-agent games: One player's action quality depends on opponent response + +See RFC 004 "Delayed Rewards" section for design rationale. +""" + +from abc import abstractmethod +from typing import Any, Dict, List, Tuple + +from openenv.core.rubrics.base import Rubric + + +class TrajectoryRubric(Rubric): + """Abstract base for rubrics that score based on full trajectories. + + Subclasses implement: + - score_trajectory(): Compute final score from trajectory + - compute_step_rewards(): Define credit assignment strategy + + The __call__ method accumulates steps and returns rewards according + to the subclass's implementation. + + IMPORTANT: Trajectories are stored in CPU memory to avoid GPU pressure. + Environments with GPU tensors in observations must move them to CPU + before returning from step(). + + Known limitation: Very long episodes (thousands of steps) may consume + significant CPU memory. For such cases, consider streaming rubrics. + + Usage: + class WinLossRubric(TrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + return 1.0 if final_obs.metadata.get('won') else 0.0 + + def compute_step_rewards(self): + # Equal credit to all steps + score = self.score_trajectory(self._trajectory) + return [score] * len(self._trajectory) + + rubric = WinLossRubric() + for action, obs in episode: + reward = rubric(action, obs) # 0.0 until done + step_rewards = rubric.compute_step_rewards() # Credit assignment + """ + + _trajectory: List[Tuple[Any, Any]] + intermediate_reward: float + + def __init__(self, intermediate_reward: float = 0.0): + """Initialize trajectory rubric. + + Args: + intermediate_reward: Value to return for non-terminal steps. + Defaults to 0.0. + """ + super().__init__() + self.intermediate_reward = intermediate_reward + self._trajectory = [] + + def forward(self, action: Any, observation: Any) -> float: + """Accumulate step and return reward. + + Returns intermediate_reward until done, then computes trajectory score. + + Args: + action: The action taken. + observation: The resulting observation. Must have a 'done' attribute. + + Returns: + intermediate_reward if not done, else score_trajectory() result. + """ + self._trajectory.append((action, observation)) + + if getattr(observation, "done", False): + return self.score_trajectory(self._trajectory) + else: + return self.intermediate_reward + + @abstractmethod + def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float: + """Score the complete trajectory. Return 0.0-1.0. + + Called when observation.done=True. + + Args: + trajectory: List of (action, observation) tuples. + + Returns: + Final trajectory score (typically 0.0 to 1.0). + """ + raise NotImplementedError + + @abstractmethod + def compute_step_rewards(self) -> List[float]: + """Compute per-step rewards from the accumulated trajectory. + + Returns: + List of rewards, one per step. Length matches len(trajectory). + + Define your credit assignment strategy here (e.g., discounting, + assigning all credit to specific steps, etc.). + """ + raise NotImplementedError + + def reset(self) -> None: + """Clear accumulated trajectory. Call on env.reset().""" + self._trajectory = [] + + @property + def trajectory(self) -> List[Tuple[Any, Any]]: + """Current trajectory (read-only copy).""" + return list(self._trajectory) + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration (not trajectory data).""" + return {"intermediate_reward": self.intermediate_reward} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + if "intermediate_reward" in state: + self.intermediate_reward = state["intermediate_reward"] + + +class ExponentialDiscountingTrajectoryRubric(TrajectoryRubric): + """TrajectoryRubric with exponential discounting for credit assignment. + + Per-step reward: r_t = gamma^(T-1-t) * R_final + + With gamma=0.99, later steps get higher reward (they're "closer" to the outcome). + With gamma=1.0, all steps get equal reward. + With gamma=0.0, only the final step gets reward. + + This is the standard temporal discounting used in reinforcement learning, + applied retroactively once the episode outcome is known. + + Usage: + class ChessRubric(ExponentialDiscountingTrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + outcome = final_obs.metadata.get('winner') + if outcome == 'agent': return 1.0 + elif outcome == 'opponent': return 0.0 + else: return 0.5 # Draw + + rubric = ChessRubric(gamma=0.99) + reward = rubric(action, obs) # 0.0 until done, then final score + step_rewards = rubric.compute_step_rewards() # Discounted per-step rewards + """ + + gamma: float + + def __init__(self, gamma: float = 0.99, intermediate_reward: float = 0.0): + """Initialize with discount factor. + + Args: + gamma: Discount factor in [0, 1]. Higher values give more credit + to early moves. 0.99 is a common choice. + intermediate_reward: Value to return for non-terminal steps. + """ + super().__init__(intermediate_reward=intermediate_reward) + if not 0.0 <= gamma <= 1.0: + raise ValueError(f"gamma must be in [0, 1], got {gamma}") + self.gamma = gamma + + def compute_step_rewards(self) -> List[float]: + """Apply exponential discounting from final reward. + + Returns: + List of discounted rewards. step_rewards[t] = gamma^(T-1-t) * R_final + where T is the trajectory length and R_final is score_trajectory(). + """ + if not self._trajectory: + return [] + + final_score = self.score_trajectory(self._trajectory) + T = len(self._trajectory) + return [final_score * (self.gamma ** (T - 1 - t)) for t in range(T)] + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration.""" + state = super().state_dict() + state["gamma"] = self.gamma + return state + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + super().load_state_dict(state) + if "gamma" in state: + self.gamma = state["gamma"] diff --git a/src/core/openenv/core/sync_client.py b/src/core/openenv/core/sync_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5eb5da6151cea692ae447e1d4caba40a95fdaa --- /dev/null +++ b/src/core/openenv/core/sync_client.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Synchronous wrapper for async EnvClient. + +This module provides a SyncEnvClient that wraps an async EnvClient, +allowing synchronous usage while the underlying client uses async I/O. + +Example: + >>> from openenv.core import GenericEnvClient + >>> + >>> # Create async client and get sync wrapper + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous API + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"code": "print('hello')"}) +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import inspect +import threading +from typing import Any, Dict, Generic, TYPE_CHECKING, TypeVar + +from .client_types import StateT, StepResult + +if TYPE_CHECKING: + from .env_client import EnvClient + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") + + +class SyncEnvClient(Generic[ActT, ObsT, StateT]): + """ + Synchronous wrapper around an async EnvClient. + + This class provides a synchronous interface to an async EnvClient, + making it easier to use in synchronous code or to stop async from + "infecting" the entire call stack. + + The wrapper executes async operations on a dedicated background event loop + so connection state remains bound to a single loop. + + Cleanup note: + For guaranteed resource cleanup, use `with SyncEnvClient(...)` or call + `close()` explicitly. `__del__` is best-effort only and may not run + reliably (for example, during interpreter shutdown). + + Example: + >>> # From an async client + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous context manager + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"action": "test"}) + + Attributes: + _async: The wrapped async EnvClient instance + """ + + def __init__(self, async_client: "EnvClient[ActT, ObsT, StateT]"): + """ + Initialize sync wrapper around an async client. + + Args: + async_client: The async EnvClient to wrap + """ + self._async = async_client + self._loop: asyncio.AbstractEventLoop | None = None + self._loop_thread: threading.Thread | None = None + self._loop_ready = threading.Event() + self._loop_init_lock = threading.Lock() + self._async_wrapper_cache: Dict[str, Any] = {} + + def _run_loop_forever(self) -> None: + """Run a dedicated event loop for this sync client.""" + loop = asyncio.new_event_loop() + self._loop = loop + asyncio.set_event_loop(loop) + self._loop_ready.set() + loop.run_forever() + loop.close() + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + """Start background loop thread on first use.""" + if ( + self._loop is not None + and self._loop_thread + and self._loop_thread.is_alive() + ): + return self._loop + + # Protect loop initialization when multiple threads race on first use. + with self._loop_init_lock: + if ( + self._loop is not None + and self._loop_thread + and self._loop_thread.is_alive() + ): + return self._loop + + self._loop_ready.clear() + self._loop_thread = threading.Thread( + target=self._run_loop_forever, + name="openenv-sync-client-loop", + daemon=True, + ) + self._loop_thread.start() + if not self._loop_ready.wait(timeout=5): + raise RuntimeError("Timed out starting sync client event loop") + assert self._loop is not None + return self._loop + + def _run(self, coro: Any) -> Any: + """Run coroutine on dedicated loop and block for result.""" + loop = self._ensure_loop() + future: concurrent.futures.Future[Any] = asyncio.run_coroutine_threadsafe( + coro, loop + ) + return future.result() + + def _stop_loop(self) -> None: + """Stop and join background loop thread.""" + loop = self._loop + thread = self._loop_thread + if loop is None: + return + + if loop.is_running(): + loop.call_soon_threadsafe(loop.stop) + if thread is not None: + thread.join(timeout=5) + + self._loop = None + self._loop_thread = None + + @property + def async_client(self) -> "EnvClient[ActT, ObsT, StateT]": + """Access the underlying async client.""" + return self._async + + def connect(self) -> "SyncEnvClient[ActT, ObsT, StateT]": + """ + Establish connection to the server. + + Returns: + self for method chaining + """ + self._run(self._async.connect()) + return self + + def disconnect(self) -> None: + """Close the connection.""" + self._run(self._async.disconnect()) + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment. + + Args: + **kwargs: Optional parameters passed to the environment's reset method + + Returns: + StepResult containing initial observation + """ + return self._run(self._async.reset(**kwargs)) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters + + Returns: + StepResult containing observation, reward, and done status + """ + return self._run(self._async.step(action, **kwargs)) + + def state(self) -> StateT: + """ + Get the current environment state. + + Returns: + State object with environment state information + """ + return self._run(self._async.state()) + + def close(self) -> None: + """Close the connection and clean up resources.""" + try: + self._run(self._async.close()) + finally: + self._stop_loop() + + def __enter__(self) -> "SyncEnvClient[ActT, ObsT, StateT]": + """Enter context manager, establishing connection.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close() + + def __del__(self) -> None: + """ + Best-effort cleanup for background loop thread. + + Do not rely on this for deterministic cleanup; prefer context-manager + usage or an explicit `close()` call. + """ + try: + self._stop_loop() + except Exception: + pass + + def __getattr__(self, name: str) -> Any: + """ + Delegate unknown attributes to the async client. + + Async methods are wrapped to run on the sync client's dedicated loop. + """ + attr = getattr(self._async, name) + + if inspect.iscoroutinefunction(attr): + cached = self._async_wrapper_cache.get(name) + if cached is not None: + return cached + + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + method = getattr(self._async, name) + return self._run(method(*args, **kwargs)) + + self._async_wrapper_cache[name] = sync_wrapper + return sync_wrapper + + return attr + + # Delegate abstract method implementations to the wrapped client + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Delegate to async client's _step_payload.""" + return self._async._step_payload(action) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Delegate to async client's _parse_result.""" + return self._async._parse_result(payload) + + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Delegate to async client's _parse_state.""" + return self._async._parse_state(payload) diff --git a/src/core/openenv/core/tools/__init__.py b/src/core/openenv/core/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0193b2619fc14f14152e3276f54aa0d4aed8ca2c --- /dev/null +++ b/src/core/openenv/core/tools/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core tools for code execution and other utilities.""" + +from .git_server_client import GitServerClient, RepoInfo + +try: + from .local_python_executor import PyExecutor +except ModuleNotFoundError: + # smolagents is optional for environments that only need Git tooling. + PyExecutor = None # type: ignore[assignment] + +__all__ = [ + "PyExecutor", + "GitServerClient", + "RepoInfo", +] diff --git a/src/core/openenv/core/tools/git_server_client.py b/src/core/openenv/core/tools/git_server_client.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc3379f6b675178cc7aa94914c31f66bc846aed --- /dev/null +++ b/src/core/openenv/core/tools/git_server_client.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 +""" +Git Server Client for connecting to external Gitea instance. + +This module provides a lightweight client for interacting with a shared +Gitea service, optimized for task-based isolation where multiple environment +instances share the same Gitea server but have isolated workspaces. +""" + +import json +import os +import shutil +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from urllib.parse import urlparse + + +@dataclass +class RepoInfo: + """Information about a repository.""" + + name: str + url: str + commit: str + clone_url: str + + +class GitServerClient: + """ + Client for connecting to an external Gitea server. + + This client is optimized for task-based isolation where: + - Multiple tasks share the same Gitea instance + - Each task has its own isolated workspace + - Fast reset() via git operations (no server restart) + - Repos are pre-migrated to Gitea once + + Args: + gitea_url: URL of the Gitea server (e.g., "http://gitea:3000") + username: Gitea username for authentication + password: Gitea password for authentication + workspace_dir: Local workspace directory for cloning repos + + Example: + >>> # Connect to shared Gitea (credentials from environment) + >>> import os + >>> client = GitServerClient( + ... gitea_url=os.getenv("GITEA_URL"), + ... username=os.getenv("GITEA_USERNAME"), + ... password=os.getenv("GITEA_PASSWORD") + ... ) + >>> client.wait_for_ready() + >>> # Clone repo to workspace + >>> path = client.clone_to_workspace("my-repo", commit="abc123") + >>> # Fast reset to base state + >>> client.reset_workspace("my-repo", commit="abc123") + """ + + def __init__( + self, + gitea_url: str, + username: str, + password: str, + workspace_dir: str = "/workspace", + ): + """Initialize Git Server Client.""" + self.gitea_url = gitea_url.rstrip("/") + self.username = username + self.password = password + self.workspace_dir = Path(workspace_dir) + self.is_ready = False + + # Parse Gitea URL + parsed = urlparse(self.gitea_url) + self.domain = parsed.hostname or "localhost" + self.port = parsed.port or 3000 + + # Ensure workspace exists + os.makedirs(self.workspace_dir, exist_ok=True) + + # Configure git credentials + self._configure_git() + + def _configure_git(self): + """Configure git credentials for automatic authentication.""" + home_dir = Path.home() + + # Git config + git_config = f"""[user] + name = {self.username} + email = {self.username}@local.env +[init] + defaultBranch = main +[credential] + helper = store +""" + gitconfig_path = home_dir / ".gitconfig" + gitconfig_path.write_text(git_config) + + # Git credentials + git_credentials = ( + f"http://{self.username}:{self.password}@{self.domain}:{self.port}\n" + ) + gitcreds_path = home_dir / ".git-credentials" + gitcreds_path.write_text(git_credentials) + gitcreds_path.chmod(0o600) + + def wait_for_ready(self, timeout: int = 30) -> bool: + """ + Wait for Gitea server to be ready. + + Args: + timeout: Maximum seconds to wait + + Returns: + True if server is ready, False otherwise + """ + start_time = time.time() + while time.time() - start_time < timeout: + try: + result = subprocess.run( + ["curl", "-sf", f"{self.gitea_url}/"], + capture_output=True, + timeout=5, + ) + if result.returncode == 0: + self.is_ready = True + return True + except subprocess.TimeoutExpired: + pass + except Exception: + pass + + time.sleep(1) + + return False + + def list_repositories(self) -> list[dict[str, str]]: + """ + List all repositories in Gitea. + + Returns: + List of repository information dictionaries + """ + if not self.is_ready: + raise RuntimeError("Gitea server is not ready") + + result = subprocess.run( + [ + "curl", + "-s", + f"{self.gitea_url}/api/v1/user/repos", + "-u", + f"{self.username}:{self.password}", + ], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + return [] + + try: + repos = json.loads(result.stdout) + return [ + { + "name": repo["name"], + "full_name": repo["full_name"], + "clone_url": repo["clone_url"], + "description": repo.get("description", ""), + } + for repo in repos + ] + except (json.JSONDecodeError, KeyError): + return [] + + def clone_to_workspace( + self, repo_name: str, target_dir: str | None = None, commit: str = "main" + ) -> str: + """ + Clone a repository to the workspace at a specific commit. + + This creates a fresh clone optimized for task isolation. + + Args: + repo_name: Name of repository to clone + target_dir: Target directory name (defaults to repo_name) + commit: Commit hash or branch to check out + + Returns: + Path to cloned repository + + Raises: + RuntimeError: If clone fails + """ + if not self.is_ready: + raise RuntimeError("Gitea server is not ready") + + target_dir = target_dir or repo_name + target_path = self.workspace_dir / target_dir + + # Remove existing directory if present + if target_path.exists(): + shutil.rmtree(target_path) + + clone_url = f"{self.gitea_url}/{self.username}/{repo_name}.git" + + # Clone repository + result = subprocess.run( + ["git", "clone", clone_url, str(target_path)], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Clone failed: {result.stderr}") + + # Checkout specific commit + if commit != "main": + result = subprocess.run( + ["git", "checkout", commit], + cwd=str(target_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Checkout failed: {result.stderr}") + + return str(target_path) + + def reset_workspace(self, repo_name: str, commit: str = "main") -> bool: + """ + Fast reset of workspace to base state (optimized for task resets). + + This is much faster than re-cloning. It: + 1. Checks out the target commit + 2. Resets to that commit (hard) + 3. Cleans untracked files + + Args: + repo_name: Name of repository (directory in workspace) + commit: Commit hash or branch to reset to + + Returns: + True if reset successful + + Raises: + RuntimeError: If reset fails + """ + repo_path = self.workspace_dir / repo_name + + if not repo_path.exists(): + raise RuntimeError(f"Repository not found in workspace: {repo_name}") + + # Fetch latest (in case commit is new) + subprocess.run( + ["git", "fetch", "--all"], + cwd=str(repo_path), + capture_output=True, + ) + + # Checkout and hard reset to commit + result = subprocess.run( + ["git", "checkout", commit], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Checkout failed: {result.stderr}") + + result = subprocess.run( + [ + "git", + "reset", + "--hard", + f"origin/{commit}" if commit != "main" else commit, + ], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + # Try without origin/ prefix + result = subprocess.run( + ["git", "reset", "--hard", commit], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"Reset failed: {result.stderr}") + + # Clean untracked files and directories + subprocess.run( + ["git", "clean", "-fdx"], + cwd=str(repo_path), + capture_output=True, + ) + + return True + + def execute_git_command( + self, command: str, working_dir: str = "" + ) -> tuple[int, str, str]: + """ + Execute a git command in the workspace. + + Args: + command: Git command to execute (without 'git' prefix) + working_dir: Working directory relative to workspace + + Returns: + Tuple of (exit_code, stdout, stderr) + """ + work_path = ( + self.workspace_dir / working_dir if working_dir else self.workspace_dir + ) + + if not work_path.exists(): + return (1, "", f"Working directory does not exist: {work_path}") + + # Split command safely + cmd_parts = ["git"] + command.split() + + result = subprocess.run( + cmd_parts, + cwd=str(work_path), + capture_output=True, + text=True, + ) + + return (result.returncode, result.stdout, result.stderr) + + def get_current_commit(self, repo_name: str) -> str: + """ + Get current commit hash of a workspace repository. + + Args: + repo_name: Name of repository in workspace + + Returns: + Commit hash + """ + repo_path = self.workspace_dir / repo_name + + if not repo_path.exists(): + raise RuntimeError(f"Repository not found: {repo_name}") + + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to get commit: {result.stderr}") + + return result.stdout.strip() + + def workspace_exists(self, repo_name: str) -> bool: + """Check if a repository exists in workspace.""" + return (self.workspace_dir / repo_name).exists() diff --git a/src/core/openenv/core/tools/local_python_executor.py b/src/core/openenv/core/tools/local_python_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..bb18052b309b3c214bcf0e5c2645416734575fa1 --- /dev/null +++ b/src/core/openenv/core/tools/local_python_executor.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Local Python Executor (enhanced). + +This module provides a safer wrapper around smolagents.LocalPythonExecutor +with improved exception handling and a few helpful tools registered with +the executor to make debugging executed code easier. + +Key improvements: +- Register a few helper utilities via send_tools so user code can use + them for reporting (e.g. `format_exc`). +- More robust extraction of stdout/stderr/exit codes from the executor + result object, tolerant to different versions of smolagents. +- Detailed stderr on unexpected exceptions including full traceback. +- Structured logging for operational visibility. +""" + +from __future__ import annotations + +import json +import logging +import traceback + +from openenv.core.env_server.types import CodeExecResult +from smolagents import LocalPythonExecutor + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class PyExecutor: + """Wrapper around smolagents LocalPythonExecutor. + + The wrapper registers a few non-privileged helper tools to the + LocalPythonExecutor that can be used by the executed code to + format exceptions and to safely stringify results for improved + error reporting. + """ + + def __init__(self, additional_imports: list[str] | None = None): + if additional_imports is None: + additional_imports = [] + + self._executor = LocalPythonExecutor( + additional_authorized_imports=additional_imports + ) + + # Register helpful utilities exposed to the execution environment. + # These are intentionally small, read-only helpers. + tools = { + # Provide a small helper to format the current exception in the + # executed context. This is a *string formatting* helper only. + "format_exc": traceback.format_exc, + # Safe JSON dumps with a fallback for non-serializable objects. + "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)), + } + + # `send_tools` is the public API on LocalPythonExecutor to make + # helper callables available to the sandboxed runtime. We don't + # provide any builtins that could change the environment. + try: + self._executor.send_tools(tools) + except Exception: + # If the LocalPythonExecutor implementation doesn't support + # send_tools or fails, log and continue — the executor is still usable. + logger.debug( + "LocalPythonExecutor.send_tools failed; continuing without extra tools", + exc_info=True, + ) + + def run(self, code: str) -> CodeExecResult: + """Execute Python code and return a CodeExecResult. + + This method is intentionally defensive: it attempts to extract + meaningful stdout/stderr/exit_code information from a variety of + possible return shapes that different versions of smolagents + may provide. + """ + try: + exec_result = self._executor(code) + + # Default values + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + exit_code = 0 + + # Extract logs/prints + try: + logs = getattr(exec_result, "logs", None) + if logs: + stdout_parts.append(str(logs)) + except Exception: + logger.debug("Failed to read exec_result.logs", exc_info=True) + + # Extract the result / output value + try: + if hasattr(exec_result, "output"): + out_val = exec_result.output + # If the output is not None, stringify it in a safe way + if out_val is not None: + # Prefer JSON if possible, otherwise repr + try: + stdout_parts.append(json.dumps(out_val)) + except Exception: + stdout_parts.append(repr(out_val)) + except Exception: + logger.debug("Failed to read exec_result.output", exc_info=True) + + # Some runtime implementations may put errors on `error` or `exception` + try: + err = getattr(exec_result, "error", None) + if err: + stderr_parts.append(str(err)) + except Exception: + logger.debug("Failed to read exec_result.error", exc_info=True) + + try: + ex = getattr(exec_result, "exception", None) + if ex: + stderr_parts.append(str(ex)) + except Exception: + logger.debug("Failed to read exec_result.exception", exc_info=True) + + # Determine exit code if provided + try: + if hasattr(exec_result, "exit_code"): + exit_code = ( + int(exec_result.exit_code) + if exec_result.exit_code is not None + else 0 + ) + elif hasattr(exec_result, "success"): + # Some versions use `success` boolean + exit_code = 0 if exec_result.success else 1 + else: + # Fallback: if there were any stderr parts, treat as non-zero + exit_code = 1 if stderr_parts else 0 + except Exception: + logger.debug("Failed to determine exec_result exit code", exc_info=True) + exit_code = 1 if stderr_parts else 0 + + # Compose the final stdout/stderr strings + stdout = "\n".join(part for part in stdout_parts if part is not None) + stderr = "\n".join(part for part in stderr_parts if part is not None) + + return CodeExecResult(stdout=stdout, stderr=stderr, exit_code=exit_code) + + except Exception: + # Any unexpected exception from the LocalPythonExecutor is + # returned with a full traceback to make debugging easier. + tb = traceback.format_exc() + logger.exception("LocalPythonExecutor raised an exception during run") + return CodeExecResult(stdout="", stderr=tb, exit_code=1) diff --git a/src/core/openenv/core/utils.py b/src/core/openenv/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e86b3ae9c3e6ec0a19cd6f4868e4e3cdfee66bbc --- /dev/null +++ b/src/core/openenv/core/utils.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for OpenEnv core.""" + +import asyncio +import concurrent.futures + + +def run_async_safely(coro): + """ + Run an async coroutine safely from any context. + + This handles the case where we may already be inside an async event loop + (e.g., when called from an async framework). In that case, asyncio.run() + would fail, so we use a ThreadPoolExecutor to run in a separate thread. + + Args: + coro: The coroutine to run + + Returns: + The result of the coroutine + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in async context - run in a thread pool + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + else: + # No async context - use asyncio.run() directly + return asyncio.run(coro) + + +def convert_to_ws_url(url: str) -> str: + """ + Convert an HTTP/HTTPS URL to a WS/WSS URL. + + Args: + url: The URL to convert. + + Returns: + The converted WebSocket URL. + """ + ws_url = url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + return ws_url diff --git a/src/core/rubrics/__init__.py b/src/core/rubrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe368494b70cfbabb04e86cb2277aa8c838bdf7 --- /dev/null +++ b/src/core/rubrics/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Rubrics for reward computation. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +from openenv.core.rubrics.base import Rubric +from openenv.core.rubrics.containers import ( + Gate, + RubricDict, + RubricList, + Sequential, + WeightedSum, +) +from openenv.core.rubrics.llm_judge import LLMJudge +from openenv.core.rubrics.trajectory import ( + ExponentialDiscountingTrajectoryRubric, + TrajectoryRubric, +) + +__all__ = [ + # Base + "Rubric", + # Containers + "Sequential", + "Gate", + "WeightedSum", + "RubricList", + "RubricDict", + # Trajectory + "TrajectoryRubric", + "ExponentialDiscountingTrajectoryRubric", + # LLM Judge + "LLMJudge", +] diff --git a/src/core/rubrics/base.py b/src/core/rubrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..38c7a381bc4f40a7bc1dac832902e9e6ac93a282 --- /dev/null +++ b/src/core/rubrics/base.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base Rubric class for reward computation. + +Rubrics compute rewards from actions and observations. The API is modeled +after PyTorch's nn.Module: users implement forward(), and the framework +handles child registration and hooks. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + + +class Rubric(ABC): + """Abstract base class for reward computation. + + A Rubric computes a reward signal from an action and observation. + Subclasses implement forward() to define the reward logic. + + Usage: + class MyRubric(Rubric): + def forward(self, action, observation) -> float: + return 1.0 if action.valid else 0.0 + + rubric = MyRubric() + reward = rubric(action, observation) + + Child rubrics are auto-registered when assigned as attributes, + enabling hierarchical composition and introspection. + """ + + _rubric_children: Dict[str, "Rubric"] + _forward_hooks: List[Callable] + _forward_pre_hooks: List[Callable] + last_score: Optional[float] + + def __init__(self): + # Use object.__setattr__ to avoid triggering __setattr__ during init + object.__setattr__(self, "_rubric_children", {}) + object.__setattr__(self, "_forward_hooks", []) + object.__setattr__(self, "_forward_pre_hooks", []) + object.__setattr__(self, "last_score", None) + + def __setattr__(self, name: str, value: Any) -> None: + # Auto-register child rubrics when assigned as attributes + if isinstance(value, Rubric): + self._rubric_children[name] = value + object.__setattr__(self, name, value) + + def __call__(self, action: Any, observation: Any): + """Evaluate the rubric with hooks. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + # Check if forward method is async BEFORE calling it + if inspect.iscoroutinefunction(self.forward): + # Async path - pre-hooks will be called in _call_async + result = self.forward(action, observation) + return self._call_async(action, observation, result) + else: + # Sync path - call pre-hooks BEFORE forward() + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = self.forward(action, observation) + return self._call_sync(action, observation, result) + + def _call_sync(self, action: Any, observation: Any, result: float) -> float: + """Synchronous call path.""" + self.last_score = result + + # Post-forward hooks + for hook in self._forward_hooks: + hook(self, action, observation, result) + + return result + + async def _call_async(self, action: Any, observation: Any, result_coro) -> float: + """Asynchronous call path.""" + # Pre-forward hooks + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Await the forward result + result = await result_coro + self.last_score = result + + # Post-forward hooks + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + + return result + + @abstractmethod + def forward(self, action: Any, observation: Any) -> float: + """Compute the reward. Implement this in subclasses. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + raise NotImplementedError + + def register_forward_hook( + self, hook: Callable[["Rubric", Any, Any, float], None] + ) -> None: + """Register a hook called after forward(). + + Args: + hook: Callable with signature (rubric, action, observation, result). + """ + self._forward_hooks.append(hook) + + def register_forward_pre_hook( + self, hook: Callable[["Rubric", Any, Any], None] + ) -> None: + """Register a hook called before forward(). + + Args: + hook: Callable with signature (rubric, action, observation). + """ + self._forward_pre_hooks.append(hook) + + def children(self) -> Iterator["Rubric"]: + """Iterate over immediate child rubrics.""" + yield from self._rubric_children.values() + + def named_children(self) -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over immediate child rubrics with names.""" + yield from self._rubric_children.items() + + def rubrics(self) -> Iterator["Rubric"]: + """Iterate over all descendant rubrics (depth-first).""" + for child in self._rubric_children.values(): + yield child + yield from child.rubrics() + + def named_rubrics(self, prefix: str = "") -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over all descendant rubrics with dot-separated names.""" + for name, child in self._rubric_children.items(): + full_name = f"{prefix}.{name}" if prefix else name + yield full_name, child + yield from child.named_rubrics(full_name) + + def get_rubric(self, path: str) -> "Rubric": + """Access a nested rubric by dot-separated path. + + Args: + path: Dot-separated path (e.g., "code.syntax"). + + Returns: + The rubric at the specified path. + + Raises: + KeyError: If the path does not exist. + """ + parts = path.split(".") + current = self + for part in parts: + if part not in current._rubric_children: + raise KeyError(f"Rubric path not found: {path}") + current = current._rubric_children[part] + return current + + def reset(self) -> None: + """Reset any internal state. Override in subclasses if needed.""" + pass + + def state_dict(self) -> Dict[str, Any]: + """Serialize rubric configuration for checkpointing.""" + return {} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load rubric configuration from checkpoint.""" + pass diff --git a/src/core/rubrics/containers.py b/src/core/rubrics/containers.py new file mode 100644 index 0000000000000000000000000000000000000000..7a587ee7885efdf71b03d644b54524f1855474d9 --- /dev/null +++ b/src/core/rubrics/containers.py @@ -0,0 +1,574 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container rubrics for composing reward computations. + +These containers provide common aggregation patterns for rubrics, +similar to how PyTorch provides nn.Sequential alongside nn.Module. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import asyncio +import inspect +from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union + +from openenv.core.rubrics.base import Rubric + + +def _in_async_context() -> bool: + """Check if we're currently in an async context.""" + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + +class Sequential(Rubric): + """Run rubrics in order, fail-fast on zero. + + Runs child rubrics in order. If any returns 0, stops immediately + and returns 0. This implements hierarchical gating patterns where + syntax checks run before execution checks. + + Usage: + rubric = Sequential( + Gate(Compiles()), + Gate(PassesTests(), threshold=0.5), + WeightedSum([PassesTests(), StyleRubric()], weights=[0.7, 0.3]) + ) + """ + + def __init__(self, *rubrics: Rubric): + """Initialize with rubrics to run in sequence. + + Args: + *rubrics: Rubrics to run in order. Stops and returns 0 if any + child returns 0. + """ + super().__init__() + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + + def forward(self, action: Any, observation: Any) -> float: + """Run rubrics in order, return 0 if any returns 0. Sync version.""" + result = 1.0 + for rubric in self._rubric_list: + score = rubric(action, observation) + if score == 0.0: + return 0.0 + result = score + return result + + def __call__(self, action: Any, observation: Any): + """Override to choose sync or async path based on children.""" + # Empty case - check if in async context + if not self._rubric_list: + if _in_async_context(): + return self._empty_async(action, observation) + else: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = 1.0 + self.last_score = result + for hook in self._forward_hooks: + hook(self, action, observation, result) + return result + + # Call first rubric to see if it's async + first_result = self._rubric_list[0](action, observation) + if inspect.iscoroutine(first_result): + # At least one child is async, use async path + return self._call_async_detected(action, observation, first_result) + else: + # Continue with sync path + if first_result == 0.0: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = 0.0 + for hook in self._forward_hooks: + hook(self, action, observation, 0.0) + return 0.0 + + final_result = first_result + for i, rubric in enumerate(self._rubric_list[1:], start=1): + score = rubric(action, observation) + if inspect.iscoroutine(score): + # Found async mid-way, switch to async + # We already called rubric at index i, so pass the coroutine and remaining rubrics + return self._call_async_mid( + action, + observation, + final_result, + score, + self._rubric_list[i + 1 :], + ) + if score == 0.0: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = 0.0 + for hook in self._forward_hooks: + hook(self, action, observation, 0.0) + return 0.0 + final_result = score + + # All sync - check if in async context + if _in_async_context(): + return self._wrap_sync_result(action, observation, final_result) + else: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = final_result + for hook in self._forward_hooks: + hook(self, action, observation, final_result) + return final_result + + async def _empty_async(self, action, observation): + """Async path for empty sequential.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + result = 1.0 + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _wrap_sync_result(self, action, observation, result): + """Wrap sync result for async context.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _call_async_detected(self, action, observation, first_coro): + """Async path when first child is async.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + result = await first_coro + if result == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return 0.0 + + for rubric in self._rubric_list[1:]: + score = rubric(action, observation) + if inspect.iscoroutine(score): + score = await score + if score == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + result = score + + self.last_score = result + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _call_async_mid( + self, action, observation, current_result, first_async_coro, remaining + ): + """Async path when async detected mid-execution.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Await the first async rubric (already called) + result = await first_async_coro + if result == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + + # Continue with remaining rubrics + for rubric in remaining: + score = rubric(action, observation) + if inspect.iscoroutine(score): + score = await score + if score == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + result = score + + self.last_score = result + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + def __len__(self) -> int: + return len(self._rubric_list) + + def __getitem__(self, index: int) -> Rubric: + return self._rubric_list[index] + + +class Gate(Rubric): + """Threshold wrapper - returns 0 if child score is below threshold. + + Useful for hard constraints like "must pass 50% of tests". + + Usage: + rubric = Gate(PassesTests(), threshold=0.5) + # Returns PassesTests() score if >= 0.5, else 0.0 + """ + + def __init__(self, rubric: Rubric, threshold: float = 1.0): + """Initialize with a rubric and threshold. + + Args: + rubric: The rubric to gate. + threshold: Minimum score required. If child returns less than + this, Gate returns 0. Default is 1.0 (must pass completely). + """ + super().__init__() + self.rubric = rubric + self.threshold = threshold + + def forward(self, action: Any, observation: Any) -> float: + """Return child score if >= threshold, else 0. Sync version.""" + score = self.rubric(action, observation) + if score < self.threshold: + return 0.0 + return score + + def __call__(self, action: Any, observation: Any): + """Override to handle async child.""" + # Call child + score = self.rubric(action, observation) + + if inspect.iscoroutine(score): + # Child is async + return self._call_async(action, observation, score) + else: + # Child is sync + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = 0.0 if score < self.threshold else score + self.last_score = result + for hook in self._forward_hooks: + hook(self, action, observation, result) + return result + + async def _call_async(self, action, observation, score_coro): + """Async path.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + score = await score_coro + result = 0.0 if score < self.threshold else score + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + +class WeightedSum(Rubric): + """Weighted combination of child rubrics. + + Standard aggregation pattern for multi-criteria evaluation. + + Usage: + rubric = WeightedSum( + [PassesTests(), StyleRubric()], + weights=[0.7, 0.3] + ) + """ + + def __init__(self, rubrics: List[Rubric], weights: List[float]): + """Initialize with rubrics and weights. + + Args: + rubrics: List of rubrics to combine. + weights: Weight for each rubric. Must sum to 1.0. + + Raises: + ValueError: If lengths don't match or weights don't sum to 1.0. + """ + super().__init__() + if len(rubrics) != len(weights): + raise ValueError( + f"Number of rubrics ({len(rubrics)}) must match " + f"number of weights ({len(weights)})" + ) + if abs(sum(weights) - 1.0) > 1e-6: + raise ValueError(f"Weights must sum to 1.0, got {sum(weights)}") + + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + self._weights = list(weights) + + def forward(self, action: Any, observation: Any) -> float: + """Return weighted sum of child scores. Sync version.""" + total = 0.0 + for rubric, weight in zip(self._rubric_list, self._weights): + score = rubric(action, observation) + total += score * weight + return total + + def __call__(self, action: Any, observation: Any): + """Override to handle async children with parallel execution.""" + # Call all rubrics + results = [rubric(action, observation) for rubric in self._rubric_list] + + # Check if any are async + has_async = any(inspect.iscoroutine(r) for r in results) + + if has_async: + # Use async path + return self._call_async(action, observation, results) + else: + # Sync path + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + total = 0.0 + for score, weight in zip(results, self._weights): + total += score * weight + self.last_score = total + for hook in self._forward_hooks: + hook(self, action, observation, total) + return total + + async def _call_async(self, action, observation, results): + """Async path with parallel execution.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Separate sync and async results + async_tasks = [] + async_indices = [] + scores = [None] * len(results) + + for i, result in enumerate(results): + if inspect.iscoroutine(result): + async_tasks.append(result) + async_indices.append(i) + else: + scores[i] = result + + # Await all async tasks in parallel + if async_tasks: + async_scores = await asyncio.gather(*async_tasks) + for i, score in zip(async_indices, async_scores): + scores[i] = score + + # Compute weighted sum + total = 0.0 + for score, weight in zip(scores, self._weights): + total += score * weight + + self.last_score = total + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, total) + else: + hook(self, action, observation, total) + return total + + @property + def weights(self) -> List[float]: + """Get the weights (read-only copy).""" + return list(self._weights) + + +class RubricList(Rubric): + """Container for dynamic lists of rubrics. + + Analogous to nn.ModuleList. Does not define aggregation - use within + a parent rubric that implements custom logic. + + Usage: + class MultiGameRubric(Rubric): + def __init__(self, games: List[str]): + super().__init__() + self.games = RubricList([GameRubric(g) for g in games]) + + def forward(self, action, obs) -> float: + return self.games[obs.game_index](action, obs) + """ + + def __init__(self, rubrics: List[Rubric] = None): + """Initialize with optional list of rubrics. + + Args: + rubrics: Optional list of rubrics to start with. + """ + super().__init__() + self._rubrics: List[Rubric] = [] + if rubrics is not None: + for i, rubric in enumerate(rubrics): + self.append(rubric) + + def forward(self, action: Any, observation: Any) -> float: + """RubricList does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricList.forward() is not implemented. " + "Use RubricList within a parent rubric that defines aggregation." + ) + + def append(self, rubric: Rubric) -> None: + """Add a rubric to the list.""" + index = len(self._rubrics) + setattr(self, f"rubric_{index}", rubric) + self._rubrics.append(rubric) + + def extend(self, rubrics: List[Rubric]) -> None: + """Add multiple rubrics to the list.""" + for rubric in rubrics: + self.append(rubric) + + def __len__(self) -> int: + return len(self._rubrics) + + def __getitem__(self, index: int) -> Rubric: + return self._rubrics[index] + + def __iter__(self) -> Iterator[Rubric]: + return iter(self._rubrics) + + +class RubricDict(Rubric): + """Container for named rubrics with keyed access. + + Analogous to nn.ModuleDict. Enables keyed access for multi-task + environments where different tasks require different rubrics. + + Usage: + class AtariRubric(Rubric): + def __init__(self): + super().__init__() + self.games = RubricDict({ + "pong": PongRubric(), + "breakout": BreakoutRubric(), + "space_invaders": SpaceInvadersRubric(), + }) + + def forward(self, action, obs) -> float: + return self.games[obs.game_id](action, obs) + + # Access: env.rubric.games["pong"] + """ + + def __init__(self, rubrics: Dict[str, Rubric] = None): + """Initialize with optional dictionary of rubrics. + + Args: + rubrics: Optional dictionary mapping names to rubrics. + """ + super().__init__() + self._rubric_dict: Dict[str, Rubric] = {} + if rubrics is not None: + for name, rubric in rubrics.items(): + self[name] = rubric + + def forward(self, action: Any, observation: Any) -> float: + """RubricDict does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricDict.forward() is not implemented. " + "Use RubricDict within a parent rubric that defines aggregation." + ) + + def __setitem__(self, key: str, rubric: Rubric) -> None: + """Add a rubric with the given key.""" + setattr(self, key, rubric) + self._rubric_dict[key] = rubric + + def __getitem__(self, key: str) -> Rubric: + """Get rubric by key.""" + return self._rubric_dict[key] + + def __contains__(self, key: str) -> bool: + """Check if key exists.""" + return key in self._rubric_dict + + def __len__(self) -> int: + return len(self._rubric_dict) + + def __iter__(self) -> Iterator[str]: + return iter(self._rubric_dict) + + def keys(self) -> Iterator[str]: + """Iterate over keys.""" + return iter(self._rubric_dict.keys()) + + def values(self) -> Iterator[Rubric]: + """Iterate over rubrics.""" + return iter(self._rubric_dict.values()) + + def items(self) -> Iterator[Tuple[str, Rubric]]: + """Iterate over (key, rubric) pairs.""" + return iter(self._rubric_dict.items()) + + def update(self, rubrics: Union[Dict[str, Rubric], Mapping[str, Rubric]]) -> None: + """Update with rubrics from a dictionary.""" + for name, rubric in rubrics.items(): + self[name] = rubric diff --git a/src/core/rubrics/llm_judge.py b/src/core/rubrics/llm_judge.py new file mode 100644 index 0000000000000000000000000000000000000000..4963956eb4a51270c03809f9f0e14f1c66b91958 --- /dev/null +++ b/src/core/rubrics/llm_judge.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""LLM-as-a-judge rubric for reward computation. + +Uses an LLM endpoint (via LLMClient) to evaluate agent actions/observations. + +Usage: + client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") + judge = LLMJudge( + prompt_template="Rate this code solution:\\n{action}\\n\\nScore (0-1):", + client=client, + ) + score = await judge(action, observation) + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import re +from typing import Any, Dict + +from openenv.core.llm_client import LLMClient +from openenv.core.rubrics.base import Rubric + + +class LLMJudge(Rubric): + """Rubric that uses an LLM to evaluate agent actions/observations. + + The prompt template is formatted with ``{action}`` and ``{observation}`` + placeholders. The LLM response is parsed for a numeric score. + + Args: + prompt_template: Template string with {action} and {observation} placeholders. + client: An LLMClient instance for making LLM calls. + score_pattern: Regex to extract the score from the LLM response. + Defaults to matching the first decimal number. + default_score: Score returned when parsing fails. + normalize: If True, clamp extracted score to [0, 1]. + """ + + def __init__( + self, + prompt_template: str, + client: LLMClient, + *, + score_pattern: str | None = None, + default_score: float = 0.0, + normalize: bool = True, + ): + super().__init__() + self.prompt_template = prompt_template + self._client = client + self._score_pattern = re.compile(score_pattern or r"(\d+\.?\d*)") + self.default_score = default_score + self.normalize = normalize + + async def forward(self, action: Any, observation: Any) -> float: + """Evaluate by sending a prompt to the LLM and parsing the score. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Parsed score from the LLM response. + """ + prompt = self._render_prompt(action, observation) + response = await self._client.complete(prompt) + return self._parse_score(response) + + def _render_prompt(self, action: Any, observation: Any) -> str: + """Format the prompt template with action and observation. + + Override in subclasses for custom prompt construction. + """ + return self.prompt_template.format(action=action, observation=observation) + + def _parse_score(self, response: str) -> float: + """Extract a numeric score from the LLM response. + + Uses the configured regex pattern to find the first match. + Returns default_score if no match is found. + """ + match = self._score_pattern.search(response) + if match is None: + return self.default_score + try: + # Use first capture group if present, otherwise full match + text = match.group(1) if match.lastindex else match.group(0) + score = float(text) + except (ValueError, IndexError): + return self.default_score + if self.normalize: + score = max(0.0, min(1.0, score)) + return score + + def state_dict(self) -> Dict[str, Any]: + """Serialize rubric configuration.""" + return { + "prompt_template": self.prompt_template, + "score_pattern": self._score_pattern.pattern, + "default_score": self.default_score, + "normalize": self.normalize, + } + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load rubric configuration from checkpoint.""" + if "prompt_template" in state: + self.prompt_template = state["prompt_template"] + if "score_pattern" in state: + self._score_pattern = re.compile(state["score_pattern"]) + if "default_score" in state: + self.default_score = state["default_score"] + if "normalize" in state: + self.normalize = state["normalize"] diff --git a/src/core/rubrics/trajectory.py b/src/core/rubrics/trajectory.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bb9aa9172047a24f89fae1fee6917abb861257 --- /dev/null +++ b/src/core/rubrics/trajectory.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Trajectory-based rubrics for delayed reward computation. + +These rubrics accumulate trajectory data and compute rewards based on +episode outcomes rather than individual steps. This supports scenarios +where reward signals depend on future events: + +- Terminal games (chess, Go): Win/loss known only at game end +- Plan execution: Plan quality depends on execution success +- Multi-agent games: One player's action quality depends on opponent response + +See RFC 004 "Delayed Rewards" section for design rationale. +""" + +from abc import abstractmethod +from typing import Any, Dict, List, Tuple + +from openenv.core.rubrics.base import Rubric + + +class TrajectoryRubric(Rubric): + """Abstract base for rubrics that score based on full trajectories. + + Subclasses implement: + - score_trajectory(): Compute final score from trajectory + - compute_step_rewards(): Define credit assignment strategy + + The __call__ method accumulates steps and returns rewards according + to the subclass's implementation. + + IMPORTANT: Trajectories are stored in CPU memory to avoid GPU pressure. + Environments with GPU tensors in observations must move them to CPU + before returning from step(). + + Known limitation: Very long episodes (thousands of steps) may consume + significant CPU memory. For such cases, consider streaming rubrics. + + Usage: + class WinLossRubric(TrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + return 1.0 if final_obs.metadata.get('won') else 0.0 + + def compute_step_rewards(self): + # Equal credit to all steps + score = self.score_trajectory(self._trajectory) + return [score] * len(self._trajectory) + + rubric = WinLossRubric() + for action, obs in episode: + reward = rubric(action, obs) # 0.0 until done + step_rewards = rubric.compute_step_rewards() # Credit assignment + """ + + _trajectory: List[Tuple[Any, Any]] + intermediate_reward: float + + def __init__(self, intermediate_reward: float = 0.0): + """Initialize trajectory rubric. + + Args: + intermediate_reward: Value to return for non-terminal steps. + Defaults to 0.0. + """ + super().__init__() + self.intermediate_reward = intermediate_reward + self._trajectory = [] + + def forward(self, action: Any, observation: Any) -> float: + """Accumulate step and return reward. + + Returns intermediate_reward until done, then computes trajectory score. + + Args: + action: The action taken. + observation: The resulting observation. Must have a 'done' attribute. + + Returns: + intermediate_reward if not done, else score_trajectory() result. + """ + self._trajectory.append((action, observation)) + + if getattr(observation, "done", False): + return self.score_trajectory(self._trajectory) + else: + return self.intermediate_reward + + @abstractmethod + def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float: + """Score the complete trajectory. Return 0.0-1.0. + + Called when observation.done=True. + + Args: + trajectory: List of (action, observation) tuples. + + Returns: + Final trajectory score (typically 0.0 to 1.0). + """ + raise NotImplementedError + + @abstractmethod + def compute_step_rewards(self) -> List[float]: + """Compute per-step rewards from the accumulated trajectory. + + Returns: + List of rewards, one per step. Length matches len(trajectory). + + Define your credit assignment strategy here (e.g., discounting, + assigning all credit to specific steps, etc.). + """ + raise NotImplementedError + + def reset(self) -> None: + """Clear accumulated trajectory. Call on env.reset().""" + self._trajectory = [] + + @property + def trajectory(self) -> List[Tuple[Any, Any]]: + """Current trajectory (read-only copy).""" + return list(self._trajectory) + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration (not trajectory data).""" + return {"intermediate_reward": self.intermediate_reward} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + if "intermediate_reward" in state: + self.intermediate_reward = state["intermediate_reward"] + + +class ExponentialDiscountingTrajectoryRubric(TrajectoryRubric): + """TrajectoryRubric with exponential discounting for credit assignment. + + Per-step reward: r_t = gamma^(T-1-t) * R_final + + With gamma=0.99, later steps get higher reward (they're "closer" to the outcome). + With gamma=1.0, all steps get equal reward. + With gamma=0.0, only the final step gets reward. + + This is the standard temporal discounting used in reinforcement learning, + applied retroactively once the episode outcome is known. + + Usage: + class ChessRubric(ExponentialDiscountingTrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + outcome = final_obs.metadata.get('winner') + if outcome == 'agent': return 1.0 + elif outcome == 'opponent': return 0.0 + else: return 0.5 # Draw + + rubric = ChessRubric(gamma=0.99) + reward = rubric(action, obs) # 0.0 until done, then final score + step_rewards = rubric.compute_step_rewards() # Discounted per-step rewards + """ + + gamma: float + + def __init__(self, gamma: float = 0.99, intermediate_reward: float = 0.0): + """Initialize with discount factor. + + Args: + gamma: Discount factor in [0, 1]. Higher values give more credit + to early moves. 0.99 is a common choice. + intermediate_reward: Value to return for non-terminal steps. + """ + super().__init__(intermediate_reward=intermediate_reward) + if not 0.0 <= gamma <= 1.0: + raise ValueError(f"gamma must be in [0, 1], got {gamma}") + self.gamma = gamma + + def compute_step_rewards(self) -> List[float]: + """Apply exponential discounting from final reward. + + Returns: + List of discounted rewards. step_rewards[t] = gamma^(T-1-t) * R_final + where T is the trajectory length and R_final is score_trajectory(). + """ + if not self._trajectory: + return [] + + final_score = self.score_trajectory(self._trajectory) + T = len(self._trajectory) + return [final_score * (self.gamma ** (T - 1 - t)) for t in range(T)] + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration.""" + state = super().state_dict() + state["gamma"] = self.gamma + return state + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + super().load_state_dict(state) + if "gamma" in state: + self.gamma = state["gamma"] diff --git a/src/core/sync_client.py b/src/core/sync_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5eb5da6151cea692ae447e1d4caba40a95fdaa --- /dev/null +++ b/src/core/sync_client.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Synchronous wrapper for async EnvClient. + +This module provides a SyncEnvClient that wraps an async EnvClient, +allowing synchronous usage while the underlying client uses async I/O. + +Example: + >>> from openenv.core import GenericEnvClient + >>> + >>> # Create async client and get sync wrapper + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous API + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"code": "print('hello')"}) +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import inspect +import threading +from typing import Any, Dict, Generic, TYPE_CHECKING, TypeVar + +from .client_types import StateT, StepResult + +if TYPE_CHECKING: + from .env_client import EnvClient + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") + + +class SyncEnvClient(Generic[ActT, ObsT, StateT]): + """ + Synchronous wrapper around an async EnvClient. + + This class provides a synchronous interface to an async EnvClient, + making it easier to use in synchronous code or to stop async from + "infecting" the entire call stack. + + The wrapper executes async operations on a dedicated background event loop + so connection state remains bound to a single loop. + + Cleanup note: + For guaranteed resource cleanup, use `with SyncEnvClient(...)` or call + `close()` explicitly. `__del__` is best-effort only and may not run + reliably (for example, during interpreter shutdown). + + Example: + >>> # From an async client + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous context manager + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"action": "test"}) + + Attributes: + _async: The wrapped async EnvClient instance + """ + + def __init__(self, async_client: "EnvClient[ActT, ObsT, StateT]"): + """ + Initialize sync wrapper around an async client. + + Args: + async_client: The async EnvClient to wrap + """ + self._async = async_client + self._loop: asyncio.AbstractEventLoop | None = None + self._loop_thread: threading.Thread | None = None + self._loop_ready = threading.Event() + self._loop_init_lock = threading.Lock() + self._async_wrapper_cache: Dict[str, Any] = {} + + def _run_loop_forever(self) -> None: + """Run a dedicated event loop for this sync client.""" + loop = asyncio.new_event_loop() + self._loop = loop + asyncio.set_event_loop(loop) + self._loop_ready.set() + loop.run_forever() + loop.close() + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + """Start background loop thread on first use.""" + if ( + self._loop is not None + and self._loop_thread + and self._loop_thread.is_alive() + ): + return self._loop + + # Protect loop initialization when multiple threads race on first use. + with self._loop_init_lock: + if ( + self._loop is not None + and self._loop_thread + and self._loop_thread.is_alive() + ): + return self._loop + + self._loop_ready.clear() + self._loop_thread = threading.Thread( + target=self._run_loop_forever, + name="openenv-sync-client-loop", + daemon=True, + ) + self._loop_thread.start() + if not self._loop_ready.wait(timeout=5): + raise RuntimeError("Timed out starting sync client event loop") + assert self._loop is not None + return self._loop + + def _run(self, coro: Any) -> Any: + """Run coroutine on dedicated loop and block for result.""" + loop = self._ensure_loop() + future: concurrent.futures.Future[Any] = asyncio.run_coroutine_threadsafe( + coro, loop + ) + return future.result() + + def _stop_loop(self) -> None: + """Stop and join background loop thread.""" + loop = self._loop + thread = self._loop_thread + if loop is None: + return + + if loop.is_running(): + loop.call_soon_threadsafe(loop.stop) + if thread is not None: + thread.join(timeout=5) + + self._loop = None + self._loop_thread = None + + @property + def async_client(self) -> "EnvClient[ActT, ObsT, StateT]": + """Access the underlying async client.""" + return self._async + + def connect(self) -> "SyncEnvClient[ActT, ObsT, StateT]": + """ + Establish connection to the server. + + Returns: + self for method chaining + """ + self._run(self._async.connect()) + return self + + def disconnect(self) -> None: + """Close the connection.""" + self._run(self._async.disconnect()) + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment. + + Args: + **kwargs: Optional parameters passed to the environment's reset method + + Returns: + StepResult containing initial observation + """ + return self._run(self._async.reset(**kwargs)) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters + + Returns: + StepResult containing observation, reward, and done status + """ + return self._run(self._async.step(action, **kwargs)) + + def state(self) -> StateT: + """ + Get the current environment state. + + Returns: + State object with environment state information + """ + return self._run(self._async.state()) + + def close(self) -> None: + """Close the connection and clean up resources.""" + try: + self._run(self._async.close()) + finally: + self._stop_loop() + + def __enter__(self) -> "SyncEnvClient[ActT, ObsT, StateT]": + """Enter context manager, establishing connection.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close() + + def __del__(self) -> None: + """ + Best-effort cleanup for background loop thread. + + Do not rely on this for deterministic cleanup; prefer context-manager + usage or an explicit `close()` call. + """ + try: + self._stop_loop() + except Exception: + pass + + def __getattr__(self, name: str) -> Any: + """ + Delegate unknown attributes to the async client. + + Async methods are wrapped to run on the sync client's dedicated loop. + """ + attr = getattr(self._async, name) + + if inspect.iscoroutinefunction(attr): + cached = self._async_wrapper_cache.get(name) + if cached is not None: + return cached + + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + method = getattr(self._async, name) + return self._run(method(*args, **kwargs)) + + self._async_wrapper_cache[name] = sync_wrapper + return sync_wrapper + + return attr + + # Delegate abstract method implementations to the wrapped client + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Delegate to async client's _step_payload.""" + return self._async._step_payload(action) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Delegate to async client's _parse_result.""" + return self._async._parse_result(payload) + + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Delegate to async client's _parse_state.""" + return self._async._parse_state(payload) diff --git a/src/core/tools/__init__.py b/src/core/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0193b2619fc14f14152e3276f54aa0d4aed8ca2c --- /dev/null +++ b/src/core/tools/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core tools for code execution and other utilities.""" + +from .git_server_client import GitServerClient, RepoInfo + +try: + from .local_python_executor import PyExecutor +except ModuleNotFoundError: + # smolagents is optional for environments that only need Git tooling. + PyExecutor = None # type: ignore[assignment] + +__all__ = [ + "PyExecutor", + "GitServerClient", + "RepoInfo", +] diff --git a/src/core/tools/git_server_client.py b/src/core/tools/git_server_client.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc3379f6b675178cc7aa94914c31f66bc846aed --- /dev/null +++ b/src/core/tools/git_server_client.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 +""" +Git Server Client for connecting to external Gitea instance. + +This module provides a lightweight client for interacting with a shared +Gitea service, optimized for task-based isolation where multiple environment +instances share the same Gitea server but have isolated workspaces. +""" + +import json +import os +import shutil +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from urllib.parse import urlparse + + +@dataclass +class RepoInfo: + """Information about a repository.""" + + name: str + url: str + commit: str + clone_url: str + + +class GitServerClient: + """ + Client for connecting to an external Gitea server. + + This client is optimized for task-based isolation where: + - Multiple tasks share the same Gitea instance + - Each task has its own isolated workspace + - Fast reset() via git operations (no server restart) + - Repos are pre-migrated to Gitea once + + Args: + gitea_url: URL of the Gitea server (e.g., "http://gitea:3000") + username: Gitea username for authentication + password: Gitea password for authentication + workspace_dir: Local workspace directory for cloning repos + + Example: + >>> # Connect to shared Gitea (credentials from environment) + >>> import os + >>> client = GitServerClient( + ... gitea_url=os.getenv("GITEA_URL"), + ... username=os.getenv("GITEA_USERNAME"), + ... password=os.getenv("GITEA_PASSWORD") + ... ) + >>> client.wait_for_ready() + >>> # Clone repo to workspace + >>> path = client.clone_to_workspace("my-repo", commit="abc123") + >>> # Fast reset to base state + >>> client.reset_workspace("my-repo", commit="abc123") + """ + + def __init__( + self, + gitea_url: str, + username: str, + password: str, + workspace_dir: str = "/workspace", + ): + """Initialize Git Server Client.""" + self.gitea_url = gitea_url.rstrip("/") + self.username = username + self.password = password + self.workspace_dir = Path(workspace_dir) + self.is_ready = False + + # Parse Gitea URL + parsed = urlparse(self.gitea_url) + self.domain = parsed.hostname or "localhost" + self.port = parsed.port or 3000 + + # Ensure workspace exists + os.makedirs(self.workspace_dir, exist_ok=True) + + # Configure git credentials + self._configure_git() + + def _configure_git(self): + """Configure git credentials for automatic authentication.""" + home_dir = Path.home() + + # Git config + git_config = f"""[user] + name = {self.username} + email = {self.username}@local.env +[init] + defaultBranch = main +[credential] + helper = store +""" + gitconfig_path = home_dir / ".gitconfig" + gitconfig_path.write_text(git_config) + + # Git credentials + git_credentials = ( + f"http://{self.username}:{self.password}@{self.domain}:{self.port}\n" + ) + gitcreds_path = home_dir / ".git-credentials" + gitcreds_path.write_text(git_credentials) + gitcreds_path.chmod(0o600) + + def wait_for_ready(self, timeout: int = 30) -> bool: + """ + Wait for Gitea server to be ready. + + Args: + timeout: Maximum seconds to wait + + Returns: + True if server is ready, False otherwise + """ + start_time = time.time() + while time.time() - start_time < timeout: + try: + result = subprocess.run( + ["curl", "-sf", f"{self.gitea_url}/"], + capture_output=True, + timeout=5, + ) + if result.returncode == 0: + self.is_ready = True + return True + except subprocess.TimeoutExpired: + pass + except Exception: + pass + + time.sleep(1) + + return False + + def list_repositories(self) -> list[dict[str, str]]: + """ + List all repositories in Gitea. + + Returns: + List of repository information dictionaries + """ + if not self.is_ready: + raise RuntimeError("Gitea server is not ready") + + result = subprocess.run( + [ + "curl", + "-s", + f"{self.gitea_url}/api/v1/user/repos", + "-u", + f"{self.username}:{self.password}", + ], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + return [] + + try: + repos = json.loads(result.stdout) + return [ + { + "name": repo["name"], + "full_name": repo["full_name"], + "clone_url": repo["clone_url"], + "description": repo.get("description", ""), + } + for repo in repos + ] + except (json.JSONDecodeError, KeyError): + return [] + + def clone_to_workspace( + self, repo_name: str, target_dir: str | None = None, commit: str = "main" + ) -> str: + """ + Clone a repository to the workspace at a specific commit. + + This creates a fresh clone optimized for task isolation. + + Args: + repo_name: Name of repository to clone + target_dir: Target directory name (defaults to repo_name) + commit: Commit hash or branch to check out + + Returns: + Path to cloned repository + + Raises: + RuntimeError: If clone fails + """ + if not self.is_ready: + raise RuntimeError("Gitea server is not ready") + + target_dir = target_dir or repo_name + target_path = self.workspace_dir / target_dir + + # Remove existing directory if present + if target_path.exists(): + shutil.rmtree(target_path) + + clone_url = f"{self.gitea_url}/{self.username}/{repo_name}.git" + + # Clone repository + result = subprocess.run( + ["git", "clone", clone_url, str(target_path)], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Clone failed: {result.stderr}") + + # Checkout specific commit + if commit != "main": + result = subprocess.run( + ["git", "checkout", commit], + cwd=str(target_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Checkout failed: {result.stderr}") + + return str(target_path) + + def reset_workspace(self, repo_name: str, commit: str = "main") -> bool: + """ + Fast reset of workspace to base state (optimized for task resets). + + This is much faster than re-cloning. It: + 1. Checks out the target commit + 2. Resets to that commit (hard) + 3. Cleans untracked files + + Args: + repo_name: Name of repository (directory in workspace) + commit: Commit hash or branch to reset to + + Returns: + True if reset successful + + Raises: + RuntimeError: If reset fails + """ + repo_path = self.workspace_dir / repo_name + + if not repo_path.exists(): + raise RuntimeError(f"Repository not found in workspace: {repo_name}") + + # Fetch latest (in case commit is new) + subprocess.run( + ["git", "fetch", "--all"], + cwd=str(repo_path), + capture_output=True, + ) + + # Checkout and hard reset to commit + result = subprocess.run( + ["git", "checkout", commit], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Checkout failed: {result.stderr}") + + result = subprocess.run( + [ + "git", + "reset", + "--hard", + f"origin/{commit}" if commit != "main" else commit, + ], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + # Try without origin/ prefix + result = subprocess.run( + ["git", "reset", "--hard", commit], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"Reset failed: {result.stderr}") + + # Clean untracked files and directories + subprocess.run( + ["git", "clean", "-fdx"], + cwd=str(repo_path), + capture_output=True, + ) + + return True + + def execute_git_command( + self, command: str, working_dir: str = "" + ) -> tuple[int, str, str]: + """ + Execute a git command in the workspace. + + Args: + command: Git command to execute (without 'git' prefix) + working_dir: Working directory relative to workspace + + Returns: + Tuple of (exit_code, stdout, stderr) + """ + work_path = ( + self.workspace_dir / working_dir if working_dir else self.workspace_dir + ) + + if not work_path.exists(): + return (1, "", f"Working directory does not exist: {work_path}") + + # Split command safely + cmd_parts = ["git"] + command.split() + + result = subprocess.run( + cmd_parts, + cwd=str(work_path), + capture_output=True, + text=True, + ) + + return (result.returncode, result.stdout, result.stderr) + + def get_current_commit(self, repo_name: str) -> str: + """ + Get current commit hash of a workspace repository. + + Args: + repo_name: Name of repository in workspace + + Returns: + Commit hash + """ + repo_path = self.workspace_dir / repo_name + + if not repo_path.exists(): + raise RuntimeError(f"Repository not found: {repo_name}") + + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to get commit: {result.stderr}") + + return result.stdout.strip() + + def workspace_exists(self, repo_name: str) -> bool: + """Check if a repository exists in workspace.""" + return (self.workspace_dir / repo_name).exists() diff --git a/src/core/tools/local_python_executor.py b/src/core/tools/local_python_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..bb18052b309b3c214bcf0e5c2645416734575fa1 --- /dev/null +++ b/src/core/tools/local_python_executor.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Local Python Executor (enhanced). + +This module provides a safer wrapper around smolagents.LocalPythonExecutor +with improved exception handling and a few helpful tools registered with +the executor to make debugging executed code easier. + +Key improvements: +- Register a few helper utilities via send_tools so user code can use + them for reporting (e.g. `format_exc`). +- More robust extraction of stdout/stderr/exit codes from the executor + result object, tolerant to different versions of smolagents. +- Detailed stderr on unexpected exceptions including full traceback. +- Structured logging for operational visibility. +""" + +from __future__ import annotations + +import json +import logging +import traceback + +from openenv.core.env_server.types import CodeExecResult +from smolagents import LocalPythonExecutor + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class PyExecutor: + """Wrapper around smolagents LocalPythonExecutor. + + The wrapper registers a few non-privileged helper tools to the + LocalPythonExecutor that can be used by the executed code to + format exceptions and to safely stringify results for improved + error reporting. + """ + + def __init__(self, additional_imports: list[str] | None = None): + if additional_imports is None: + additional_imports = [] + + self._executor = LocalPythonExecutor( + additional_authorized_imports=additional_imports + ) + + # Register helpful utilities exposed to the execution environment. + # These are intentionally small, read-only helpers. + tools = { + # Provide a small helper to format the current exception in the + # executed context. This is a *string formatting* helper only. + "format_exc": traceback.format_exc, + # Safe JSON dumps with a fallback for non-serializable objects. + "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)), + } + + # `send_tools` is the public API on LocalPythonExecutor to make + # helper callables available to the sandboxed runtime. We don't + # provide any builtins that could change the environment. + try: + self._executor.send_tools(tools) + except Exception: + # If the LocalPythonExecutor implementation doesn't support + # send_tools or fails, log and continue — the executor is still usable. + logger.debug( + "LocalPythonExecutor.send_tools failed; continuing without extra tools", + exc_info=True, + ) + + def run(self, code: str) -> CodeExecResult: + """Execute Python code and return a CodeExecResult. + + This method is intentionally defensive: it attempts to extract + meaningful stdout/stderr/exit_code information from a variety of + possible return shapes that different versions of smolagents + may provide. + """ + try: + exec_result = self._executor(code) + + # Default values + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + exit_code = 0 + + # Extract logs/prints + try: + logs = getattr(exec_result, "logs", None) + if logs: + stdout_parts.append(str(logs)) + except Exception: + logger.debug("Failed to read exec_result.logs", exc_info=True) + + # Extract the result / output value + try: + if hasattr(exec_result, "output"): + out_val = exec_result.output + # If the output is not None, stringify it in a safe way + if out_val is not None: + # Prefer JSON if possible, otherwise repr + try: + stdout_parts.append(json.dumps(out_val)) + except Exception: + stdout_parts.append(repr(out_val)) + except Exception: + logger.debug("Failed to read exec_result.output", exc_info=True) + + # Some runtime implementations may put errors on `error` or `exception` + try: + err = getattr(exec_result, "error", None) + if err: + stderr_parts.append(str(err)) + except Exception: + logger.debug("Failed to read exec_result.error", exc_info=True) + + try: + ex = getattr(exec_result, "exception", None) + if ex: + stderr_parts.append(str(ex)) + except Exception: + logger.debug("Failed to read exec_result.exception", exc_info=True) + + # Determine exit code if provided + try: + if hasattr(exec_result, "exit_code"): + exit_code = ( + int(exec_result.exit_code) + if exec_result.exit_code is not None + else 0 + ) + elif hasattr(exec_result, "success"): + # Some versions use `success` boolean + exit_code = 0 if exec_result.success else 1 + else: + # Fallback: if there were any stderr parts, treat as non-zero + exit_code = 1 if stderr_parts else 0 + except Exception: + logger.debug("Failed to determine exec_result exit code", exc_info=True) + exit_code = 1 if stderr_parts else 0 + + # Compose the final stdout/stderr strings + stdout = "\n".join(part for part in stdout_parts if part is not None) + stderr = "\n".join(part for part in stderr_parts if part is not None) + + return CodeExecResult(stdout=stdout, stderr=stderr, exit_code=exit_code) + + except Exception: + # Any unexpected exception from the LocalPythonExecutor is + # returned with a full traceback to make debugging easier. + tb = traceback.format_exc() + logger.exception("LocalPythonExecutor raised an exception during run") + return CodeExecResult(stdout="", stderr=tb, exit_code=1) diff --git a/src/core/utils.py b/src/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e86b3ae9c3e6ec0a19cd6f4868e4e3cdfee66bbc --- /dev/null +++ b/src/core/utils.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for OpenEnv core.""" + +import asyncio +import concurrent.futures + + +def run_async_safely(coro): + """ + Run an async coroutine safely from any context. + + This handles the case where we may already be inside an async event loop + (e.g., when called from an async framework). In that case, asyncio.run() + would fail, so we use a ThreadPoolExecutor to run in a separate thread. + + Args: + coro: The coroutine to run + + Returns: + The result of the coroutine + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in async context - run in a thread pool + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + else: + # No async context - use asyncio.run() directly + return asyncio.run(coro) + + +def convert_to_ws_url(url: str) -> str: + """ + Convert an HTTP/HTTPS URL to a WS/WSS URL. + + Args: + url: The URL to convert. + + Returns: + The converted WebSocket URL. + """ + ws_url = url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + return ws_url diff --git a/src/openenv.egg-info/PKG-INFO b/src/openenv.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..2d5c7aaef9a35d5234c6e55122bd0b620fef6273 --- /dev/null +++ b/src/openenv.egg-info/PKG-INFO @@ -0,0 +1,337 @@ +Metadata-Version: 2.4 +Name: openenv +Version: 0.2.0 +Summary: A unified framework for reinforcement learning environments +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: fastapi>=0.104.0 +Requires-Dist: pydantic>=2.0.0 +Requires-Dist: uvicorn>=0.24.0 +Requires-Dist: requests>=2.25.0 +Requires-Dist: typer>=0.9.0 +Requires-Dist: rich>=13.0.0 +Requires-Dist: pyyaml>=6.0 +Requires-Dist: huggingface_hub>=0.20.0 +Requires-Dist: openai>=2.7.2 +Requires-Dist: tomli>=2.3.0 +Requires-Dist: tomli-w>=1.2.0 +Requires-Dist: websockets>=15.0.1 +Provides-Extra: core +Requires-Dist: fastapi>=0.104.0; extra == "core" +Requires-Dist: pydantic>=2.0.0; extra == "core" +Requires-Dist: uvicorn>=0.24.0; extra == "core" +Requires-Dist: requests>=2.25.0; extra == "core" +Requires-Dist: websockets>=15.0.1; extra == "core" +Provides-Extra: cli +Requires-Dist: typer>=0.9.0; extra == "cli" +Requires-Dist: rich>=13.0.0; extra == "cli" +Requires-Dist: pyyaml>=6.0; extra == "cli" +Requires-Dist: huggingface_hub>=0.20.0; extra == "cli" +Requires-Dist: openai>=2.7.2; extra == "cli" +Requires-Dist: tomli>=2.3.0; extra == "cli" +Requires-Dist: tomli-w>=1.2.0; extra == "cli" +Provides-Extra: all +Requires-Dist: openenv[core]; extra == "all" +Requires-Dist: openenv[cli]; extra == "all" +Dynamic: license-file + +# image OpenEnv: Agentic Execution Environments + +An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. + +[![PyPI](https://img.shields.io/pypi/v/openenv?color=blue)](https://pypi.org/project/openenv/) +[![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb) +[![Docs](https://img.shields.io/badge/Docs-Explore-blue?logo=readthedocs&logoColor=white)](https://meta-pytorch.org/OpenEnv/) + +--- + +**🚀 Featured Example:** Train LLMs to play BlackJack using [torchforge](https://github.com/meta-pytorch/torchforge) (PyTorch's agentic RL framework): [`examples/grpo_blackjack/`](examples/grpo_blackjack/) + +## OpenEnv on partner platforms: + +- [Lightning AI Studio](https://lightning.ai/environments?section=featured) +- [TRL example](https://huggingface.co/docs/trl/main/en/openenv) +- [Unsloth Google Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb) +- [ART example](https://art.openpipe.ai/integrations/openenv-integration) +- [Oumi example](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20OpenEnv%20GRPO%20with%20trl.ipynb) + +## Overview + +OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - `step()`, `reset()`, `state()`. Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs. + +In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use. + +The OpenEnv CLI (`openenv`) provides commands to initialize new environments and deploy them to Hugging Face Spaces. + +> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental +> stage. You should expect bugs, incomplete features, and APIs that may change +> in future versions. The project welcomes bugfixes, but to make sure things are +> well coordinated you should discuss any significant change before starting the +> work. It's recommended that you signal your intention to contribute in the +> issue tracker, either by filing a new issue or by claiming an existing one. + +### RFCs + +Below is a list of active and historical RFCs for OpenEnv. RFCs are proposals for major changes or features. Please review and contribute! + +- [RFC 001: Baseline API and Interface Specifications](https://github.com/meta-pytorch/OpenEnv/pull/26) + +## Architecture + +### Component Overview + +``` +┌─────────────────────────────────────────────────────────┐ +│ Client Application │ +│ ┌────────────────┐ ┌──────────────────┐ │ +│ │ EchoEnv │ │ CodingEnv │ │ +│ │ (HTTPEnvClient)│ │ (HTTPEnvClient) │ │ +│ └────────┬───────┘ └────────┬─────────┘ │ +└───────────┼───────────────────────────────┼─────────────┘ + │ HTTP │ HTTP + │ (reset, step, state) │ +┌───────────▼───────────────────────────────▼─────────────┐ +│ Docker Containers (Isolated) │ +│ ┌──────────────────────┐ ┌──────────────────────┐ │ +│ │ FastAPI Server │ │ FastAPI Server │ │ +│ │ EchoEnvironment │ │ PythonCodeActEnv │ │ +│ │ (Environment base) │ │ (Environment base) │ │ +│ └──────────────────────┘ └──────────────────────┘ │ +└─────────────────────────────────────────────────────────┘ +``` + +### Core Components + +#### 1. Web Interface + +OpenEnv includes a built-in web interface for interactive environment exploration and debugging. The web interface provides: + +- **Two-Pane Layout**: HumanAgent interaction on the left, state observation on the right +- **Real-time Updates**: WebSocket-based live updates without page refresh +- **Dynamic Forms**: Automatically generated action forms based on environment Action types +- **Action History**: Complete log of all actions taken and their results + +The web interface is **conditionally enabled** based on environment variables: + +- **Local Development**: Disabled by default for lightweight development +- **Manual Override**: Enable with `ENABLE_WEB_INTERFACE=true` + +To use the web interface: + +```python +from openenv.core.env_server import create_web_interface_app +from your_env.models import YourAction, YourObservation +from your_env.server.your_environment import YourEnvironment + +env = YourEnvironment() +app = create_web_interface_app(env, YourAction, YourObservation) +``` + +When enabled, open `http://localhost:8000/web` in your browser to interact with the environment. + +#### 2. Environment (Server-Side) +Base class for implementing environment logic: +- **`reset()`**: Initialize a new episode, returns initial `Observation` +- **`step(action)`**: Execute an `Action`, returns resulting `Observation` +- **`state()`**: Access episode metadata (`State` with episode_id, step_count, etc.) + +#### 3. HTTPEnvClient (Client-Side) +Base class for HTTP communication: +- Handles HTTP requests to environment server +- Contains a utility to spin up a docker container locally for the corresponding environment +- Type-safe action/observation parsing + +#### 4. Container Providers +Manage container deployment: +- `LocalDockerProvider`: Run containers on local Docker daemon +- `KubernetesProvider`: Deploy to K8s clusters (future) + +#### 5. Models +Type-safe data structures: +- `Action`: Base class for environment actions +- `Observation`: Base class for environment observations +- `State`: Episode state tracking +- `StepResult`: Combines observation, reward, done flag + +## Project Structure + +### For Environment Creators + +Use the CLI to quickly scaffold a new environment: + +```bash +openenv init my_env +``` + +This creates the following structure: + +``` +my_env/ +├── .dockerignore # Docker build exclusions +├── __init__.py # Export YourAction, YourObservation, YourEnv +├── models.py # Define Action, Observation, State dataclasses +├── client.py # Implement YourEnv(HTTPEnvClient) +├── README.md # Document your environment +├── openenv.yaml # Environment manifest +├── pyproject.toml # Dependencies and package configuration +├── outputs/ # Runtime outputs (logs, evals) - gitignored +│ ├── logs/ +│ └── evals/ +└── server/ + ├── your_environment.py # Implement YourEnvironment(Environment) + ├── app.py # Create FastAPI app + ├── requirements.txt # Dependencies for Docker (can be generated) + └── Dockerfile # Define container image +``` + +#### Dependency Management + +OpenEnv uses `pyproject.toml` as the primary dependency specification: + +- **Environment-level `pyproject.toml`**: Each environment defines its own dependencies +- **Root-level `pyproject.toml`**: Contains shared core dependencies (fastapi, pydantic, uvicorn) +- **Server `requirements.txt`**: Can be auto-generated from `pyproject.toml` for Docker builds + +**Development Workflow:** + +```bash +# Install environment in editable mode +cd my_env +pip install -e . + +# Or using uv (faster) +uv pip install -e . + +# Run server locally without Docker +uv run server --host 0.0.0.0 --port 8000 +``` + +**Benefits:** +- ✅ **Client-side extensions**: Modify client classes locally without repo changes +- ✅ **Better dependency management**: Clear separation between environments +- ✅ **Flexible workflows**: Use pip, uv, or Docker for different scenarios +- ✅ **CI/CD ready**: Automated dependency generation and validation + +See [`envs/README.md`](envs/README.md) for a complete guide on building environments. + +### For Environment Users + +To use an environment: +1. Import from `envs.your_env`: `from envs.echo_env import EchoAction, EchoEnv` +2. Create client: `client = EchoEnv.from_docker_image("echo-env:latest")` +3. Interact: `client.reset()`, `client.step(action)`, `client.state()` +4. Cleanup: `client.close()` + +See example scripts in `examples/` directory. + +## CLI Commands + +The OpenEnv CLI provides commands to manage environments: + +- **`openenv init `** - Initialize a new environment from template +- **`openenv push [--repo-id ] [--private]`** - Deploy environment to Hugging Face Spaces + +### Quick Start + +```bash +# Create a new environment +openenv init my_game_env + +# Deploy to Hugging Face (will prompt for login if needed) +cd my_game_env +openenv push +``` + +For detailed options: `openenv init --help` and `openenv push --help`. + +## Design Principles + +1. **Separation of Concerns**: Clear client-server boundaries +2. **Type Safety**: Strongly-typed actions, observations, and state +3. **Container Isolation**: Each environment runs in its own container +4. **Simple APIs**: Minimal, intuitive interfaces + +## Quick Start + +### Using the Echo Environment(Example) + +```python +from envs.echo_env import EchoAction, EchoEnv + +# Automatically start container and connect +client = EchoEnv.from_docker_image("echo-env:latest") + +# Reset the environment +result = client.reset() +print(result.observation.echoed_message) # "Echo environment ready!" + +# Send messages +result = client.step(EchoAction(message="Hello, World!")) +print(result.observation.echoed_message) # "Hello, World!" +print(result.reward) # 1.3 (based on message length) + +# Cleanup +client.close() # Stops and removes container +``` + +## Requirements + +- Python 3.11+ +- Docker Desktop or Docker Engine +- FastAPI >= 0.104.0 +- Uvicorn >= 0.24.0 +- Requests >= 2.25.0 +- smolagents (for coding environment) + +## Supported RL Tools +The goal of this project is to support a broad set of open and closed tools to help standardize the agentic RL community. If you have a project that supports OpenEnv environments, please put up a PR to add your tool name along with a link to your documentation. + +### torchforge +See GRPO BlackJack training example: [`examples/grpo_blackjack/`](examples/grpo_blackjack/) + +### TRL +See the [TRL example](https://huggingface.co/docs/trl/main/en/openenv) on how to integrate OpenEnv environments with GRPO training. + +### Unsloth +See the 2048 game example based on gpt-oss: [Colab notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb) + +### SkyRL +See the [SkyRL example](https://skyrl.readthedocs.io/en/latest/examples/openenv.html) on how to train on OpenEnv environments with SkyRL. + +### ART +See the [ART example](https://art.openpipe.ai/integrations/openenv-integration) on how OpenEnv environments can be used to train models with ART. + +### Oumi +See the [Oumi example](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20OpenEnv%20GRPO%20with%20trl.ipynb) on how OpenEnv environments can be used to train models with Oumi. + +## Example Environments + +### Echo Environment +A simple environment that echoes back messages with metadata. Perfect for: +- Testing the HTTP server infrastructure +- Learning the framework basics +- Verifying container deployment + +See: [`envs/echo_env/README.md`](envs/echo_env/README.md) + +### Coding Environment +Executes arbitrary Python code in a sandboxed environment. Features: +- Safe code execution using smolagents +- Capture stdout, stderr, and exit codes +- Persistent execution context within episodes +- Error handling with detailed messages + +See: [`envs/coding_env/README.md`](envs/coding_env/README.md) + +## Community Support & Acknowledgments +This is an open and community-centric project. If you would like to add your name here, please put up a pull request and tag @jspisak for review. Ty!! + +Supporters include: Meta-PyTorch, Hugging Face, [Patronus AI](https://patronus.ai), [Surge AI](https://surgehq.ai), [LastMile AI](https://www.lastmileai.dev), Unsloth AI, Reflection AI, vLLM, SkyRL (UC-Berkeley), LightningAI, Axolotl AI, Stanford Scaling Intelligence Lab, Mithril, [OpenMined](https://openmined.org/), [Fleet AI](https://fleetai.com), [Halluminate](https://halluminate.ai/), [Turing](https://www.turing.com/) .. + +And we'd also like to acknowledge the team at Farama Foundation as the OpenEnv API was heavily inspired by the work you all have done on Gymnasium. Cheers! + +## License + +BSD 3-Clause License (see [LICENSE](./LICENSE) file) diff --git a/src/openenv.egg-info/SOURCES.txt b/src/openenv.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..195f9366d2dad0e6330359cc55a84859679f5e99 --- /dev/null +++ b/src/openenv.egg-info/SOURCES.txt @@ -0,0 +1,142 @@ +LICENSE +README.md +pyproject.toml +envs/atari_env/__init__.py +envs/atari_env/client.py +envs/atari_env/models.py +envs/atari_env/server/__init__.py +envs/atari_env/server/app.py +envs/atari_env/server/atari_environment.py +envs/browsergym_env/__init__.py +envs/browsergym_env/client.py +envs/browsergym_env/models.py +envs/browsergym_env/server/__init__.py +envs/browsergym_env/server/app.py +envs/browsergym_env/server/browsergym_environment.py +envs/chat_env/__init__.py +envs/chat_env/client.py +envs/chat_env/models.py +envs/chat_env/server/__init__.py +envs/chat_env/server/app.py +envs/chat_env/server/chat_environment.py +envs/chat_env/server/test_chat_env.py +envs/coding_env/__init__.py +envs/coding_env/client.py +envs/coding_env/models.py +envs/coding_env/server/__init__.py +envs/coding_env/server/app.py +envs/coding_env/server/python_codeact_env.py +envs/coding_env/server/python_executor.py +envs/coding_env/server/transforms.py +envs/connect4_env/__init__.py +envs/connect4_env/client.py +envs/connect4_env/models.py +envs/connect4_env/server/__init__.py +envs/connect4_env/server/app.py +envs/connect4_env/server/connect4_environment.py +envs/dipg_safety_env/__init__.py +envs/dipg_safety_env/client.py +envs/dipg_safety_env/models.py +envs/dipg_safety_env/server/__init__.py +envs/dipg_safety_env/server/app.py +envs/dipg_safety_env/server/dipg_environment.py +envs/echo_env/__init__.py +envs/echo_env/client.py +envs/echo_env/models.py +envs/echo_env/build/lib/server/__init__.py +envs/echo_env/build/lib/server/app.py +envs/echo_env/build/lib/server/echo_environment.py +envs/echo_env/server/__init__.py +envs/echo_env/server/app.py +envs/echo_env/server/echo_environment.py +envs/finrl_env/__init__.py +envs/finrl_env/client.py +envs/finrl_env/models.py +envs/finrl_env/server/__init__.py +envs/finrl_env/server/app.py +envs/finrl_env/server/finrl_environment.py +envs/git_env/__init__.py +envs/git_env/client.py +envs/git_env/models.py +envs/git_env/server/__init__.py +envs/git_env/server/app.py +envs/git_env/server/git_task_environment.py +envs/openspiel_env/__init__.py +envs/openspiel_env/client.py +envs/openspiel_env/models.py +envs/openspiel_env/server/__init__.py +envs/openspiel_env/server/app.py +envs/openspiel_env/server/openspiel_environment.py +envs/openspiel_env/server/opponent_policies.py +envs/play/build/lib/server/__init__.py +envs/play/build/lib/server/app.py +envs/play/build/lib/server/play_environment.py +envs/sumo_rl_env/__init__.py +envs/sumo_rl_env/client.py +envs/sumo_rl_env/models.py +envs/sumo_rl_env/server/__init__.py +envs/sumo_rl_env/server/app.py +envs/sumo_rl_env/server/sumo_environment.py +envs/textarena_env/__init__.py +envs/textarena_env/client.py +envs/textarena_env/models.py +envs/textarena_env/rewards.py +envs/textarena_env/build/lib/server/__init__.py +envs/textarena_env/build/lib/server/app.py +envs/textarena_env/build/lib/server/environment.py +envs/textarena_env/server/__init__.py +envs/textarena_env/server/app.py +envs/textarena_env/server/environment.py +src/openenv/__init__.py +src/openenv.egg-info/PKG-INFO +src/openenv.egg-info/SOURCES.txt +src/openenv.egg-info/dependency_links.txt +src/openenv.egg-info/entry_points.txt +src/openenv.egg-info/requires.txt +src/openenv.egg-info/top_level.txt +src/openenv/cli/__init__.py +src/openenv/cli/__main__.py +src/openenv/cli/_cli_utils.py +src/openenv/cli/_validation.py +src/openenv/cli/commands/__init__.py +src/openenv/cli/commands/build.py +src/openenv/cli/commands/init.py +src/openenv/cli/commands/push.py +src/openenv/cli/commands/serve.py +src/openenv/cli/commands/validate.py +src/openenv/cli/templates/__init__.py +src/openenv/cli/templates/__pycache__/__init__.cpython-311.pyc +src/openenv/cli/templates/__pycache__/__init__.cpython-313.pyc +src/openenv/cli/templates/openenv_env/README.md +src/openenv/cli/templates/openenv_env/__init__.py +src/openenv/cli/templates/openenv_env/client.py +src/openenv/cli/templates/openenv_env/models.py +src/openenv/cli/templates/openenv_env/openenv.yaml +src/openenv/cli/templates/openenv_env/pyproject.toml +src/openenv/cli/templates/openenv_env/server/Dockerfile +src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +src/openenv/cli/templates/openenv_env/server/__init__.py +src/openenv/cli/templates/openenv_env/server/app.py +src/openenv/cli/templates/openenv_env/server/requirements.txt +src/openenv/core/__init__.py +src/openenv/core/client_types.py +src/openenv/core/env_client.py +src/openenv/core/utils.py +src/openenv/core/containers/__init__.py +src/openenv/core/containers/test_local_docker_provider.py +src/openenv/core/containers/runtime/__init__.py +src/openenv/core/containers/runtime/providers.py +src/openenv/core/containers/runtime/uv_provider.py +src/openenv/core/env_server/__init__.py +src/openenv/core/env_server/base_transforms.py +src/openenv/core/env_server/exceptions.py +src/openenv/core/env_server/http_server.py +src/openenv/core/env_server/interfaces.py +src/openenv/core/env_server/route_config.py +src/openenv/core/env_server/serialization.py +src/openenv/core/env_server/types.py +src/openenv/core/env_server/web_interface.py +src/openenv/core/tools/__init__.py +src/openenv/core/tools/git_server_client.py +src/openenv/core/tools/local_python_executor.py +src/openenv_core/__init__.py \ No newline at end of file diff --git a/src/openenv.egg-info/dependency_links.txt b/src/openenv.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/openenv.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/openenv.egg-info/entry_points.txt b/src/openenv.egg-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..a771c213d08379ccc0be28741bcfaccc3f64193a --- /dev/null +++ b/src/openenv.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +openenv = openenv.cli.__main__:main diff --git a/src/openenv.egg-info/requires.txt b/src/openenv.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ce491cc8741582ffd01406446576aa2675847fe --- /dev/null +++ b/src/openenv.egg-info/requires.txt @@ -0,0 +1,32 @@ +fastapi>=0.104.0 +pydantic>=2.0.0 +uvicorn>=0.24.0 +requests>=2.25.0 +typer>=0.9.0 +rich>=13.0.0 +pyyaml>=6.0 +huggingface_hub>=0.20.0 +openai>=2.7.2 +tomli>=2.3.0 +tomli-w>=1.2.0 +websockets>=15.0.1 + +[all] +openenv[core] +openenv[cli] + +[cli] +typer>=0.9.0 +rich>=13.0.0 +pyyaml>=6.0 +huggingface_hub>=0.20.0 +openai>=2.7.2 +tomli>=2.3.0 +tomli-w>=1.2.0 + +[core] +fastapi>=0.104.0 +pydantic>=2.0.0 +uvicorn>=0.24.0 +requests>=2.25.0 +websockets>=15.0.1 diff --git a/src/openenv.egg-info/top_level.txt b/src/openenv.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..e540ace7b2bb39dfb2445c526e67c391e8249487 --- /dev/null +++ b/src/openenv.egg-info/top_level.txt @@ -0,0 +1,2 @@ +openenv +openenv_core diff --git a/src/openenv/__init__.py b/src/openenv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cabe2abc6a70dacafe04f0583b27b2552bab1e47 --- /dev/null +++ b/src/openenv/__init__.py @@ -0,0 +1,54 @@ +"""Unified OpenEnv package bundling the CLI and core runtime.""" + +from __future__ import annotations + +from importlib import import_module, metadata + +__all__ = [ + "core", + "cli", + "AutoEnv", + "AutoAction", + "GenericEnvClient", + "GenericAction", + "SyncEnvClient", +] + +try: + __version__ = metadata.version("openenv") # type: ignore[arg-type] +except metadata.PackageNotFoundError: # pragma: no cover - local dev + __version__ = "0.0.0" + + +_LAZY_MODULES = { + "core": ".core", + "cli": ".cli", +} + +_LAZY_ATTRS = { + "AutoEnv": (".auto", "AutoEnv"), + "AutoAction": (".auto", "AutoAction"), + "GenericEnvClient": (".core", "GenericEnvClient"), + "GenericAction": (".core", "GenericAction"), + "SyncEnvClient": (".core", "SyncEnvClient"), +} + + +def __getattr__(name: str): + if name in _LAZY_MODULES: + module = import_module(_LAZY_MODULES[name], __name__) + globals()[name] = module + return module + + if name in _LAZY_ATTRS: + module_path, attr_name = _LAZY_ATTRS[name] + module = import_module(module_path, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return sorted(set(globals().keys()) | set(__all__)) diff --git a/src/openenv/auto/__init__.py b/src/openenv/auto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a154570d50d01ed430ea221c1896c06cbc1b7f1c --- /dev/null +++ b/src/openenv/auto/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv Auto Module +=================== + +Provides HuggingFace-style auto-discovery API for OpenEnv environments. + +This module enables automatic environment and action class loading without +manual imports: + + >>> from openenv import AutoEnv, AutoAction + >>> + >>> # Load environment from installed package or HuggingFace Hub + >>> env = AutoEnv.from_name("coding-env") + >>> + >>> # Get action class + >>> CodeAction = AutoAction.from_name("coding") + >>> action = CodeAction(code="print('Hello!')") + +Classes: + AutoEnv: Automatic environment client selection and instantiation + AutoAction: Automatic action class selection + +The auto-discovery system works by: +1. Discovering installed openenv-* packages via importlib.metadata +2. Loading environment manifests (openenv.yaml) from package resources +3. Supporting HuggingFace Hub repositories for remote environments +4. Caching discovery results for performance +""" + +from .auto_action import AutoAction +from .auto_env import AutoEnv + +__all__ = ["AutoEnv", "AutoAction"] diff --git a/src/openenv/auto/_discovery.py b/src/openenv/auto/_discovery.py new file mode 100644 index 0000000000000000000000000000000000000000..9dda19f4a393a38f74ac4e2508d5edc0e19f0990 --- /dev/null +++ b/src/openenv/auto/_discovery.py @@ -0,0 +1,584 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Environment Auto-Discovery System +================================== + +This module provides automatic discovery of OpenEnv environments by: +1. Discovering installed openenv-* packages using importlib.metadata +2. Loading manifests (openenv.yaml) from package resources +3. Caching results for performance +4. Supporting HuggingFace Hub downloads + +This enables AutoEnv to work without coupling to src/envs/ directory. +""" + +import importlib +import importlib.metadata +import importlib.resources +import json +import logging +import re +import tempfile +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Type + +import yaml + +logger = logging.getLogger(__name__) + + +@dataclass +class EnvironmentInfo: + """ + Rich information about a discovered environment. + + Attributes: + env_key: Environment key (e.g., "echo", "coding") + name: Full environment name (e.g., "echo_env") + package_name: Package name (e.g., "openenv-echo_env") + version: Version string + description: Human-readable description + client_module_path: Full module path to client (e.g., "echo_env.client") + client_class_name: Client class name (e.g., "EchoEnv") + action_class_name: Action class name (e.g., "EchoAction") + observation_class_name: Observation class name (e.g., "EchoObservation") + default_image: Default Docker image name (e.g., "echo-env:latest") + spec_version: OpenEnv spec version (from openenv.yaml) + manifest: Original manifest data + """ + + env_key: str + name: str + package_name: str + version: str + description: str + client_module_path: str + client_class_name: str + action_class_name: str + observation_class_name: str + default_image: str + spec_version: Optional[int] = None + manifest: Optional[Dict[str, Any]] = None + + def get_client_class(self) -> Type: + """ + Dynamically import and return the client class. + + Returns: + Client class (e.g., EchoEnv) + + Raises: + ImportError: If module or class cannot be imported + """ + try: + module = importlib.import_module(self.client_module_path) + return getattr(module, self.client_class_name) + except ImportError as e: + raise ImportError( + f"Failed to import {self.client_class_name} from {self.client_module_path}: {e}\n" + f"Make sure the package '{self.package_name}' is installed: " + f"pip install {self.package_name}" + ) from e + except AttributeError as e: + raise ImportError( + f"Class {self.client_class_name} not found in {self.client_module_path}: {e}" + ) from e + + def get_action_class(self) -> Type: + """ + Dynamically import and return the action class. + + Returns: + Action class (e.g., EchoAction) + + Raises: + ImportError: If module or class cannot be imported + """ + try: + module = importlib.import_module(self.client_module_path) + return getattr(module, self.action_class_name) + except ImportError as e: + raise ImportError( + f"Failed to import {self.action_class_name} from {self.client_module_path}: {e}\n" + f"Make sure the package '{self.package_name}' is installed: " + f"pip install {self.package_name}" + ) from e + except AttributeError as e: + raise ImportError( + f"Class {self.action_class_name} not found in {self.client_module_path}: {e}" + ) from e + + def get_observation_class(self) -> Type: + """ + Dynamically import and return the observation class. + + Returns: + Observation class (e.g., EchoObservation) + + Raises: + ImportError: If module or class cannot be imported + """ + try: + module = importlib.import_module(self.client_module_path) + return getattr(module, self.observation_class_name) + except ImportError as e: + raise ImportError( + f"Failed to import {self.observation_class_name} from {self.client_module_path}: {e}\n" + f"Make sure the package '{self.package_name}' is installed: " + f"pip install {self.package_name}" + ) from e + except AttributeError as e: + raise ImportError( + f"Class {self.observation_class_name} not found in {self.client_module_path}: {e}" + ) from e + + +def _normalize_env_name(name: str) -> str: + """ + Normalize environment name to standard format. + + Args: + name: Input name (e.g., "echo", "echo-env", "echo_env") + + Returns: + Normalized name (e.g., "echo_env") + + Examples: + >>> _normalize_env_name("echo") + 'echo_env' + >>> _normalize_env_name("echo-env") + 'echo_env' + >>> _normalize_env_name("echo_env") + 'echo_env' + """ + # Remove common suffixes + name = re.sub(r"[-_]env$", "", name) + # Convert hyphens to underscores + name = name.replace("-", "_") + # Add _env suffix if not present + if not name.endswith("_env"): + name = f"{name}_env" + return name + + +def _is_hub_url(name: str) -> bool: + """ + Check if name is a HuggingFace Hub URL or repo ID. + + Args: + name: Input name + + Returns: + True if it looks like a Hub URL + + Examples: + >>> _is_hub_url("meta-pytorch/echo_env") + True + >>> _is_hub_url("https://huggingface.co/meta-pytorch/echo_env") + True + >>> _is_hub_url("echo") + False + """ + # Contains org/repo pattern or huggingface.co domain + return "/" in name or "huggingface.co" in name + + +def _infer_class_name(env_name: str, class_type: str) -> str: + """ + Infer class name from environment name using simple conventions. + + Args: + env_name: Environment name (e.g., "echo_env") + class_type: Type of class ("client", "action", "observation") + + Returns: + Inferred class name + + Examples: + >>> _infer_class_name("echo_env", "client") + 'EchoEnv' + >>> _infer_class_name("echo_env", "action") + 'EchoAction' + """ + # Remove _env suffix for base name + base_name = env_name.replace("_env", "") + + # Convert to PascalCase + pascal_name = "".join(word.capitalize() for word in base_name.split("_")) + + # Add suffix based on type + if class_type == "client": + return f"{pascal_name}Env" + elif class_type == "action": + return f"{pascal_name}Action" + elif class_type == "observation": + return f"{pascal_name}Observation" + else: + raise ValueError(f"Unknown class type: {class_type}") + + +def _load_manifest_from_package( + package_name: str, module_name: str +) -> Optional[Dict[str, Any]]: + """ + Load openenv.yaml manifest from an installed package. + + Args: + package_name: Package name (e.g., "openenv-echo_env") + module_name: Module name (e.g., "echo_env") + + Returns: + Parsed manifest dictionary, or None if not found + + """ + try: + # Try to read openenv.yaml from package + if hasattr(importlib.resources, "files"): + # Python 3.9+ + package_files = importlib.resources.files(module_name) + if (package_files / "openenv.yaml").is_file(): + manifest_text = (package_files / "openenv.yaml").read_text() + return yaml.safe_load(manifest_text) + else: + # Python 3.7-3.8 fallback + with importlib.resources.open_text(module_name, "openenv.yaml") as f: + return yaml.safe_load(f) + except (FileNotFoundError, ModuleNotFoundError, AttributeError): + logger.debug(f"No openenv.yaml found in {module_name}") + return None + except Exception as e: + logger.warning(f"Failed to load openenv.yaml from {module_name}: {e}") + return None + + +def _create_env_info_from_package( + package_name: str, module_name: str, version: str +) -> Optional[EnvironmentInfo]: + """ + Create EnvironmentInfo from an installed package. + + Args: + package_name: Package name (e.g., "openenv-echo_env") + module_name: Module name (e.g., "echo_env") + version: Package version + + Returns: + EnvironmentInfo instance, or None if invalid + """ + # Load manifest + manifest = _load_manifest_from_package(package_name, module_name) + + # Get environment name + if manifest and "name" in manifest: + env_name = manifest["name"] + else: + # Infer from module name + env_name = module_name + + # Normalize to ensure _env suffix + if not env_name.endswith("_env"): + env_name = f"{env_name}_env" + + # Determine env_key (e.g., "echo_env" → "echo") + env_key = env_name.replace("_env", "") if env_name.endswith("_env") else env_name + + # Get description + description = ( + manifest.get("description", f"{env_name} environment") + if manifest + else f"{env_name} environment" + ) + + # Get spec version + spec_version = manifest.get("spec_version") if manifest else None + + # Determine class names + # Check if manifest has custom class names (custom format) + if manifest and "action" in manifest and "observation" in manifest: + # Custom format (like coding_env) + client_class_name = _infer_class_name(env_name, "client") + action_class_name = manifest.get( + "action", _infer_class_name(env_name, "action") + ) + observation_class_name = manifest.get( + "observation", _infer_class_name(env_name, "observation") + ) + else: + # Use conventions + client_class_name = _infer_class_name(env_name, "client") + action_class_name = _infer_class_name(env_name, "action") + observation_class_name = _infer_class_name(env_name, "observation") + + # Module path is just module_name.client + client_module_path = f"{module_name}.client" + + # Determine default Docker image name + image_name = env_name.replace("_", "-") + default_image = f"{image_name}:latest" + + return EnvironmentInfo( + env_key=env_key, + name=env_name, + package_name=package_name, + version=version, + description=description, + client_module_path=client_module_path, + client_class_name=client_class_name, + action_class_name=action_class_name, + observation_class_name=observation_class_name, + default_image=default_image, + spec_version=spec_version, + manifest=manifest, + ) + + +class EnvironmentDiscovery: + """ + Auto-discovery system for OpenEnv environments using installed packages. + + This class discovers installed openenv-* packages and loads their metadata. + """ + + def __init__(self): + """Initialize discovery system.""" + self._cache: Optional[Dict[str, EnvironmentInfo]] = None + self._cache_file = Path(tempfile.gettempdir()) / "openenv_discovery_cache.json" + + def _discover_installed_packages(self) -> Dict[str, EnvironmentInfo]: + """ + Discover all installed openenv-* packages. + + Returns: + Dictionary mapping env_key to EnvironmentInfo + """ + environments = {} + + # Invalidate import caches to ensure we pick up newly installed packages + importlib.invalidate_caches() + + # Get all installed packages + try: + distributions = importlib.metadata.distributions() + except Exception as e: + logger.warning(f"Failed to get installed packages: {e}") + return environments + + # Filter for openenv-* packages (exclude openenv-core) + for dist in distributions: + package_name = dist.metadata["Name"] + + if not package_name.startswith("openenv-"): + continue + + if package_name == "openenv-core": + continue + + # Get module name (e.g., "openenv-echo_env" → "echo_env") + module_name = package_name.replace("openenv-", "").replace("-", "_") + + # Get version + version = dist.version + + try: + # Create environment info + env_info = _create_env_info_from_package( + package_name, module_name, version + ) + + if env_info: + environments[env_info.env_key] = env_info + logger.debug( + f"Discovered environment: {env_info.env_key} ({package_name})" + ) + + except Exception as e: + logger.warning(f"Failed to load environment from {package_name}: {e}") + continue + + return environments + + def _load_cache(self) -> Optional[Dict[str, EnvironmentInfo]]: + """ + Load cached discovery results. + + Returns: + Dictionary of env_key -> EnvironmentInfo, or None if cache invalid + """ + if not self._cache_file.exists(): + return None + + try: + with open(self._cache_file, "r") as f: + cache_data = json.load(f) + + # Reconstruct EnvironmentInfo objects + cache = {} + for env_key, env_data in cache_data.items(): + cache[env_key] = EnvironmentInfo(**env_data) + + return cache + except Exception as e: + logger.warning(f"Failed to load discovery cache: {e}") + return None + + def _save_cache(self, environments: Dict[str, EnvironmentInfo]) -> None: + """ + Save discovery results to cache. + + Args: + environments: Dictionary of env_key -> EnvironmentInfo + """ + try: + cache_data = {} + for env_key, env_info in environments.items(): + cache_data[env_key] = asdict(env_info) + + with open(self._cache_file, "w") as f: + json.dump(cache_data, f, indent=2) + + except Exception as e: + logger.warning(f"Failed to save discovery cache: {e}") + + def discover(self, use_cache: bool = True) -> Dict[str, EnvironmentInfo]: + """ + Discover all installed OpenEnv environments. + + Args: + use_cache: If True, try to load from cache first + + Returns: + Dictionary mapping env_key to EnvironmentInfo + + Examples: + >>> discovery = EnvironmentDiscovery() + >>> envs = discovery.discover() + >>> print(envs.keys()) + dict_keys(['echo', 'coding', ...]) + """ + # Try to load from memory cache first + if use_cache and self._cache is not None: + return self._cache + + # Try to load from file cache + if use_cache: + cached = self._load_cache() + if cached is not None: + self._cache = cached + return self._cache + + # Discover from installed packages + environments = self._discover_installed_packages() + + # Save to cache + self._save_cache(environments) + self._cache = environments + + return environments + + def get_environment(self, env_key: str) -> Optional[EnvironmentInfo]: + """ + Get information about a specific environment. + + Args: + env_key: Environment key (e.g., "echo", "coding") + + Returns: + EnvironmentInfo if found, None otherwise + + Examples: + >>> discovery = EnvironmentDiscovery() + >>> env = discovery.get_environment("echo") + >>> print(env.client_class_name) + 'EchoEnv' + """ + environments = self.discover() + return environments.get(env_key) + + def get_environment_by_name(self, name: str) -> Optional[EnvironmentInfo]: + """ + Get environment info by flexible name matching. + + Args: + name: Environment name (e.g., "echo", "echo-env", "echo_env") + + Returns: + EnvironmentInfo if found, None otherwise + """ + # Normalize name to env_key + normalized = _normalize_env_name(name) + env_key = normalized.replace("_env", "") + + return self.get_environment(env_key) + + def list_environments(self) -> None: + """ + Print a formatted list of all discovered environments. + + Examples: + >>> discovery = EnvironmentDiscovery() + >>> discovery.list_environments() + Available OpenEnv Environments: + ---------------------------------------------------------------------- + echo : Echo Environment (v0.1.0) - openenv-echo_env + coding : Coding Environment (v0.1.0) - openenv-coding_env + ... + """ + environments = self.discover() + + print("Available OpenEnv Environments:") + print("-" * 70) + + if not environments: + print(" No OpenEnv environments found.") + print(" Install environments with: pip install openenv-") + else: + for env_key in sorted(environments.keys()): + env = environments[env_key] + print(f" {env_key:<15}: {env.description} (v{env.version})") + print(f" Package: {env.package_name}") + + print("-" * 70) + print(f"Total: {len(environments)} environments") + + def clear_cache(self) -> None: + """Clear the discovery cache.""" + if self._cache_file.exists(): + self._cache_file.unlink() + self._cache = None + + +# Global discovery instance +_global_discovery: Optional[EnvironmentDiscovery] = None + + +def get_discovery() -> EnvironmentDiscovery: + """ + Get or create the global discovery instance. + + Returns: + Global EnvironmentDiscovery instance + + Examples: + >>> discovery = get_discovery() + >>> envs = discovery.discover() + """ + global _global_discovery + + if _global_discovery is None: + _global_discovery = EnvironmentDiscovery() + + return _global_discovery + + +def reset_discovery() -> None: + """Reset the global discovery instance (useful for testing).""" + global _global_discovery + if _global_discovery is not None: + _global_discovery.clear_cache() + _global_discovery = None diff --git a/src/openenv/auto/auto_action.py b/src/openenv/auto/auto_action.py new file mode 100644 index 0000000000000000000000000000000000000000..b097ad1d193a605fe834ff18dd9ccd8d913eab45 --- /dev/null +++ b/src/openenv/auto/auto_action.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +AutoAction - Automatic Action Class Selection +============================================== + +AutoAction provides a HuggingFace-style API for automatically retrieving the +correct Action class from installed packages or HuggingFace Hub. + +This module simplifies working with environment actions by automatically +detecting and returning the appropriate Action class without requiring +manual imports. + +Example: + >>> from openenv import AutoEnv, AutoAction + >>> + >>> # Get Action class from environment name + >>> CodeAction = AutoAction.from_env("coding") + >>> action = CodeAction(code="print('Hello!')") + >>> + >>> # From HuggingFace Hub + >>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env") + >>> + >>> # Use with AutoEnv + >>> env = AutoEnv.from_env("coding-env") + >>> result = env.step(action) +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Type + +from ._discovery import _is_hub_url, get_discovery +from .auto_env import AutoEnv + +logger = logging.getLogger(__name__) + + +class AutoAction: + """ + AutoAction automatically retrieves the correct Action class based on + environment names or HuggingFace Hub repositories. + + This class follows the HuggingFace AutoModel pattern, making it easy to + get the right Action class without needing to know which module to import. + + The class provides factory methods that look up the Action class and + return the class (not an instance) for you to instantiate. + + Example: + >>> # From installed package + >>> CodeAction = AutoAction.from_env("coding") + >>> action = CodeAction(code="print('test')") + >>> + >>> # From HuggingFace Hub + >>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env") + >>> action = CodeAction(code="print('test')") + >>> + >>> # Use with AutoEnv for a complete workflow + >>> env = AutoEnv.from_env("coding-env") + >>> ActionClass = AutoAction.from_env("coding-env") + >>> action = ActionClass(code="print('Hello, AutoAction!')") + >>> result = env.step(action) + + Note: + AutoAction is not meant to be instantiated directly. Use the class + method from_env() instead. + """ + + def __init__(self): + """AutoAction should not be instantiated directly. Use class methods instead.""" + raise TypeError( + "AutoAction is a factory class and should not be instantiated directly. " + "Use AutoAction.from_hub() or AutoAction.from_env() instead." + ) + + @classmethod + def from_env(cls, name: str, skip_install: bool = False) -> Type: + """ + Get the Action class from environment name or HuggingFace Hub repository. + + This method automatically: + 1. Checks if the name is a HuggingFace Hub URL/repo ID + 2. If Hub: downloads and installs the environment package + 3. If local: looks up the installed openenv-* package + 4. Imports and returns the Action class + + Args: + name: Environment name or HuggingFace Hub repo ID + Examples: + - "coding" / "coding-env" / "coding_env" + - "meta-pytorch/coding-env" (Hub repo ID) + - "https://huggingface.co/meta-pytorch/coding-env" (Hub URL) + skip_install: If True, skip package installation and return + GenericAction class instead. Use this when working with + GenericEnvClient to avoid installing remote packages. + + Returns: + Action class (not an instance!). Returns GenericAction when + skip_install=True. + + Raises: + ValueError: If environment not found (only when skip_install=False) + ImportError: If environment package is not installed (only when skip_install=False) + + Examples: + >>> # From installed package + >>> CodeAction = AutoAction.from_env("coding-env") + >>> action = CodeAction(code="print('Hello!')") + >>> + >>> # From HuggingFace Hub + >>> CodeAction = AutoAction.from_env("meta-pytorch/coding-env") + >>> action = CodeAction(code="print('Hello!')") + >>> + >>> # Skip installation, use GenericAction (for GenericEnvClient) + >>> ActionClass = AutoAction.from_env("user/repo", skip_install=True) + >>> action = ActionClass(code="print('Hello!')") # Returns GenericAction + >>> + >>> # Different name formats + >>> EchoAction = AutoAction.from_env("echo") + >>> EchoAction = AutoAction.from_env("echo-env") + >>> EchoAction = AutoAction.from_env("echo_env") + """ + # If skip_install is True, return GenericAction without any package lookup + if skip_install: + from openenv.core.generic_client import GenericAction + + logger.info( + f"Returning GenericAction for '{name}' (skip_install=True). " + f"Use keyword arguments to create actions: GenericAction(code='...')" + ) + return GenericAction + + # Check if it's a HuggingFace Hub URL or repo ID + if _is_hub_url(name): + # Ensure package is installed (reuse AutoEnv logic, downloads only if needed) + env_name = AutoEnv._ensure_package_from_hub(name) + else: + env_name = name + + # Get environment info from discovery + discovery = get_discovery() + env_info = discovery.get_environment_by_name(env_name) + + if not env_info: + # Environment not found - provide helpful error message + available_envs = discovery.discover() + + if not available_envs: + raise ValueError( + "No OpenEnv environments found.\n" + "Install an environment with: pip install openenv-\n" + "Or specify a HuggingFace Hub repository: AutoAction.from_env('openenv/echo_env')" + ) + + # Try to suggest similar environment names + from difflib import get_close_matches + + env_keys = list(available_envs.keys()) + suggestions = get_close_matches(env_name, env_keys, n=3, cutoff=0.6) + + error_msg = f"Unknown environment '{env_name}'.\n" + if suggestions: + error_msg += f"Did you mean: {', '.join(suggestions)}?\n" + error_msg += f"Available environments: {', '.join(sorted(env_keys))}" + + raise ValueError(error_msg) + + # Get the action class + try: + action_class = env_info.get_action_class() + return action_class + except ImportError as e: + raise ImportError( + f"Failed to import action class for '{env_name}'.\n" + f"Package '{env_info.package_name}' appears to be installed but the module cannot be imported.\n" + f"Try reinstalling: pip install --force-reinstall {env_info.package_name}\n" + f"Original error: {e}" + ) from e + + @classmethod + def from_hub(cls, env_name: str, skip_install: bool = False) -> Type: + """ + Get the Action class from environment name. + + This is an alias for from_env() for backward compatibility and clarity. + + Args: + env_name: Environment name (e.g., "coding", "echo") + skip_install: If True, skip package installation and return + GenericAction class instead. + + Returns: + Action class (not an instance!) + + Examples: + >>> CodeAction = AutoAction.from_hub("coding") + >>> action = CodeAction(code="print('Hello!')") + """ + return cls.from_env(env_name, skip_install=skip_install) + + @classmethod + def get_action_info(cls, name: str) -> Dict[str, Any]: + """ + Get detailed information about an action class. + + Args: + name: Environment name + + Returns: + Dictionary with action class metadata + + Raises: + ValueError: If environment not found + + Examples: + >>> info = AutoAction.get_action_info("coding") + >>> print(info['action_class']) + 'CodingAction' + >>> print(info['module']) + 'coding_env.client' + """ + discovery = get_discovery() + env_info = discovery.get_environment_by_name(name) + + if not env_info: + raise ValueError(f"Unknown environment: {name}") + + return { + "env_key": env_info.env_key, + "env_name": env_info.name, + "package": env_info.package_name, + "action_class": env_info.action_class_name, + "observation_class": env_info.observation_class_name, + "module": env_info.client_module_path, + } + + @classmethod + def list_actions(cls) -> None: + """ + Print a formatted list of all available action classes. + + This discovers all installed openenv-* packages and displays + their action class information in a user-friendly format. + + Examples: + >>> AutoAction.list_actions() + Available Action Classes: + ---------------------------------------------------------------------- + echo : EchoAction (from openenv-echo-env) + coding : CodingAction (from openenv-coding_env) + ---------------------------------------------------------------------- + Total: 2 action classes + """ + discovery = get_discovery() + environments = discovery.discover() + + print("Available Action Classes:") + print("-" * 70) + + if not environments: + print(" No OpenEnv environments found.") + print(" Install environments with: pip install openenv-") + else: + for env_key in sorted(environments.keys()): + env = environments[env_key] + print(f" {env_key:<15}: {env.action_class_name}") + print(f" Package: {env.package_name}") + + print("-" * 70) + print(f"Total: {len(environments)} action classes") diff --git a/src/openenv/auto/auto_env.py b/src/openenv/auto/auto_env.py new file mode 100644 index 0000000000000000000000000000000000000000..be845565b651ec721a029505d754bb0e5328bfa6 --- /dev/null +++ b/src/openenv/auto/auto_env.py @@ -0,0 +1,897 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +AutoEnv - Automatic Environment Selection +========================================== + +AutoEnv provides a HuggingFace-style API for automatically selecting and +instantiating the correct environment client from installed packages or +HuggingFace Hub. + +This module simplifies environment creation by automatically detecting the +environment type from the name and instantiating the appropriate client class. + +Example: + >>> from openenv import AutoEnv, AutoAction + >>> + >>> # From installed package + >>> env = AutoEnv.from_env("coding-env") + >>> + >>> # From HuggingFace Hub + >>> env = AutoEnv.from_env("meta-pytorch/coding-env") + >>> + >>> # With configuration + >>> env = AutoEnv.from_env("coding", env_vars={"DEBUG": "1"}) +""" + +from __future__ import annotations + +import importlib +import logging +import os +import shutil +import subprocess +import sys +from typing import Any, Dict, Optional, TYPE_CHECKING + +import requests +from openenv.core.utils import run_async_safely + +from ._discovery import _is_hub_url, get_discovery + + +if TYPE_CHECKING: + from openenv.core.containers.runtime import ContainerProvider + from openenv.core.env_client import EnvClient + +logger = logging.getLogger(__name__) + +# Cache for repo ID → env_name mapping to avoid redundant downloads +_hub_env_name_cache: Dict[str, str] = {} + +# Environment variable to skip user confirmation for remote installs +OPENENV_TRUST_REMOTE_CODE = "OPENENV_TRUST_REMOTE_CODE" + + +def _has_uv() -> bool: + """Check if uv is available in the system.""" + return shutil.which("uv") is not None + + +def _get_pip_command() -> list[str]: + """ + Get the appropriate pip command (uv pip or pip). + + Returns: + List of command parts for pip installation + """ + if _has_uv(): + return ["uv", "pip"] + return [sys.executable, "-m", "pip"] + + +def _confirm_remote_install(repo_id: str) -> bool: + """ + Ask user for confirmation before installing remote code. + + This is a security measure since we're executing code from the internet. + + Args: + repo_id: The HuggingFace repo ID being installed + + Returns: + True if user confirms, False otherwise + """ + # Check environment variable for automated/CI environments + if os.environ.get(OPENENV_TRUST_REMOTE_CODE, "").lower() in ("1", "true", "yes"): + logger.info("Skipping confirmation (OPENENV_TRUST_REMOTE_CODE is set)") + return True + + # Check if we're in an interactive terminal + if not sys.stdin.isatty(): + logger.warning( + "Cannot prompt for confirmation in non-interactive mode. " + "Set OPENENV_TRUST_REMOTE_CODE=1 to allow remote installs." + ) + return False + + print(f"\n{'=' * 60}") + print("⚠️ SECURITY WARNING: Remote Code Installation") + print(f"{'=' * 60}") + print("You are about to install code from a remote repository:") + print(f" Repository: {repo_id}") + print(f" Source: https://huggingface.co/spaces/{repo_id}") + print("\nThis will execute code from the internet on your machine.") + print("Only proceed if you trust the source.") + print(f"{'=' * 60}\n") + + try: + response = input("Do you want to proceed? [y/N]: ").strip().lower() + return response in ("y", "yes") + except (EOFError, KeyboardInterrupt): + print("\nInstallation cancelled.") + return False + + +class AutoEnv: + """ + AutoEnv automatically selects and instantiates the correct environment client + based on environment names or HuggingFace Hub repositories. + + This class follows the HuggingFace AutoModel pattern, making it easy to work + with different environments without needing to import specific client classes. + + The class provides factory methods that: + 1. Check if name is a HuggingFace Hub URL/repo ID + 2. If Hub: download and install the environment package + 3. If local: look up the installed openenv-* package + 4. Import and instantiate the client class + + Example: + >>> # From installed package + >>> env = AutoEnv.from_env("coding-env") + >>> + >>> # From HuggingFace Hub + >>> env = AutoEnv.from_env("meta-pytorch/coding-env") + >>> + >>> # List available environments + >>> AutoEnv.list_environments() + + Note: + AutoEnv is not meant to be instantiated directly. Use the class method + from_env() instead. + """ + + def __init__(self): + """AutoEnv should not be instantiated directly. Use class methods instead.""" + raise TypeError( + "AutoEnv is a factory class and should not be instantiated directly. " + "Use AutoEnv.from_hub() or AutoEnv.from_env() instead." + ) + + @classmethod + def _resolve_space_url(cls, repo_id: str) -> str: + """ + Resolve HuggingFace Space repo ID to Space URL. + + Args: + repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + + Returns: + Space URL (e.g., "https://wukaixingxp-coding-env-test.hf.space") + + Examples: + >>> AutoEnv._resolve_space_url("wukaixingxp/coding-env-test") + 'https://wukaixingxp-coding-env-test.hf.space' + """ + # Clean up repo_id if it's a full URL + if "huggingface.co" in repo_id: + # Extract org/repo from URL + # https://huggingface.co/wukaixingxp/coding-env-test -> wukaixingxp/coding-env-test + parts = repo_id.split("/") + if len(parts) >= 2: + repo_id = f"{parts[-2]}/{parts[-1]}" + + # Convert user/space-name to user-space-name.hf.space + space_slug = repo_id.replace("/", "-") + return f"https://{space_slug}.hf.space" + + @classmethod + def _is_local_url(cls, url: str) -> bool: + """ + Check if a URL points to a local server. + + Args: + url: URL to check + + Returns: + True if URL is localhost or 127.0.0.1, False otherwise + + Examples: + >>> AutoEnv._is_local_url("http://localhost:8000") + True + >>> AutoEnv._is_local_url("http://127.0.0.1:8000") + True + >>> AutoEnv._is_local_url("https://example.com") + False + """ + url_lower = url.lower() + return "localhost" in url_lower or "127.0.0.1" in url_lower + + @classmethod + def _check_server_availability(cls, base_url: str, timeout: float = 2.0) -> bool: + """ + Check if a server at the given URL is running and accessible. + + Args: + base_url: Server base URL to check + timeout: Request timeout in seconds + + Returns: + True if server is accessible, False otherwise + + Examples: + >>> AutoEnv._check_server_availability("http://localhost:8000") + True # if server is running + """ + try: + # Bypass proxy for localhost to avoid proxy issues + proxies = None + if cls._is_local_url(base_url): + proxies = {"http": None, "https": None} + + # Try to access the health endpoint + response = requests.get( + f"{base_url}/health", timeout=timeout, proxies=proxies + ) + if response.status_code == 200: + return True + + # If health endpoint doesn't exist, try root endpoint + response = requests.get(base_url, timeout=timeout, proxies=proxies) + return response.status_code == 200 + except (requests.RequestException, Exception) as e: + logger.debug(f"Server {base_url} not accessible: {e}") + return False + + @classmethod + def _check_space_availability(cls, space_url: str, timeout: float = 5.0) -> bool: + """ + Check if HuggingFace Space is running and accessible. + + Args: + space_url: Space URL to check + timeout: Request timeout in seconds + + Returns: + True if Space is accessible, False otherwise + + Examples: + >>> AutoEnv._check_space_availability("https://wukaixingxp-coding-env-test.hf.space") + True + """ + try: + # Try to access the health endpoint + response = requests.get(f"{space_url}/health", timeout=timeout) + if response.status_code == 200: + return True + + # If health endpoint doesn't exist, try root endpoint + response = requests.get(space_url, timeout=timeout) + return response.status_code == 200 + except (requests.RequestException, Exception) as e: + logger.debug(f"Space {space_url} not accessible: {e}") + return False + + @classmethod + def _get_hub_git_url(cls, repo_id: str) -> str: + """ + Get the git URL for a HuggingFace Space. + + Args: + repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + + Returns: + Git URL for pip installation (e.g., "git+https://huggingface.co/spaces/wukaixingxp/coding-env-test") + """ + # Clean up repo_id if it's a full URL + if "huggingface.co" in repo_id: + parts = repo_id.split("/") + if len(parts) >= 2: + repo_id = f"{parts[-2]}/{parts[-1]}" + + return f"git+https://huggingface.co/spaces/{repo_id}" + + @classmethod + def _install_from_hub(cls, repo_id: str, trust_remote_code: bool = False) -> str: + """ + Install environment package directly from HuggingFace Hub using git+. + + This is the preferred method as it avoids downloading the entire repo + and uses pip/uv's native git support. + + Args: + repo_id: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + trust_remote_code: If True, skip user confirmation + + Returns: + Package name that was installed + + Raises: + ValueError: If installation fails or user declines + """ + # Security check - confirm with user before installing remote code + if not trust_remote_code and not _confirm_remote_install(repo_id): + raise ValueError( + "Installation cancelled by user.\n" + "To allow remote installs without prompting, set OPENENV_TRUST_REMOTE_CODE=1" + ) + + git_url = cls._get_hub_git_url(repo_id) + pip_cmd = _get_pip_command() + pip_name = "uv pip" if pip_cmd[0] == "uv" else "pip" + + logger.info(f"Installing from HuggingFace Space using {pip_name}: {repo_id}") + logger.info(f"Command: {' '.join(pip_cmd)} install {git_url}") + + try: + result = subprocess.run( + [*pip_cmd, "install", git_url], + check=True, + capture_output=True, + text=True, + ) + + # Try to extract package name from pip output + # Look for "Successfully installed -" + for line in result.stdout.split("\n"): + if "Successfully installed" in line: + # Parse package name from the line + parts = line.replace("Successfully installed", "").strip().split() + for part in parts: + if part.startswith("openenv-"): + # Remove version suffix (e.g., "openenv-coding_env-0.1.0" -> "openenv-coding_env") + # Check if last segment looks like a version number + last_segment = part.rsplit("-", 1)[-1] + if last_segment.replace(".", "").isdigit(): + package_name = "-".join(part.rsplit("-", 1)[:-1]) + else: + package_name = part + logger.info(f"Successfully installed: {package_name}") + return package_name + + # Fallback: try to determine package name from repo_id + # Convention: repo name like "coding-env-test" -> package "openenv-coding_env" + env_name = repo_id.split("/")[-1] # Get repo name from "user/repo" + env_name = env_name.replace("-", "_") + if not env_name.endswith("_env"): + env_name = f"{env_name}_env" + package_name = f"openenv-{env_name}" + + logger.info(f"Installed (inferred package name): {package_name}") + return package_name + + except subprocess.CalledProcessError as e: + error_msg = e.stderr or e.stdout or str(e) + raise ValueError( + f"Failed to install environment from HuggingFace Space: {repo_id}\n" + f"Command: {' '.join(pip_cmd)} install {git_url}\n" + f"Error: {error_msg}\n" + f"Make sure the repository exists and contains a valid Python package." + ) from e + + @classmethod + def _is_package_installed(cls, package_name: str) -> bool: + """ + Check if a package is already installed. + + Args: + package_name: Package name (e.g., "openenv-coding_env") + + Returns: + True if installed, False otherwise + """ + try: + import importlib.metadata + + importlib.metadata.distribution(package_name) + return True + except importlib.metadata.PackageNotFoundError: + return False + + @classmethod + def _ensure_package_from_hub( + cls, name: str, trust_remote_code: bool = False + ) -> str: + """ + Ensure package from HuggingFace Hub is installed. + + Uses git+ URLs for direct installation without downloading the entire repo. + Prompts user for confirmation before installing remote code. + + Args: + name: HuggingFace repo ID (e.g., "wukaixingxp/coding-env-test") + trust_remote_code: If True, skip user confirmation + + Returns: + Environment name (e.g., "coding_env") + """ + global _hub_env_name_cache + + # Check if we already resolved this repo ID + if name in _hub_env_name_cache: + env_name = _hub_env_name_cache[name] + logger.debug(f"Using cached env name for {name}: {env_name}") + return env_name + + # Try to infer expected package name from repo ID + # Convention: repo "user/coding-env" -> package "openenv-coding_env" + repo_name = name.split("/")[-1] if "/" in name else name + expected_env_name = repo_name.replace("-", "_") + if not expected_env_name.endswith("_env"): + expected_env_name = f"{expected_env_name}_env" + expected_package_name = f"openenv-{expected_env_name}" + + # Check if already installed + if cls._is_package_installed(expected_package_name): + logger.info(f"Package already installed: {expected_package_name}") + # Clear and refresh discovery cache to make sure it's detected + get_discovery().clear_cache() + get_discovery().discover(use_cache=False) + # Cache the result + _hub_env_name_cache[name] = expected_env_name + return expected_env_name + + # Not installed, install using git+ URL + logger.info(f"Package not found locally, installing from Hub: {name}") + + # Track existing packages before installation + get_discovery().clear_cache() + existing_envs = set(get_discovery().discover(use_cache=False).keys()) + + # Install the package + cls._install_from_hub(name, trust_remote_code=trust_remote_code) + + # Clear discovery cache to pick up the newly installed package + try: + importlib.invalidate_caches() + except Exception: + pass + get_discovery().clear_cache() + discovered_envs = get_discovery().discover(use_cache=False) + + # Find the newly installed environment by comparing before/after + new_envs = set(discovered_envs.keys()) - existing_envs + + if new_envs: + # Use the first newly discovered environment + env_name = next(iter(new_envs)) + logger.info(f"Found newly installed environment: '{env_name}'") + else: + # Fallback: try to find by matching module patterns + # Look for any env that might match the repo name pattern + repo_name = name.split("/")[-1] if "/" in name else name + repo_base = ( + repo_name.replace("-", "_").replace("_env", "").replace("_test", "") + ) + + env_name = None + for env_key, env_info in discovered_envs.items(): + # Check if env_key is a prefix/substring match + if env_key in repo_base or repo_base.startswith(env_key): + env_name = env_key + logger.info( + f"Found matching environment '{env_name}' for repo '{name}'" + ) + break + + if env_name is None: + # Last resort: use inferred name from repo + env_name = repo_name.replace("-", "_") + if not env_name.endswith("_env"): + env_name = f"{env_name}_env" + # Strip to get env_key + env_name = env_name.replace("_env", "") + logger.warning( + f"Could not find newly installed environment for repo '{name}', " + f"using inferred name: {env_name}" + ) + + # Cache the result to avoid redundant installs + _hub_env_name_cache[name] = env_name + + return env_name + + @classmethod + def from_env( + cls, + name: str, + base_url: Optional[str] = None, + docker_image: Optional[str] = None, + container_provider: Optional[ContainerProvider] = None, + wait_timeout: float = 30.0, + env_vars: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + skip_install: bool = False, + **kwargs: Any, + ) -> "EnvClient": + """ + Create an environment client from a name or HuggingFace Hub repository. + + This method automatically: + 1. Checks if the name is a HuggingFace Hub URL/repo ID + 2. If Hub: installs the environment package using git+ URL + 3. If local: looks up the installed openenv-* package + 4. Imports the client class and instantiates it + + Args: + name: Environment name or HuggingFace Hub repo ID + Examples: + - "coding" / "coding-env" / "coding_env" + - "meta-pytorch/coding-env" (Hub repo ID) + - "https://huggingface.co/meta-pytorch/coding-env" (Hub URL) + base_url: Optional base URL for HTTP connection + docker_image: Optional Docker image name (overrides default) + container_provider: Optional container provider + wait_timeout: Timeout for container startup (seconds) + env_vars: Optional environment variables for the container + trust_remote_code: If True, skip user confirmation when installing + from HuggingFace Hub. Can also be set via OPENENV_TRUST_REMOTE_CODE + environment variable. + skip_install: If True, skip package installation and return a + GenericEnvClient for remote environments. Useful when you only + want to connect to a running server without installing any + remote code. When True: + - If base_url is provided: connects directly using GenericEnvClient + - If HF Space is running: connects to Space using GenericEnvClient + - If HF Space is not running: uses Docker from HF registry + **kwargs: Additional arguments passed to the client class + + Returns: + Instance of the environment client class + + Raises: + ValueError: If environment not found or cannot be loaded + ImportError: If environment package is not installed + + Examples: + >>> # From installed package + >>> env = AutoEnv.from_env("coding-env") + >>> + >>> # From HuggingFace Hub + >>> env = AutoEnv.from_env("meta-pytorch/coding-env") + >>> + >>> # With custom Docker image + >>> env = AutoEnv.from_env("coding", docker_image="my-coding-env:v2") + >>> + >>> # With environment variables + >>> env = AutoEnv.from_env( + ... "dipg", + ... env_vars={"DIPG_DATASET_PATH": "/data/dipg"} + ... ) + >>> + >>> # Skip package installation, use GenericEnvClient + >>> env = AutoEnv.from_env( + ... "user/my-env", + ... skip_install=True + ... ) + """ + from openenv.core import GenericEnvClient + + # Handle skip_install mode - return GenericEnvClient without package installation + if skip_install: + # If base_url is provided, connect directly + if base_url: + if cls._check_server_availability(base_url): + logger.info( + f"Using GenericEnvClient for {base_url} (skip_install=True)" + ) + return GenericEnvClient(base_url=base_url, **kwargs) + else: + raise ConnectionError( + f"Server not available at {base_url}. " + f"Please ensure the server is running." + ) + + # If it's a Hub URL, try to connect to Space or use Docker + if _is_hub_url(name): + space_url = cls._resolve_space_url(name) + logger.info(f"Checking if HuggingFace Space is accessible: {space_url}") + + if cls._check_space_availability(space_url): + logger.info( + f"Using GenericEnvClient for Space {space_url} (skip_install=True)" + ) + return GenericEnvClient(base_url=space_url, **kwargs) + else: + # Space not running, use Docker from HF registry + logger.info( + f"Space not running at {space_url}, " + f"using GenericEnvClient with HF Docker registry" + ) + return run_async_safely( + GenericEnvClient.from_env( + name, + use_docker=True, + provider=container_provider, + env_vars=env_vars or {}, + **kwargs, + ) + ) + + # For local environments with skip_install, we need docker_image + if docker_image: + logger.info( + f"Using GenericEnvClient with Docker image {docker_image} " + f"(skip_install=True)" + ) + return run_async_safely( + GenericEnvClient.from_docker_image( + image=docker_image, + provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars or {}, + **kwargs, + ) + ) + else: + raise ValueError( + f"Cannot use skip_install=True for local environment '{name}' " + f"without providing base_url or docker_image. " + f"For local environments, either:\n" + f" 1. Provide base_url to connect to a running server\n" + f" 2. Provide docker_image to start a container\n" + f" 3. Set skip_install=False to use the installed package" + ) + + # Check if it's a HuggingFace Hub URL or repo ID + if _is_hub_url(name): + # Try to connect to Space directly first + space_url = cls._resolve_space_url(name) + logger.info(f"Checking if HuggingFace Space is accessible: {space_url}") + + space_is_available = cls._check_space_availability(space_url) + + if space_is_available and base_url is None: + # Space is accessible! We'll connect directly without Docker + logger.info(f"Space is accessible at: {space_url}") + logger.info("Installing package for client code (no Docker needed)...") + + # Ensure package is installed (uses git+ URL) + env_name = cls._ensure_package_from_hub( + name, trust_remote_code=trust_remote_code + ) + + # Set base_url to connect to remote Space + base_url = space_url + logger.info("Will connect to remote Space (no local Docker)") + else: + # Space not accessible or user provided explicit base_url + if not space_is_available: + logger.info(f"Space not accessible at {space_url}") + logger.info("Falling back to local Docker mode...") + + # Ensure package is installed (uses git+ URL) + env_name = cls._ensure_package_from_hub( + name, trust_remote_code=trust_remote_code + ) + else: + env_name = name + + # Get environment info from discovery + discovery = get_discovery() + env_info = discovery.get_environment_by_name(env_name) + + if not env_info: + # Environment not found - provide helpful error message + available_envs = discovery.discover() + + if not available_envs: + raise ValueError( + "No OpenEnv environments found.\n" + "Install an environment with: pip install openenv-\n" + "Or specify a HuggingFace Hub repository: AutoEnv.from_env('openenv/echo_env')" + ) + + # Try to suggest similar environment names + from difflib import get_close_matches + + env_keys = list(available_envs.keys()) + suggestions = get_close_matches(env_name, env_keys, n=3, cutoff=0.6) + + error_msg = f"Unknown environment '{env_name}'.\n" + if suggestions: + error_msg += f"Did you mean: {', '.join(suggestions)}?\n" + error_msg += f"Available environments: {', '.join(sorted(env_keys))}" + + raise ValueError(error_msg) + + # Get the client class + try: + client_class = env_info.get_client_class() + except ImportError as e: + raise ImportError( + f"Failed to import environment client for '{env_name}'.\n" + f"Package '{env_info.package_name}' appears to be installed but the module cannot be imported.\n" + f"Try reinstalling: pip install --force-reinstall {env_info.package_name}\n" + f"Original error: {e}" + ) from e + + # Determine Docker image to use + if docker_image is None: + docker_image = env_info.default_image + + # Create client instance + try: + if base_url: + # Check if the server at base_url is available + is_local = cls._is_local_url(base_url) + server_available = cls._check_server_availability(base_url) + + if server_available: + # Server is running, connect directly + logger.info( + f"✅ Server available at {base_url}, connecting directly" + ) + return client_class(base_url=base_url, provider=None, **kwargs) + elif is_local: + # Local server not running, auto-start Docker container + logger.info(f"❌ Server not available at {base_url}") + logger.info(f"🐳 Auto-starting Docker container: {docker_image}") + return run_async_safely( + client_class.from_docker_image( + image=docker_image, + provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars or {}, + **kwargs, + ) + ) + else: + # Remote server not available, cannot auto-start + raise ConnectionError( + f"Remote server not available at {base_url}. " + f"Please ensure the server is running." + ) + else: + # No base_url provided, start new Docker container + return run_async_safely( + client_class.from_docker_image( + image=docker_image, + provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars or {}, + **kwargs, + ) + ) + except Exception as e: + raise ValueError( + f"Failed to create environment client for '{env_name}'.\n" + f"Client class: {client_class.__name__}\n" + f"Docker image: {docker_image}\n" + f"Error: {e}" + ) from e + + @classmethod + def from_hub( + cls, + name: str, + base_url: Optional[str] = None, + docker_image: Optional[str] = None, + container_provider: Optional["ContainerProvider"] = None, + wait_timeout: float = 30.0, + env_vars: Optional[Dict[str, str]] = None, + trust_remote_code: bool = False, + skip_install: bool = False, + **kwargs: Any, + ) -> "EnvClient": + """ + Create an environment client from a name or HuggingFace Hub repository. + + This is an alias for from_env() for backward compatibility. + + Args: + name: Environment name or HuggingFace Hub repo ID + base_url: Optional base URL for HTTP connection + docker_image: Optional Docker image name (overrides default) + container_provider: Optional container provider + wait_timeout: Timeout for container startup (seconds) + env_vars: Optional environment variables for the container + trust_remote_code: If True, skip user confirmation when installing + from HuggingFace Hub + skip_install: If True, skip package installation and return a + GenericEnvClient for remote environments + **kwargs: Additional arguments passed to the client class + + Returns: + Instance of the environment client class + + Examples: + >>> env = AutoEnv.from_hub("coding-env") + >>> env = AutoEnv.from_hub("meta-pytorch/coding-env") + """ + return cls.from_env( + name=name, + base_url=base_url, + docker_image=docker_image, + container_provider=container_provider, + wait_timeout=wait_timeout, + env_vars=env_vars, + trust_remote_code=trust_remote_code, + skip_install=skip_install, + **kwargs, + ) + + @classmethod + def get_env_class(cls, name: str): + """ + Get the environment client class without instantiating it. + + Args: + name: Environment name + + Returns: + The environment client class + + Raises: + ValueError: If environment not found + + Examples: + >>> CodingEnv = AutoEnv.get_env_class("coding") + >>> # Now you can instantiate it yourself + >>> env = CodingEnv(base_url="http://localhost:8000") + """ + discovery = get_discovery() + env_info = discovery.get_environment_by_name(name) + + if not env_info: + raise ValueError(f"Unknown environment: {name}") + + return env_info.get_client_class() + + @classmethod + def get_env_info(cls, name: str) -> Dict[str, Any]: + """ + Get detailed information about an environment. + + Args: + name: Environment name + + Returns: + Dictionary with environment metadata + + Raises: + ValueError: If environment not found + + Examples: + >>> info = AutoEnv.get_env_info("coding") + >>> print(info['description']) + 'Coding environment for OpenEnv' + >>> print(info['default_image']) + 'coding-env:latest' + """ + discovery = get_discovery() + env_info = discovery.get_environment_by_name(name) + + if not env_info: + raise ValueError(f"Unknown environment: {name}") + + return { + "env_key": env_info.env_key, + "name": env_info.name, + "package": env_info.package_name, + "version": env_info.version, + "description": env_info.description, + "env_class": env_info.client_class_name, + "action_class": env_info.action_class_name, + "observation_class": env_info.observation_class_name, + "module": env_info.client_module_path, + "default_image": env_info.default_image, + "spec_version": env_info.spec_version, + } + + @classmethod + def list_environments(cls) -> None: + """ + Print a formatted list of all available environments. + + This discovers all installed openenv-* packages and displays + their metadata in a user-friendly format. + + Examples: + >>> AutoEnv.list_environments() + Available OpenEnv Environments: + ---------------------------------------------------------------------- + echo : Echo Environment (v0.1.0) + Package: openenv-echo-env + coding : Coding Environment (v0.1.0) + Package: openenv-coding_env + ---------------------------------------------------------------------- + Total: 2 environments + """ + discovery = get_discovery() + discovery.list_environments() diff --git a/src/openenv/cli/__init__.py b/src/openenv/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40bee4e3ecf31e272806785b2cd7e05ff0000564 --- /dev/null +++ b/src/openenv/cli/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenEnv CLI package.""" + +__version__ = "0.1.0" diff --git a/src/openenv/cli/__main__.py b/src/openenv/cli/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b457cb7e1f430771bd310c120972a57c06cd661 --- /dev/null +++ b/src/openenv/cli/__main__.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv CLI entry point. + +This module provides the main entry point for the OpenEnv command-line interface, +following the Hugging Face CLI pattern. +""" + +import sys + +import typer +from openenv.cli.commands import build, fork, init, push, serve, skills, validate + +# Create the main CLI app +app = typer.Typer( + name="openenv", + help="OpenEnv - An e2e framework for creating, deploying and using isolated execution environments for agentic RL training", + no_args_is_help=True, +) + +# Register commands +app.command(name="init", help="Initialize a new OpenEnv environment")(init.init) +app.command(name="build", help="Build Docker images for OpenEnv environments")( + build.build +) +app.command( + name="validate", help="Validate environment structure and deployment readiness" +)(validate.validate) +app.command( + name="push", + help="Push an OpenEnv environment to Hugging Face Spaces or custom registry", +)(push.push) +app.command(name="serve", help="Serve environments locally (TODO: Phase 4)")( + serve.serve +) +app.command( + name="fork", + help="Fork (duplicate) a Hugging Face Space to your account", +)(fork.fork) +app.add_typer( + skills.app, + name="skills", + help="Manage OpenEnv skills for AI assistants", +) + + +# Entry point for setuptools +def main() -> None: + """Main entry point for the CLI.""" + try: + app() + except KeyboardInterrupt: + print("\nOperation cancelled by user.") + sys.exit(130) + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/openenv/cli/_cli_utils.py b/src/openenv/cli/_cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b781bb3e34c60842fd8fc8b0eef7b700eb8461e0 --- /dev/null +++ b/src/openenv/cli/_cli_utils.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""CLI utilities for OpenEnv command-line interface.""" + +from pathlib import Path +from typing import List + +from rich.console import Console + +# Create a console instance for CLI output +console = Console() + + +def validate_env_structure(env_dir: Path, strict: bool = False) -> List[str]: + """ + Validate that the directory follows OpenEnv environment structure. + + Args: + env_dir: Path to environment directory + strict: If True, enforce all optional requirements + + Returns: + List of validation warnings (empty if all checks pass) + + Raises: + FileNotFoundError: If required files are missing + """ + warnings = [] + + # Required files + required_files = [ + "openenv.yaml", + "__init__.py", + "client.py", + "models.py", + "README.md", + ] + + for file in required_files: + if not (env_dir / file).exists(): + raise FileNotFoundError(f"Required file missing: {file}") + + # Dockerfile: must exist in server/ or at env root + has_root_dockerfile = (env_dir / "Dockerfile").exists() + has_server_dockerfile = (env_dir / "server" / "Dockerfile").exists() + + if not has_root_dockerfile and not has_server_dockerfile: + raise FileNotFoundError( + "Required file missing: server/Dockerfile or Dockerfile at env root" + ) + + # When no root Dockerfile, require the traditional server/ layout + if not has_root_dockerfile: + server_dir = env_dir / "server" + if not server_dir.exists() or not server_dir.is_dir(): + raise FileNotFoundError("Required directory missing: server/") + + for file in ["server/__init__.py", "server/app.py"]: + if not (env_dir / file).exists(): + raise FileNotFoundError(f"Required file missing: {file}") + + # Check for dependency management (pyproject.toml required) + has_pyproject = (env_dir / "pyproject.toml").exists() + + if not has_pyproject: + raise FileNotFoundError( + "No dependency specification found. 'pyproject.toml' is required." + ) + + # Warnings for recommended structure + + if not (env_dir / "outputs").exists(): + warnings.append("Recommended directory missing: outputs/") + + return warnings diff --git a/src/openenv/cli/_validation.py b/src/openenv/cli/_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..60ea7cc58b3fb22c6247280ecbf64fe762433549 --- /dev/null +++ b/src/openenv/cli/_validation.py @@ -0,0 +1,594 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Validation utilities for multi-mode deployment readiness. + +This module provides functions to check if environments are properly +configured for multi-mode deployment (Docker, direct Python, notebooks, clusters). +""" + +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +import requests + +try: + import tomllib +except ModuleNotFoundError: + import tomli as tomllib + + +def _make_criterion( + criterion_id: str, + description: str, + passed: bool, + *, + required: bool = True, + details: str | None = None, + expected: Any | None = None, + actual: Any | None = None, +) -> dict[str, Any]: + """Create a standard criterion result payload.""" + criterion: dict[str, Any] = { + "id": criterion_id, + "description": description, + "passed": passed, + "required": required, + } + if details is not None: + criterion["details"] = details + if expected is not None: + criterion["expected"] = expected + if actual is not None: + criterion["actual"] = actual + return criterion + + +def _normalize_runtime_url(base_url: str) -> str: + """Normalize and validate a runtime target URL.""" + target = base_url.strip() + if not target: + raise ValueError("Runtime URL cannot be empty") + + if "://" not in target: + target = f"http://{target}" + + parsed = urlparse(target) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"Invalid runtime URL: {base_url}") + + return target.rstrip("/") + + +def _runtime_standard_profile(api_version: str) -> str: + """Resolve the runtime standard profile for an API version.""" + if api_version.startswith("1."): + return "openenv-http/1.x" + return "openenv-http/unknown" + + +def _build_summary(criteria: list[dict[str, Any]]) -> dict[str, Any]: + """Build a compact pass/fail summary for a criteria list.""" + total_count = len(criteria) + passed_count = sum(1 for criterion in criteria if criterion.get("passed", False)) + failed_criteria = [ + criterion.get("id", "unknown") + for criterion in criteria + if not criterion.get("passed", False) + ] + required_criteria = [ + criterion for criterion in criteria if criterion.get("required", True) + ] + required_total_count = len(required_criteria) + required_passed_count = sum( + 1 for criterion in required_criteria if criterion.get("passed", False) + ) + + return { + "passed_count": passed_count, + "total_count": total_count, + "failed_criteria": failed_criteria, + "required_passed_count": required_passed_count, + "required_total_count": required_total_count, + } + + +def validate_running_environment( + base_url: str, timeout_s: float = 5.0 +) -> dict[str, Any]: + """ + Validate a running OpenEnv server against runtime API standards. + + The returned JSON report contains an overall pass/fail result and + per-criterion outcomes that can be consumed in CI. + """ + normalized_url = _normalize_runtime_url(base_url) + criteria: list[dict[str, Any]] = [] + + report: dict[str, Any] = { + "target": normalized_url, + "validation_type": "running_environment", + "standard_version": "unknown", + "standard_profile": "openenv-http/unknown", + "mode": "unknown", + "passed": False, + "summary": {}, + "criteria": criteria, + } + + openapi_paths: dict[str, Any] = {} + api_version = "unknown" + + # Criterion: OpenAPI endpoint reachable with a declared version. + try: + openapi_response = requests.get( + f"{normalized_url}/openapi.json", timeout=timeout_s + ) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "openapi_version_available", + "GET /openapi.json returns OpenAPI info.version", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "info.version": "string"}, + ) + ) + else: + try: + openapi_json = openapi_response.json() + except ValueError: + openapi_json = None + + openapi_ok = ( + openapi_response.status_code == 200 + and isinstance(openapi_json, dict) + and isinstance(openapi_json.get("info"), dict) + and isinstance(openapi_json["info"].get("version"), str) + ) + + if openapi_ok: + api_version = str(openapi_json["info"]["version"]) + openapi_paths = openapi_json.get("paths", {}) + criteria.append( + _make_criterion( + "openapi_version_available", + "GET /openapi.json returns OpenAPI info.version", + True, + expected={"status_code": 200, "info.version": "string"}, + actual={ + "status_code": openapi_response.status_code, + "info.version": api_version, + }, + ) + ) + else: + criteria.append( + _make_criterion( + "openapi_version_available", + "GET /openapi.json returns OpenAPI info.version", + False, + details="Response missing required OpenAPI info.version field", + expected={"status_code": 200, "info.version": "string"}, + actual={ + "status_code": openapi_response.status_code, + "body_type": ( + type(openapi_json).__name__ + if openapi_json is not None + else "non_json" + ), + }, + ) + ) + + report["standard_version"] = api_version + report["standard_profile"] = _runtime_standard_profile(api_version) + + # Criterion: Health endpoint. + try: + health_response = requests.get(f"{normalized_url}/health", timeout=timeout_s) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "health_endpoint", + "GET /health returns healthy status", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "status": "healthy"}, + ) + ) + else: + try: + health_json = health_response.json() + except ValueError: + health_json = None + + health_ok = ( + health_response.status_code == 200 + and isinstance(health_json, dict) + and health_json.get("status") == "healthy" + ) + criteria.append( + _make_criterion( + "health_endpoint", + "GET /health returns healthy status", + health_ok, + expected={"status_code": 200, "status": "healthy"}, + actual={ + "status_code": health_response.status_code, + "status": ( + health_json.get("status") + if isinstance(health_json, dict) + else None + ), + }, + ) + ) + + # Criterion: Metadata endpoint has required fields. + try: + metadata_response = requests.get( + f"{normalized_url}/metadata", timeout=timeout_s + ) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "metadata_endpoint", + "GET /metadata returns name and description", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "fields": ["name", "description"]}, + ) + ) + else: + try: + metadata_json = metadata_response.json() + except ValueError: + metadata_json = None + + metadata_ok = ( + metadata_response.status_code == 200 + and isinstance(metadata_json, dict) + and isinstance(metadata_json.get("name"), str) + and isinstance(metadata_json.get("description"), str) + ) + criteria.append( + _make_criterion( + "metadata_endpoint", + "GET /metadata returns name and description", + metadata_ok, + expected={"status_code": 200, "fields": ["name", "description"]}, + actual={ + "status_code": metadata_response.status_code, + "name": ( + metadata_json.get("name") + if isinstance(metadata_json, dict) + else None + ), + "description": ( + metadata_json.get("description") + if isinstance(metadata_json, dict) + else None + ), + }, + ) + ) + + # Criterion: Schema endpoint returns action/observation/state. + try: + schema_response = requests.get(f"{normalized_url}/schema", timeout=timeout_s) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "schema_endpoint", + "GET /schema returns action, observation, and state schemas", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={ + "status_code": 200, + "fields": ["action", "observation", "state"], + }, + ) + ) + else: + try: + schema_json = schema_response.json() + except ValueError: + schema_json = None + + schema_ok = ( + schema_response.status_code == 200 + and isinstance(schema_json, dict) + and isinstance(schema_json.get("action"), dict) + and isinstance(schema_json.get("observation"), dict) + and isinstance(schema_json.get("state"), dict) + ) + criteria.append( + _make_criterion( + "schema_endpoint", + "GET /schema returns action, observation, and state schemas", + schema_ok, + expected={ + "status_code": 200, + "fields": ["action", "observation", "state"], + }, + actual={ + "status_code": schema_response.status_code, + "has_action": ( + isinstance(schema_json.get("action"), dict) + if isinstance(schema_json, dict) + else False + ), + "has_observation": ( + isinstance(schema_json.get("observation"), dict) + if isinstance(schema_json, dict) + else False + ), + "has_state": ( + isinstance(schema_json.get("state"), dict) + if isinstance(schema_json, dict) + else False + ), + }, + ) + ) + + # Criterion: MCP endpoint is reachable. + try: + mcp_response = requests.post( + f"{normalized_url}/mcp", json={}, timeout=timeout_s + ) + except requests.RequestException as exc: + criteria.append( + _make_criterion( + "mcp_endpoint", + "POST /mcp is reachable and returns JSON-RPC payload", + False, + details=f"Request failed: {type(exc).__name__}: {exc}", + expected={"status_code": 200, "jsonrpc": "2.0"}, + ) + ) + else: + try: + mcp_json = mcp_response.json() + except ValueError: + mcp_json = None + + mcp_ok = ( + mcp_response.status_code == 200 + and isinstance(mcp_json, dict) + and mcp_json.get("jsonrpc") == "2.0" + ) + criteria.append( + _make_criterion( + "mcp_endpoint", + "POST /mcp is reachable and returns JSON-RPC payload", + mcp_ok, + expected={"status_code": 200, "jsonrpc": "2.0"}, + actual={ + "status_code": mcp_response.status_code, + "jsonrpc": ( + mcp_json.get("jsonrpc") if isinstance(mcp_json, dict) else None + ), + }, + ) + ) + + # Criterion: mode endpoint contract consistency via OpenAPI paths. + if isinstance(openapi_paths, dict) and openapi_paths: + has_reset = "/reset" in openapi_paths + has_step = "/step" in openapi_paths + has_state = "/state" in openapi_paths + + if has_reset: + report["mode"] = "simulation" + mode_ok = has_step and has_state + expected_paths = {"/reset": True, "/step": True, "/state": True} + else: + report["mode"] = "production" + mode_ok = not has_step and not has_state + expected_paths = {"/reset": False, "/step": False, "/state": False} + + criteria.append( + _make_criterion( + "mode_endpoint_consistency", + "OpenAPI endpoint set matches OpenEnv mode contract", + mode_ok, + expected=expected_paths, + actual={ + "/reset": has_reset, + "/step": has_step, + "/state": has_state, + }, + ) + ) + else: + criteria.append( + _make_criterion( + "mode_endpoint_consistency", + "OpenAPI endpoint set matches OpenEnv mode contract", + False, + details="Cannot determine mode without OpenAPI paths", + expected={"openapi.paths": "present"}, + actual={"openapi.paths": "missing"}, + ) + ) + + report["passed"] = all( + criterion["passed"] for criterion in criteria if criterion.get("required", True) + ) + report["summary"] = _build_summary(criteria) + return report + + +def validate_multi_mode_deployment(env_path: Path) -> tuple[bool, list[str]]: + """ + Validate that an environment is ready for multi-mode deployment. + + Checks: + 1. pyproject.toml exists + 2. uv.lock exists + 3. pyproject.toml has [project.scripts] with server entry point + 4. server/app.py has a main() function + 5. Required dependencies are present + + Returns: + Tuple of (is_valid, list of issues found) + """ + issues = [] + + # Check pyproject.toml exists + pyproject_path = env_path / "pyproject.toml" + if not pyproject_path.exists(): + issues.append("Missing pyproject.toml") + return False, issues + + # Check uv.lock exists + lockfile_path = env_path / "uv.lock" + if not lockfile_path.exists(): + issues.append("Missing uv.lock - run 'uv lock' to generate it") + + # Parse pyproject.toml + try: + with open(pyproject_path, "rb") as f: + pyproject = tomllib.load(f) + except Exception as e: + issues.append(f"Failed to parse pyproject.toml: {e}") + return False, issues + + # Check [project.scripts] section + scripts = pyproject.get("project", {}).get("scripts", {}) + if "server" not in scripts: + issues.append("Missing [project.scripts] server entry point") + + # Check server entry point format + server_entry = scripts.get("server", "") + if server_entry and ":main" not in server_entry: + issues.append( + f"Server entry point should reference main function, got: {server_entry}" + ) + + # Check required dependencies + deps = [dep.lower() for dep in pyproject.get("project", {}).get("dependencies", [])] + has_openenv = any( + dep.startswith("openenv") and not dep.startswith("openenv-core") for dep in deps + ) + has_legacy_core = any(dep.startswith("openenv-core") for dep in deps) + + if not (has_openenv or has_legacy_core): + issues.append( + "Missing required dependency: openenv-core>=0.2.0 (or openenv>=0.2.0)" + ) + + # Check server/app.py exists + server_app = env_path / "server" / "app.py" + if not server_app.exists(): + issues.append("Missing server/app.py") + else: + # Check for main() function (flexible - with or without parameters) + app_content = server_app.read_text(encoding="utf-8") + if "def main(" not in app_content: + issues.append("server/app.py missing main() function") + + # Check if main() is callable + if "__name__" not in app_content or "main()" not in app_content: + issues.append( + "server/app.py main() function not callable (missing if __name__ == '__main__')" + ) + + return len(issues) == 0, issues + + +def get_deployment_modes(env_path: Path) -> dict[str, bool]: + """ + Check which deployment modes are supported by the environment. + + Returns: + Dictionary with deployment mode names and whether they're supported + """ + modes = { + "docker": False, + "openenv_serve": False, + "uv_run": False, + "python_module": False, + } + + # Check Docker (Dockerfile may be in server/ or at env root) + modes["docker"] = (env_path / "server" / "Dockerfile").exists() or ( + env_path / "Dockerfile" + ).exists() + + # Check multi-mode deployment readiness + is_valid, _ = validate_multi_mode_deployment(env_path) + if is_valid: + modes["openenv_serve"] = True + modes["uv_run"] = True + modes["python_module"] = True + + return modes + + +def format_validation_report(env_name: str, is_valid: bool, issues: list[str]) -> str: + """ + Format a validation report for display. + + Returns: + Formatted report string + """ + if is_valid: + return f"[OK] {env_name}: Ready for multi-mode deployment" + + report = [f"[FAIL] {env_name}: Not ready for multi-mode deployment", ""] + report.append("Issues found:") + for issue in issues: + report.append(f" - {issue}") + + return "\n".join(report) + + +def build_local_validation_json_report( + env_name: str, + env_path: Path, + is_valid: bool, + issues: list[str], + deployment_modes: dict[str, bool] | None = None, +) -> dict[str, Any]: + """Build a JSON report for local environment validation.""" + criteria = [ + _make_criterion( + "multi_mode_deployment_readiness", + "Environment structure is ready for multi-mode deployment", + is_valid, + details="No issues found" if is_valid else f"{len(issues)} issue(s) found", + actual={"issues": issues}, + ) + ] + + if deployment_modes: + for mode, supported in deployment_modes.items(): + criteria.append( + _make_criterion( + f"deployment_mode_{mode}", + f"Deployment mode '{mode}' is supported", + supported, + required=False, + ) + ) + + return { + "target": str(env_path), + "environment": env_name, + "validation_type": "local_environment", + "standard_version": "local", + "standard_profile": "openenv-local", + "passed": is_valid, + "summary": _build_summary(criteria), + "criteria": criteria, + "issues": issues, + "deployment_modes": deployment_modes or {}, + } diff --git a/src/openenv/cli/commands/__init__.py b/src/openenv/cli/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f351a32ff5b353b05b1019005253e0c7cdf71c57 --- /dev/null +++ b/src/openenv/cli/commands/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenEnv CLI commands.""" + +from . import build, fork, init, push, serve, skills, validate + +__all__ = ["build", "fork", "init", "push", "serve", "skills", "validate"] diff --git a/src/openenv/cli/commands/build.py b/src/openenv/cli/commands/build.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4d91b0f80754b00943caaeab23774b09b6d987 --- /dev/null +++ b/src/openenv/cli/commands/build.py @@ -0,0 +1,461 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Build Docker images for OpenEnv environments.""" + +from __future__ import annotations + +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Annotated + +import typer + +from .._cli_utils import console + +app = typer.Typer(help="Build Docker images for OpenEnv environments") + + +def _detect_build_context(env_path: Path) -> tuple[str, Path, Path | None]: + """ + Detect whether we're building a standalone or in-repo environment. + + Returns: + tuple: (build_mode, build_context_path, repo_root) + - build_mode: "standalone" or "in-repo" + - build_context_path: Path to use as Docker build context + - repo_root: Path to repo root (None for standalone) + """ + # Ensure env_path is absolute for proper comparison + env_path = env_path.absolute() + + # Check if we're in a git repository + current = env_path + repo_root = None + + # Walk up to find .git directory + for parent in [current] + list(current.parents): + if (parent / ".git").exists(): + repo_root = parent + break + + if repo_root is None: + # Not in a git repo = standalone + return "standalone", env_path, None + + # Check if environment is under envs/ (in-repo pattern) + try: + rel_path = env_path.relative_to(repo_root) + rel_str = str(rel_path) + if ( + rel_str.startswith("envs/") + or rel_str.startswith("envs\\") + or rel_str.startswith("envs/") + ): + # In-repo environment + return "in-repo", repo_root, repo_root + except ValueError: + pass + + # Otherwise, it's standalone (environment outside repo structure) + return "standalone", env_path, None + + +def _prepare_standalone_build(env_path: Path, temp_dir: Path) -> Path: + """ + Prepare a standalone environment for building. + + For standalone builds: + 1. Copy environment to temp directory + 2. Ensure pyproject.toml depends on openenv + + Returns: + Path to the prepared build directory + """ + console.print("[cyan]Preparing standalone build...[/cyan]") + + # Copy environment to temp directory + build_dir = temp_dir / env_path.name + shutil.copytree(env_path, build_dir, symlinks=True) + + console.print(f"[cyan]Copied environment to:[/cyan] {build_dir}") + + # Check if pyproject.toml has openenv dependency + pyproject_path = build_dir / "pyproject.toml" + if pyproject_path.exists(): + with open(pyproject_path, "rb") as f: + try: + import tomli + + pyproject = tomli.load(f) + deps = pyproject.get("project", {}).get("dependencies", []) + + # Check if openenv dependency is declared + has_openenv = any(dep.startswith("openenv") for dep in deps) + + if not has_openenv: + console.print( + "[yellow]Warning:[/yellow] pyproject.toml doesn't list the openenv dependency", + ) + console.print( + "[yellow]You may need to add:[/yellow] openenv>=0.2.0", + ) + except ImportError: + console.print( + "[yellow]Warning:[/yellow] tomli not available, skipping dependency check", + ) + + return build_dir + + +def _prepare_inrepo_build(env_path: Path, repo_root: Path, temp_dir: Path) -> Path: + """ + Prepare an in-repo environment for building. + + For in-repo builds: + 1. Create temp directory with environment and core + 2. Set up structure that matches expected layout + + Returns: + Path to the prepared build directory + """ + console.print("[cyan]Preparing in-repo build...[/cyan]") + + # Copy environment to temp directory + build_dir = temp_dir / env_path.name + shutil.copytree(env_path, build_dir, symlinks=True) + + # Copy OpenEnv package metadata + sources to temp directory. + # Keep the src/ layout since pyproject.toml uses package-dir = {"" = "src"}. + package_src = repo_root / "src" / "openenv" + package_dest = build_dir / "openenv" + if package_src.exists(): + package_dest.mkdir(parents=True, exist_ok=True) + shutil.copytree(package_src, package_dest / "src" / "openenv", symlinks=True) + + for filename in ("pyproject.toml", "README.md"): + src_file = repo_root / filename + if src_file.exists(): + shutil.copy2(src_file, package_dest / filename) + + console.print(f"[cyan]Copied OpenEnv package to:[/cyan] {package_dest}") + + # Update pyproject.toml to reference local OpenEnv copy + pyproject_path = build_dir / "pyproject.toml" + if pyproject_path.exists(): + with open(pyproject_path, "rb") as f: + try: + import tomli + + pyproject = tomli.load(f) + deps = pyproject.get("project", {}).get("dependencies", []) + + # Replace openenv/openenv-core with local reference + new_deps = [] + for dep in deps: + if ( + dep.startswith("openenv-core") + or dep.startswith("openenv_core") + or dep.startswith("openenv") + ): + # Skip - we'll use local core + continue + new_deps.append(dep) + + # Write back with local core reference + pyproject["project"]["dependencies"] = new_deps + [ + "openenv-core @ file:///app/env/openenv" + ] + + # Write updated pyproject.toml + with open(pyproject_path, "wb") as out_f: + import tomli_w + + tomli_w.dump(pyproject, out_f) + + console.print( + "[cyan]Updated pyproject.toml to use local core[/cyan]" + ) + + # Remove old lockfile since dependencies changed + lockfile = build_dir / "uv.lock" + if lockfile.exists(): + lockfile.unlink() + console.print("[cyan]Removed outdated uv.lock[/cyan]") + + except ImportError: + console.print( + "[yellow]Warning:[/yellow] tomli/tomli_w not available, using pyproject.toml as-is", + ) + else: + console.print( + "[yellow]Warning:[/yellow] OpenEnv package not found, building without it" + ) + + console.print(f"[cyan]Build directory prepared:[/cyan] {build_dir}") + return build_dir + + +def _run_command( + cmd: list[str], + cwd: Path | None = None, + check: bool = True, +) -> subprocess.CompletedProcess: + """Run a shell command and handle errors.""" + console.print(f"[bold cyan]Running:[/bold cyan] {' '.join(cmd)}") + try: + result = subprocess.run( + cmd, cwd=cwd, check=check, capture_output=True, text=True + ) + if result.stdout: + console.print(result.stdout) + if result.stderr: + print(result.stderr, file=sys.stderr) + return result + except subprocess.CalledProcessError as e: + print(f"Error running command: {e}", file=sys.stderr) + if e.stdout: + console.print(e.stdout) + if e.stderr: + print(e.stderr, file=sys.stderr) + if check: + raise typer.Exit(1) from e + return e + + +def _build_docker_image( + env_path: Path, + tag: str | None = None, + context_path: Path | None = None, + dockerfile: Path | None = None, + build_args: dict[str, str] | None = None, + no_cache: bool = False, +) -> bool: + """Build Docker image for the environment with smart context detection.""" + + # Detect build context (standalone vs in-repo) + build_mode, detected_context, repo_root = _detect_build_context(env_path) + + console.print(f"[bold cyan]Build mode detected:[/bold cyan] {build_mode}") + + # Use detected context unless explicitly overridden + if context_path is None: + context_path = detected_context + + # Create temporary build directory + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + + # Prepare build directory based on mode + if build_mode == "standalone": + build_dir = _prepare_standalone_build(env_path, temp_dir) + else: # in-repo + build_dir = _prepare_inrepo_build(env_path, repo_root, temp_dir) + + # Determine Dockerfile path + if dockerfile is None: + # Look for Dockerfile in server/ subdirectory + dockerfile = build_dir / "server" / "Dockerfile" + if not dockerfile.exists(): + # Fallback to root of build directory + dockerfile = build_dir / "Dockerfile" + + if not dockerfile.exists(): + console.print( + f"[bold red]Error:[/bold red] Dockerfile not found at {dockerfile}", + ) + return False + + # Generate tag if not provided + if tag is None: + env_name = env_path.name + if env_name.endswith("_env"): + env_name = env_name[:-4] + tag = f"openenv-{env_name}" + + console.print(f"[bold cyan]Building Docker image:[/bold cyan] {tag}") + console.print(f"[bold cyan]Build context:[/bold cyan] {build_dir}") + console.print(f"[bold cyan]Dockerfile:[/bold cyan] {dockerfile}") + + # Prepare build args + if build_args is None: + build_args = {} + + # Add build mode and env name to build args + build_args["BUILD_MODE"] = build_mode + build_args["ENV_NAME"] = env_path.name.replace("_env", "") + + # Build Docker command + cmd = ["docker", "build", "-t", tag, "-f", str(dockerfile)] + + if no_cache: + cmd.append("--no-cache") + + for key, value in build_args.items(): + cmd.extend(["--build-arg", f"{key}={value}"]) + + cmd.append(str(build_dir)) + + result = _run_command(cmd, check=False) + return result.returncode == 0 + + +def _push_docker_image(tag: str, registry: str | None = None) -> bool: + """Push Docker image to registry.""" + if registry: + full_tag = f"{registry}/{tag}" + console.print(f"[bold cyan]Tagging image as {full_tag}[/bold cyan]") + _run_command(["docker", "tag", tag, full_tag]) + tag = full_tag + + console.print(f"[bold cyan]Pushing image:[/bold cyan] {tag}") + result = _run_command(["docker", "push", tag], check=False) + return result.returncode == 0 + + +@app.command() +def build( + env_path: Annotated[ + str | None, + typer.Argument( + help="Path to the environment directory (default: current directory)" + ), + ] = None, + tag: Annotated[ + str | None, + typer.Option( + "--tag", + "-t", + help="Docker image tag (default: openenv-)", + ), + ] = None, + context: Annotated[ + str | None, + typer.Option( + "--context", + "-c", + help="Build context path (default: /server)", + ), + ] = None, + dockerfile: Annotated[ + str | None, + typer.Option( + "--dockerfile", + "-f", + help="Path to Dockerfile (default: /Dockerfile)", + ), + ] = None, + no_cache: Annotated[ + bool, + typer.Option( + "--no-cache", + help="Build without using cache", + ), + ] = False, + build_arg: Annotated[ + list[str] | None, + typer.Option( + "--build-arg", + help="Build arguments (can be used multiple times, format: KEY=VALUE)", + ), + ] = None, +) -> None: + """ + Build Docker images for OpenEnv environments. + + This command builds Docker images using the environment's pyproject.toml + and uv for dependency management. Run from the environment root directory. + + Examples: + # Build from environment root (recommended) + $ cd my_env + $ openenv build + + # Build with custom tag + $ openenv build -t my-custom-tag + + # Build without cache + $ openenv build --no-cache + + # Build with custom build arguments + $ openenv build --build-arg VERSION=1.0 --build-arg ENV=prod + + # Build from different directory + $ openenv build envs/echo_env + """ + # Determine environment path (default to current directory) + if env_path is None: + env_path_obj = Path.cwd() + else: + env_path_obj = Path(env_path) + + # Validate environment path + if not env_path_obj.exists(): + print( + f"Error: Environment path does not exist: {env_path_obj}", + file=sys.stderr, + ) + raise typer.Exit(1) + + if not env_path_obj.is_dir(): + print( + f"Error: Environment path is not a directory: {env_path_obj}", + file=sys.stderr, + ) + raise typer.Exit(1) + + # Check for openenv.yaml to confirm this is an environment directory + openenv_yaml = env_path_obj / "openenv.yaml" + if not openenv_yaml.exists(): + print( + f"Error: Not an OpenEnv environment directory (missing openenv.yaml): {env_path_obj}", + file=sys.stderr, + ) + print( + "Hint: Run this command from the environment root directory or specify the path", + file=sys.stderr, + ) + raise typer.Exit(1) + + console.print(f"[bold]Building Docker image for:[/bold] {env_path_obj.name}") + console.print("=" * 60) + + # Parse build args + build_args = {} + if build_arg: + for arg in build_arg: + if "=" in arg: + key, value = arg.split("=", 1) + build_args[key] = value + else: + print( + f"Warning: Invalid build arg format: {arg}", + file=sys.stderr, + ) + + # Convert string paths to Path objects + context_path_obj = Path(context) if context else None + dockerfile_path_obj = Path(dockerfile) if dockerfile else None + + # Build Docker image + success = _build_docker_image( + env_path=env_path_obj, + tag=tag, + context_path=context_path_obj, + dockerfile=dockerfile_path_obj, + build_args=build_args if build_args else None, + no_cache=no_cache, + ) + + if not success: + print("✗ Docker build failed", file=sys.stderr) + raise typer.Exit(1) + + console.print("[bold green]✓ Docker build successful[/bold green]") + console.print("\n[bold green]Done![/bold green]") diff --git a/src/openenv/cli/commands/fork.py b/src/openenv/cli/commands/fork.py new file mode 100644 index 0000000000000000000000000000000000000000..e06f41f1d8606874446f8edc07d675eb4680fa32 --- /dev/null +++ b/src/openenv/cli/commands/fork.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fork (duplicate) a Hugging Face Space using the Hub API.""" + +from __future__ import annotations + +from typing import Annotated + +import typer +from huggingface_hub import HfApi, login, whoami + +from .._cli_utils import console + +app = typer.Typer( + help="Fork (duplicate) an OpenEnv environment on Hugging Face to your account" +) + + +def _parse_key_value(s: str) -> tuple[str, str]: + """Parse KEY=VALUE string. Raises BadParameter if no '='.""" + if "=" not in s: + raise typer.BadParameter( + f"Expected KEY=VALUE format, got: {s!r}. " + "Use --set-env KEY=VALUE or --set-secret KEY=VALUE" + ) + key, _, value = s.partition("=") + key = key.strip() + if not key: + raise typer.BadParameter(f"Empty key in: {s!r}") + return key, value.strip() + + +def _ensure_hf_authenticated() -> str: + """Ensure user is authenticated with Hugging Face. Returns username.""" + try: + user_info = whoami() + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + if not username: + raise ValueError("Could not extract username from whoami response") + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception: + console.print( + "[bold yellow]Not authenticated with Hugging Face. Please login...[/bold yellow]" + ) + try: + login() + user_info = whoami() + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + if not username: + raise ValueError("Could not extract username from whoami response") + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception as e: + raise typer.BadParameter( + f"Hugging Face authentication failed: {e}. Please run login manually." + ) from e + + +@app.command() +def fork( + source_space: Annotated[ + str, + typer.Argument( + help="Source Space ID in format 'owner/space-name' (e.g. org/my-openenv-space)" + ), + ], + repo_id: Annotated[ + str | None, + typer.Option( + "--repo-id", + "-r", + help="Target repo ID for the fork (default: created under your account with same name)", + ), + ] = None, + private: Annotated[ + bool, + typer.Option("--private", help="Create the forked Space as private"), + ] = False, + set_env: Annotated[ + list[str], + typer.Option( + "--set-env", + "-e", + help="Set Space variable (public). Can be repeated. Format: KEY=VALUE", + ), + ] = [], + set_secret: Annotated[ + list[str], + typer.Option( + "--set-secret", + "--secret", + "-s", + help="Set Space secret. Can be repeated. Format: KEY=VALUE", + ), + ] = [], + hardware: Annotated[ + str | None, + typer.Option( + "--hardware", + "-H", + help="Request hardware (e.g. t4-medium, cpu-basic). See Hub docs for options.", + ), + ] = None, +) -> None: + """ + Fork (duplicate) a Hugging Face Space to your account using the Hub API. + + Uses the Hugging Face duplicate_space API. You can set environment variables + and secrets, and request hardware/storage/sleep time at creation time. + + Examples: + $ openenv fork owner/source-space + $ openenv fork owner/source-space --private + $ openenv fork owner/source-space --repo-id myuser/my-fork + $ openenv fork owner/source-space --set-env MODEL_ID=user/model --set-secret HF_TOKEN=hf_xxx + $ openenv fork owner/source-space --hardware t4-medium + """ + if "/" not in source_space or source_space.count("/") != 1: + raise typer.BadParameter( + f"Invalid source Space ID: {source_space!r}. Expected format: 'owner/space-name'" + ) + + _ensure_hf_authenticated() + api = HfApi() + + # Build kwargs for duplicate_space (only pass what we have) + dup_kwargs: dict = { + "from_id": source_space, + "private": private, + } + if set_env: + dup_kwargs["variables"] = [ + {"key": k, "value": v} for k, v in (_parse_key_value(x) for x in set_env) + ] + if set_secret: + dup_kwargs["secrets"] = [ + {"key": k, "value": v} for k, v in (_parse_key_value(x) for x in set_secret) + ] + # HF API requires hardware when duplicating; default to free cpu-basic + dup_kwargs["hardware"] = hardware if hardware is not None else "cpu-basic" + if repo_id is not None: + if "/" not in repo_id or repo_id.count("/") != 1: + raise typer.BadParameter( + f"Invalid --repo-id: {repo_id!r}. Expected format: 'username/repo-name'" + ) + dup_kwargs["to_id"] = repo_id + + console.print(f"[bold cyan]Forking Space {source_space}...[/bold cyan]") + try: + result = api.duplicate_space(**dup_kwargs) + except Exception as e: + console.print(f"[bold red]✗[/bold red] Fork failed: {e}") + raise typer.Exit(1) from e + + # result is RepoUrl (str-like) or similar; get repo_id for display + if hasattr(result, "repo_id"): + new_repo_id = result.repo_id + elif isinstance(result, str): + # URL like https://huggingface.co/spaces/owner/name -> owner/name + if "/spaces/" in result: + new_repo_id = result.split("/spaces/")[-1].rstrip("/") + else: + new_repo_id = result + else: + new_repo_id = getattr(result, "repo_id", str(result)) + + console.print("[bold green]✓[/bold green] Space forked successfully") + console.print( + f"[bold]Space URL:[/bold] https://huggingface.co/spaces/{new_repo_id}" + ) diff --git a/src/openenv/cli/commands/init.py b/src/openenv/cli/commands/init.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf0fc7168f109657157f3d71600857e5b91f37e --- /dev/null +++ b/src/openenv/cli/commands/init.py @@ -0,0 +1,500 @@ +"""Initialize a new OpenEnv environment.""" + +from __future__ import annotations + +import random +import shutil +import subprocess +from importlib import resources +from pathlib import Path +from typing import Annotated, Dict, List, Tuple + +import typer + +from .._cli_utils import console + +app = typer.Typer(help="Initialize a new OpenEnv environment") + + +def _snake_to_pascal(snake_str: str) -> str: + """Convert snake_case to PascalCase (e.g., 'my_env' -> 'MyEnv').""" + return "".join(word.capitalize() for word in snake_str.split("_")) + + +def _get_env_prefix(env_name: str) -> str: + """Extract the prefix for class names (e.g., 'my_env' -> 'My', 'test_env' -> 'Test').""" + # Remove trailing '_env' if present + if env_name.endswith("_env"): + base = env_name[:-4] # Remove '_env' + else: + base = env_name + + # If empty or just one part, use the whole thing + if not base or "_" not in base: + return base.capitalize() if base else env_name.capitalize() + + # PascalCase all parts except the last + parts = base.split("_") + return "".join(word.capitalize() for word in parts) + + +def _snake_to_camel(snake_str: str) -> str: + """Convert snake_case to camelCase (e.g., 'my_env' -> 'myEnv').""" + parts = snake_str.split("_") + return parts[0] + "".join(word.capitalize() for word in parts[1:]) + + +def _snake_to_title(snake_str: str) -> str: + """Convert snake_case to Title Case (e.g., 'my_env' -> 'My Env').""" + return " ".join(word.capitalize() for word in snake_str.split("_")) + + +def _validate_env_name(name: str) -> str: + """Validate environment name (must be valid Python identifier in snake_case).""" + if not name: + raise typer.BadParameter("Environment name cannot be empty") + + # Check if it's a valid Python identifier + if not name.isidentifier(): + raise typer.BadParameter( + f"Environment name '{name}' is not a valid Python identifier. Use snake_case (e.g., 'my_env', 'game_env')." + ) + + # Check if it starts with a number + if name[0].isdigit(): + raise typer.BadParameter( + f"Environment name '{name}' cannot start with a number." + ) + + return name + + +def _get_random_hf_space_config() -> Dict[str, str]: + """ + Get random Hugging Face Space configuration values. + + Returns: + Dictionary with 'emoji', 'colorFrom', and 'colorTo' keys + """ + # Valid emojis (emoji-only characters) + emojis = [ + "🎮", + "🎯", + "🚀", + "🌟", + "🎨", + "🎪", + "🎭", + "🎬", + "🎤", + "🎧", + "🎵", + "🎶", + "🎸", + "🎹", + "🥁", + "🎺", + "🎻", + "🎼", + "🎯", + "🎲", + "🎳", + "🎰", + "🎴", + "🃏", + "🀄", + "🎴", + "🎨", + "🖼️", + "🎬", + "🎭", + "🎪", + "🎤", + "🎧", + "🎵", + "🎶", + "🎸", + "🎹", + "🎺", + "🎻", + "🥁", + "🎯", + "🎲", + "🎳", + "🎰", + "🏀", + "⚽", + "🏈", + "⚾", + "🎾", + "🏐", + "🏉", + "🎱", + "🏓", + "🏸", + "🥅", + "🏒", + "🏑", + "🏏", + "⛳", + "🏹", + "🎣", + "🥊", + "🥋", + "🎽", + "🏅", + "🎖️", + "🏆", + "🥇", + "🥈", + "🥉", + "🔊", + "🔉", + "🔈", + "🔇", + "📢", + "📣", + "📯", + "🔔", + "🔕", + "📻", + "📡", + "💻", + "🖥️", + "🖨️", + "⌨️", + "🖱️", + "🖲️", + "🕹️", + "🗜️", + "💾", + "💿", + "📀", + "📼", + "📷", + "📸", + "📹", + "🎥", + "📽️", + "🎞️", + "📞", + "☎️", + "📟", + "📠", + "📺", + "📻", + "🎙️", + "🎚️", + "🎛️", + "⏱️", + "⏲️", + "⏰", + "🕰️", + "⌚", + "📱", + "📲", + "💻", + "⌨️", + "🖥️", + "🖨️", + "🖱️", + ] + + # Valid colors from HF Spaces config reference + colors = ["red", "yellow", "green", "blue", "indigo", "purple", "pink", "gray"] + + return { + "emoji": random.choice(emojis), + "colorFrom": random.choice(colors), + "colorTo": random.choice(colors), + } + + +def _create_template_replacements(env_name: str) -> Dict[str, str]: + """ + Create comprehensive template replacement dictionary. + + Supports all naming conventions: + - PascalCase for class names + - camelCase for variable names + - snake_case for module names, file paths + """ + env_prefix = _get_env_prefix(env_name) + env_camel = _snake_to_camel(env_name) + env_title = _snake_to_title(env_name) + + # Get random HF Space config values + hf_config = _get_random_hf_space_config() + + replacements = { + # Template placeholders (MUST come first - full class names before partial) + "__ENV_CLASS_NAME__Environment": f"{env_prefix}Environment", + "__ENV_CLASS_NAME__Action": f"{env_prefix}Action", + "__ENV_CLASS_NAME__Observation": f"{env_prefix}Observation", + "__ENV_CLASS_NAME__Env": f"{env_prefix}Env", + # Template placeholders (partial - must come after full replacements) + "__ENV_NAME__": env_name, + "__ENV_CLASS_NAME__": env_prefix, # Use prefix, not full PascalCase + "__ENV_TITLE_NAME__": env_title, + "__ENV_CAMEL_NAME__": env_camel, + # Hugging Face Space config placeholders + "__HF_EMOJI__": hf_config["emoji"], + "__HF_COLOR_FROM__": hf_config["colorFrom"], + "__HF_COLOR_TO__": hf_config["colorTo"], + } + + return replacements + + +def _replace_in_content(content: str, replacements: Dict[str, str]) -> str: + """Replace all occurrences in content using case-sensitive replacements.""" + result = content + # Sort by length (longest first) to avoid partial replacements + for old, new in sorted(replacements.items(), key=lambda x: len(x[0]), reverse=True): + result = result.replace(old, new) + return result + + +def _should_rename_file(filename: str, env_name: str) -> Tuple[bool, str]: + """ + Check if a file should be renamed and return the new name. + + Handles template placeholders in filenames like: + - `__ENV_NAME___environment.py` → `_environment.py` + """ + # Check for template placeholder + if "__ENV_NAME__" in filename: + new_name = filename.replace("__ENV_NAME__", env_name) + return True, new_name + + return False, filename + + +def _copy_and_template_file( + src_path: Path, + dest_path: Path, + replacements: Dict[str, str], +) -> None: + """Copy a file and apply template replacements.""" + dest_path.parent.mkdir(parents=True, exist_ok=True) + + try: + # Read source file + content = src_path.read_bytes() + + # Try to decode as text and apply replacements + try: + text = content.decode("utf-8") + # Normalize line endings to LF before applying replacements + text = text.replace("\r\n", "\n").replace("\r", "\n") + text = _replace_in_content(text, replacements) + dest_path.write_text(text, encoding="utf-8", newline="\n") + except UnicodeDecodeError: + # Binary file, just copy + dest_path.write_bytes(content) + except Exception as e: + raise RuntimeError( + f"Failed to copy template file {src_path} to {dest_path}: {e}" + ) from e + + +def _copy_template_directory( + template_pkg: str, + template_dir: str, + dest_dir: Path, + replacements: Dict[str, str], + env_name: str, +) -> List[Path]: + """Recursively copy template directory and apply replacements.""" + created_files: List[Path] = [] + + # Get the package path using importlib.resources but avoid importing the template package + # We'll use the package's __file__ to get the directory path + import importlib + + try: + # Import the parent package (not the template package itself) + if "." in template_pkg: + parent_pkg = ".".join(template_pkg.split(".")[:-1]) + pkg = importlib.import_module(parent_pkg) + template_path = Path(pkg.__file__).parent / template_pkg.split(".")[-1] + else: + pkg = importlib.import_module(template_pkg.split(".")[0]) + template_path = Path(pkg.__file__).parent / template_pkg.split(".")[-1] + except Exception: + # Fallback: try to use resources.files but handle import errors + try: + base = resources.files(template_pkg.split(".")[0]) + template_path = base.joinpath(*template_pkg.split(".")[1:]) + if not template_path.exists(): + raise FileNotFoundError(f"Template directory not found: {template_pkg}") + except Exception as e: + raise FileNotFoundError( + f"Template directory not found: {template_pkg}" + ) from e + + if template_dir: + template_path = template_path / template_dir + + if not template_path.exists() or not template_path.is_dir(): + raise FileNotFoundError( + f"Template directory not found: {template_pkg}.{template_dir}" + ) + + # Walk through all files in template directory using Path + for item in template_path.rglob("*"): + if item.is_file(): + rel_path = item.relative_to(template_path) + dest_path = dest_dir / rel_path + + # Apply filename templating + should_rename, new_name = _should_rename_file(dest_path.name, env_name) + if should_rename: + dest_path = dest_path.parent / new_name + + # Copy and apply replacements + _copy_and_template_file(item, dest_path, replacements) + created_files.append(dest_path) + + return created_files + + +def _generate_uv_lock(env_dir: Path) -> bool: + """Generate uv.lock from pyproject.toml using uv.""" + pyproject_path = env_dir / "pyproject.toml" + + if not pyproject_path.exists(): + return False + + try: + cmd = [ + "uv", + "lock", + "--directory", + str(env_dir), + ] + + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + + if result.stdout: + console.print(result.stdout) + + return True + + except subprocess.CalledProcessError as e: + console.print( + f"[yellow]Warning: Could not generate uv.lock: {e.stderr}[/yellow]" + ) + return False + except FileNotFoundError: + console.print( + "[yellow]Warning: 'uv' not found. Install it to generate uv.lock[/yellow]" + ) + return False + + +@app.command() +def init( + env_name: Annotated[ + str, + typer.Argument( + help="Name of the environment to create (snake_case, e.g., 'my_env')" + ), + ], + output_dir: Annotated[ + str | None, + typer.Option( + "--output-dir", + "-o", + help="Output directory (defaults to current working directory)", + ), + ] = None, +) -> None: + """ + Initialize a new OpenEnv environment. + + Creates a new directory with the environment name and generates all necessary + files based on the OpenEnv template structure. + + Example: + $ openenv init my_game_env + $ openenv init my_env --output-dir /path/to/projects + """ + # Validate environment name + env_name = _validate_env_name(env_name) + + # Determine output directory + base_dir = Path(output_dir).resolve() if output_dir else Path.cwd().resolve() + env_dir = base_dir / env_name + + # Check if directory already exists + if env_dir.exists(): + if env_dir.is_file(): + raise typer.BadParameter(f"Path '{env_dir}' exists and is a file") + if any(env_dir.iterdir()): + raise typer.BadParameter( + f"Directory '{env_dir}' already exists and is not empty. " + "Please choose a different name or remove the existing directory." + ) + + try: + # Create template replacements + replacements = _create_template_replacements(env_name) + + # Create environment directory + env_dir.mkdir(parents=True, exist_ok=True) + + console.print( + f"[bold cyan]Creating OpenEnv environment '{env_name}'...[/bold cyan]" + ) + + # Copy template files from template structure + template_pkg = "openenv.cli.templates.openenv_env" + created_files = _copy_template_directory( + template_pkg, + "", + env_dir, + replacements, + env_name, + ) + + console.print(f"[bold green]✓[/bold green] Created {len(created_files)} files") + + # Generate uv.lock + console.print("\n[bold]Generating uv.lock...[/bold]") + if _generate_uv_lock(env_dir): + console.print("[green]✓[/green] Generated uv.lock") + else: + console.print("[yellow]⚠[/yellow] Could not generate uv.lock automatically") + console.print(" You can generate it manually with:") + console.print(f" cd {env_dir} && uv lock") + + console.print( + f"\n[bold green]Environment created successfully at: {env_dir}[/bold green]" + ) + console.print("\n[bold]Next steps:[/bold]") + console.print(f" cd {env_dir}") + console.print( + f" # Edit your environment implementation in server/{env_name}_environment.py" + ) + console.print(" # Edit your models in models.py") + console.print(" # Install dependencies: uv sync") + console.print("\n # To integrate into OpenEnv repo:") + console.print(f" # 1. Copy this directory to /envs/{env_name}_env") + console.print( + f" # 2. Build from repo root: docker build -t {env_name}_env:latest -f envs/{env_name}_env/server/Dockerfile ." + ) + console.print( + f" # 3. Run your image: docker run -p 8000:8000 {env_name}_env:latest" + ) + + except Exception as e: + # Cleanup on error + if env_dir.exists() and env_dir.is_dir(): + try: + shutil.rmtree(env_dir) + except Exception: + pass + + console.print(f"[bold red]Error:[/bold red] {e}") + raise typer.Exit(1) from e diff --git a/src/openenv/cli/commands/push.py b/src/openenv/cli/commands/push.py new file mode 100644 index 0000000000000000000000000000000000000000..beb571c239734d8f826a42e91f34cdd5845a44ff --- /dev/null +++ b/src/openenv/cli/commands/push.py @@ -0,0 +1,718 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Push an OpenEnv environment to Hugging Face Spaces.""" + +from __future__ import annotations + +import shutil +import sys +import tempfile +from fnmatch import fnmatch +from pathlib import Path +from typing import Annotated + +import typer +import yaml +from huggingface_hub import HfApi, login, whoami + +from .._cli_utils import console, validate_env_structure + +app = typer.Typer(help="Push an OpenEnv environment to Hugging Face Spaces") + + +DEFAULT_PUSH_IGNORE_PATTERNS = [".*", "__pycache__", "*.pyc"] + + +def _path_matches_pattern(relative_path: Path, pattern: str) -> bool: + """Return True if a relative path matches an exclude pattern.""" + normalized_pattern = pattern.strip() + if normalized_pattern.startswith("!"): + return False + + while normalized_pattern.startswith("./"): + normalized_pattern = normalized_pattern[2:] + + if normalized_pattern.startswith("/"): + normalized_pattern = normalized_pattern[1:] + + if not normalized_pattern: + return False + + posix_path = relative_path.as_posix() + pattern_candidates = [normalized_pattern] + if normalized_pattern.startswith("**/"): + # Gitignore-style "**/" can also match directly at the root. + pattern_candidates.append(normalized_pattern[3:]) + + # Support directory patterns such as "artifacts/" and "**/outputs/". + if normalized_pattern.endswith("/"): + dir_pattern_candidates: list[str] = [] + for candidate in pattern_candidates: + base = candidate.rstrip("/") + if not base: + continue + dir_pattern_candidates.extend([base, f"{base}/*"]) + + return any( + fnmatch(posix_path, candidate) for candidate in dir_pattern_candidates + ) + + # Match both full relative path and basename for convenience. + return any( + fnmatch(posix_path, candidate) for candidate in pattern_candidates + ) or any(fnmatch(relative_path.name, candidate) for candidate in pattern_candidates) + + +def _should_exclude_path(relative_path: Path, ignore_patterns: list[str]) -> bool: + """Return True when the path should be excluded from staging/upload.""" + return any( + _path_matches_pattern(relative_path, pattern) for pattern in ignore_patterns + ) + + +def _read_ignore_file(ignore_path: Path) -> tuple[list[str], int]: + """Read ignore patterns from a file and return (patterns, ignored_negations).""" + patterns: list[str] = [] + ignored_negations = 0 + + for line in ignore_path.read_text().splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if stripped.startswith("!"): + ignored_negations += 1 + continue + patterns.append(stripped) + + return patterns, ignored_negations + + +def _load_ignore_patterns(env_dir: Path, exclude_file: str | None) -> list[str]: + """Load ignore patterns from defaults and an optional ignore file.""" + patterns = list(DEFAULT_PUSH_IGNORE_PATTERNS) + ignored_negations = 0 + + def _merge_ignore_file(ignore_path: Path, *, source_label: str) -> None: + nonlocal ignored_negations + file_patterns, skipped_negations = _read_ignore_file(ignore_path) + patterns.extend(file_patterns) + ignored_negations += skipped_negations + console.print( + f"[bold green]✓[/bold green] Loaded {len(file_patterns)} ignore patterns from {source_label}: {ignore_path}" + ) + + # Optional source: explicit exclude file from CLI. + if exclude_file: + ignore_path = Path(exclude_file) + if not ignore_path.is_absolute(): + ignore_path = env_dir / ignore_path + ignore_path = ignore_path.resolve() + + if not ignore_path.exists() or not ignore_path.is_file(): + raise typer.BadParameter( + f"Exclude file not found or not a file: {ignore_path}" + ) + + _merge_ignore_file(ignore_path, source_label="--exclude") + + # Keep stable order while removing duplicates. + patterns = list(dict.fromkeys(patterns)) + + if ignored_negations > 0: + console.print( + f"[bold yellow]⚠[/bold yellow] Skipped {ignored_negations} negated ignore patterns ('!') because negation is not supported for push excludes" + ) + + return patterns + + +def _copytree_ignore_factory(env_dir: Path, ignore_patterns: list[str]): + """Build a shutil.copytree ignore callback from path-based patterns.""" + + def _ignore(path: str, names: list[str]) -> set[str]: + current_dir = Path(path) + ignored: set[str] = set() + + for name in names: + candidate = current_dir / name + try: + relative_path = candidate.relative_to(env_dir) + except ValueError: + # candidate is not under env_dir (e.g. symlink or + # copytree root differs from env_dir); skip filtering. + continue + if _should_exclude_path(relative_path, ignore_patterns): + ignored.add(name) + + return ignored + + return _ignore + + +def _validate_openenv_directory(directory: Path) -> tuple[str, dict]: + """ + Validate that the directory is an OpenEnv environment. + + Returns: + Tuple of (env_name, manifest_data) + """ + # Use the comprehensive validation function + try: + warnings = validate_env_structure(directory) + for warning in warnings: + console.print(f"[bold yellow]⚠[/bold yellow] {warning}") + except FileNotFoundError as e: + raise typer.BadParameter(f"Invalid OpenEnv environment structure: {e}") from e + + # Load and validate manifest + manifest_path = directory / "openenv.yaml" + try: + with open(manifest_path, "r") as f: + manifest = yaml.safe_load(f) + except Exception as e: + raise typer.BadParameter(f"Failed to parse openenv.yaml: {e}") from e + + if not isinstance(manifest, dict): + raise typer.BadParameter("openenv.yaml must be a YAML dictionary") + + env_name = manifest.get("name") + if not env_name: + raise typer.BadParameter("openenv.yaml must contain a 'name' field") + + return env_name, manifest + + +def _ensure_hf_authenticated() -> str: + """ + Ensure user is authenticated with Hugging Face. + + Returns: + Username of authenticated user + """ + try: + # Try to get current user + user_info = whoami() + # Handle both dict and object return types + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + # If it's an object, try to get name attribute + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + + if not username: + raise ValueError("Could not extract username from whoami response") + + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception: + # Not authenticated, prompt for login + console.print( + "[bold yellow]Not authenticated with Hugging Face. Please login...[/bold yellow]" + ) + + try: + login() + # Verify login worked + user_info = whoami() + # Handle both dict and object return types + if isinstance(user_info, dict): + username = ( + user_info.get("name") + or user_info.get("fullname") + or user_info.get("username") + ) + else: + username = ( + getattr(user_info, "name", None) + or getattr(user_info, "fullname", None) + or getattr(user_info, "username", None) + ) + + if not username: + raise ValueError("Could not extract username from whoami response") + + console.print(f"[bold green]✓[/bold green] Authenticated as: {username}") + return username + except Exception as e: + raise typer.BadParameter( + f"Hugging Face authentication failed: {e}. Please run login manually." + ) from e + + +def _prepare_staging_directory( + env_dir: Path, + env_name: str, + staging_dir: Path, + ignore_patterns: list[str], + base_image: str | None = None, + enable_interface: bool = True, +) -> None: + """ + Prepare files for deployment. + + This includes: + - Copying necessary files + - Modifying Dockerfile to optionally enable web interface and update base image + - Ensuring README has proper HF frontmatter (if interface enabled) + """ + # Create staging directory structure + staging_dir.mkdir(parents=True, exist_ok=True) + + # Copy all files from env directory + copy_ignore = _copytree_ignore_factory(env_dir, ignore_patterns) + for item in env_dir.iterdir(): + relative_path = item.relative_to(env_dir) + if _should_exclude_path(relative_path, ignore_patterns): + continue + + dest = staging_dir / item.name + if item.is_dir(): + shutil.copytree(item, dest, dirs_exist_ok=True, ignore=copy_ignore) + else: + shutil.copy2(item, dest) + + # Dockerfile must be at repo root for Hugging Face. Prefer root if present + # (it was copied there); otherwise move server/Dockerfile to root. + dockerfile_server_path = staging_dir / "server" / "Dockerfile" + dockerfile_root_path = staging_dir / "Dockerfile" + dockerfile_path: Path | None = None + + if dockerfile_root_path.exists(): + dockerfile_path = dockerfile_root_path + elif dockerfile_server_path.exists(): + dockerfile_server_path.rename(dockerfile_root_path) + console.print( + "[bold cyan]Moved Dockerfile to repository root for deployment[/bold cyan]" + ) + dockerfile_path = dockerfile_root_path + + # Modify Dockerfile to optionally enable web interface and update base image + if dockerfile_path and dockerfile_path.exists(): + dockerfile_content = dockerfile_path.read_text() + lines = dockerfile_content.split("\n") + new_lines = [] + cmd_found = False + base_image_updated = False + web_interface_env_exists = "ENABLE_WEB_INTERFACE" in dockerfile_content + last_instruction = None + + for line in lines: + stripped = line.strip() + token = stripped.split(maxsplit=1)[0] if stripped else "" + current_instruction = token.upper() + + is_healthcheck_continuation = last_instruction == "HEALTHCHECK" + + # Update base image if specified + if base_image and stripped.startswith("FROM") and not base_image_updated: + new_lines.append(f"FROM {base_image}") + base_image_updated = True + last_instruction = "FROM" + continue + + if ( + stripped.startswith("CMD") + and not cmd_found + and not web_interface_env_exists + and enable_interface + and not is_healthcheck_continuation + ): + new_lines.append("ENV ENABLE_WEB_INTERFACE=true") + cmd_found = True + + new_lines.append(line) + + if current_instruction: + last_instruction = current_instruction + + if not cmd_found and not web_interface_env_exists and enable_interface: + new_lines.append("ENV ENABLE_WEB_INTERFACE=true") + + if base_image and not base_image_updated: + new_lines.insert(0, f"FROM {base_image}") + + dockerfile_path.write_text("\n".join(new_lines)) + + changes = [] + if base_image and base_image_updated: + changes.append("updated base image") + if enable_interface and not web_interface_env_exists: + changes.append("enabled web interface") + if changes: + console.print( + f"[bold green]✓[/bold green] Updated Dockerfile: {', '.join(changes)}" + ) + else: + console.print( + "[bold yellow]⚠[/bold yellow] No Dockerfile at server/ or repo root" + ) + + # Ensure README has proper HF frontmatter (only if interface enabled) + if enable_interface: + readme_path = staging_dir / "README.md" + if readme_path.exists(): + readme_content = readme_path.read_text() + if "base_path: /web" not in readme_content: + # Check if frontmatter exists + if readme_content.startswith("---"): + # Add base_path to existing frontmatter + lines = readme_content.split("\n") + new_lines = [] + _in_frontmatter = True + for i, line in enumerate(lines): + new_lines.append(line) + if line.strip() == "---" and i > 0: + # End of frontmatter, add base_path before this line + if "base_path:" not in "\n".join(new_lines): + new_lines.insert(-1, "base_path: /web") + _in_frontmatter = False + readme_path.write_text("\n".join(new_lines)) + else: + # No frontmatter, add it + frontmatter = f"""--- +title: {env_name.replace("_", " ").title()} Environment Server +emoji: 🔊 +colorFrom: '#00C9FF' +colorTo: '#1B2845' +sdk: docker +pinned: false +app_port: 8000 +base_path: /web +tags: + - openenv +--- + +""" + readme_path.write_text(frontmatter + readme_content) + console.print( + "[bold green]✓[/bold green] Updated README with HF Space frontmatter" + ) + else: + console.print("[bold yellow]⚠[/bold yellow] No README.md found") + + +def _create_hf_space( + repo_id: str, + api: HfApi, + private: bool = False, +) -> None: + """Create a Hugging Face Space if it doesn't exist.""" + console.print(f"[bold cyan]Creating/verifying space: {repo_id}[/bold cyan]") + + try: + api.create_repo( + repo_id=repo_id, + repo_type="space", + space_sdk="docker", + private=private, + exist_ok=True, + ) + console.print(f"[bold green]✓[/bold green] Space {repo_id} is ready") + except Exception as e: + # Space might already exist, which is okay with exist_ok=True + # But if there's another error, log it + console.print(f"[bold yellow]⚠[/bold yellow] Space creation: {e}") + + +def _upload_to_hf_space( + repo_id: str, + staging_dir: Path, + api: HfApi, + ignore_patterns: list[str], + private: bool = False, + create_pr: bool = False, + commit_message: str | None = None, +) -> None: + """Upload files to Hugging Face Space.""" + if create_pr: + console.print( + f"[bold cyan]Uploading files to {repo_id} (will open a Pull Request)...[/bold cyan]" + ) + else: + console.print(f"[bold cyan]Uploading files to {repo_id}...[/bold cyan]") + + upload_kwargs: dict = { + "folder_path": str(staging_dir), + "repo_id": repo_id, + "repo_type": "space", + "create_pr": create_pr, + "ignore_patterns": ignore_patterns, + } + if commit_message: + upload_kwargs["commit_message"] = commit_message + + try: + result = api.upload_folder(**upload_kwargs) + console.print("[bold green]✓[/bold green] Upload completed successfully") + if create_pr and result is not None and hasattr(result, "pr_url"): + console.print(f"[bold]Pull request:[/bold] {result.pr_url}") + console.print( + f"[bold]Space URL:[/bold] https://huggingface.co/spaces/{repo_id}" + ) + except Exception as e: + console.print(f"[bold red]✗[/bold red] Upload failed: {e}") + raise typer.Exit(1) from e + + +@app.command() +def push( + directory: Annotated[ + str | None, + typer.Argument( + help="Directory containing the OpenEnv environment (default: current directory)" + ), + ] = None, + repo_id: Annotated[ + str | None, + typer.Option( + "--repo-id", + "-r", + help="Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)", + ), + ] = None, + base_image: Annotated[ + str | None, + typer.Option( + "--base-image", + "-b", + help="Base Docker image to use (overrides Dockerfile FROM)", + ), + ] = None, + interface: Annotated[ + bool, + typer.Option( + "--interface", + help="Enable web interface (default: True if no registry specified)", + ), + ] = None, + no_interface: Annotated[ + bool, + typer.Option( + "--no-interface", + help="Disable web interface", + ), + ] = False, + registry: Annotated[ + str | None, + typer.Option( + "--registry", + help="Custom registry URL (e.g., docker.io/username). Disables web interface by default.", + ), + ] = None, + private: Annotated[ + bool, + typer.Option( + "--private", + help="Deploy the space as private", + ), + ] = False, + create_pr: Annotated[ + bool, + typer.Option( + "--create-pr", + help="Create a Pull Request instead of pushing to the default branch", + ), + ] = False, + exclude: Annotated[ + str | None, + typer.Option( + "--exclude", + help="Optional additional ignore file with newline-separated glob patterns to exclude from Hugging Face uploads", + ), + ] = None, +) -> None: + """ + Push an OpenEnv environment to Hugging Face Spaces or a custom Docker registry. + + This command: + 1. Validates that the directory is an OpenEnv environment (openenv.yaml present) + 2. Builds and pushes to Hugging Face Spaces or custom Docker registry + 3. Optionally enables web interface for deployment + + The web interface is enabled by default when pushing to HuggingFace Spaces, + but disabled by default when pushing to a custom Docker registry. + + Examples: + # Push to HuggingFace Spaces from current directory (web interface enabled) + $ cd my_env + $ openenv push + + # Push to HuggingFace repo and open a Pull Request + $ openenv push my-org/my-env --create-pr + $ openenv push --repo-id my-org/my-env --create-pr + + # Push to HuggingFace without web interface + $ openenv push --no-interface + + # Push to Docker Hub + $ openenv push --registry docker.io/myuser + + # Push to GitHub Container Registry + $ openenv push --registry ghcr.io/myorg + + # Push to custom registry with web interface + $ openenv push --registry myregistry.io/path1/path2 --interface + + # Push to specific HuggingFace repo + $ openenv push --repo-id my-org/my-env + + # Push privately with custom base image + $ openenv push --private --base-image ghcr.io/meta-pytorch/openenv-base:latest + """ + # Handle interface flag logic + if no_interface and interface: + console.print( + "[bold red]Error:[/bold red] Cannot specify both --interface and --no-interface", + file=sys.stderr, + ) + raise typer.Exit(1) + + # Determine if web interface should be enabled + if no_interface: + enable_interface = False + elif interface is not None: + enable_interface = interface + elif registry is not None: + # Custom registry: disable interface by default + enable_interface = False + else: + # HuggingFace: enable interface by default + enable_interface = True + + # Determine directory + if directory: + env_dir = Path(directory).resolve() + else: + env_dir = Path.cwd().resolve() + + if not env_dir.exists() or not env_dir.is_dir(): + raise typer.BadParameter(f"Directory does not exist: {env_dir}") + + # Check for openenv.yaml to confirm this is an environment directory + openenv_yaml = env_dir / "openenv.yaml" + if not openenv_yaml.exists(): + console.print( + f"[bold red]Error:[/bold red] Not an OpenEnv environment directory (missing openenv.yaml): {env_dir}", + ) + console.print( + "[yellow]Hint:[/yellow] Run this command from the environment root directory", + ) + raise typer.Exit(1) + + # Validate OpenEnv environment + console.print( + f"[bold cyan]Validating OpenEnv environment in {env_dir}...[/bold cyan]" + ) + env_name, manifest = _validate_openenv_directory(env_dir) + console.print(f"[bold green]✓[/bold green] Found OpenEnv environment: {env_name}") + + # Handle custom registry push + if registry: + console.print("[bold cyan]Preparing to push to custom registry...[/bold cyan]") + if enable_interface: + console.print("[bold cyan]Web interface will be enabled[/bold cyan]") + + # Import build functions + from .build import _build_docker_image, _push_docker_image + + # Prepare build args for custom registry deployment + build_args = {} + if enable_interface: + build_args["ENABLE_WEB_INTERFACE"] = "true" + + # Build Docker image from the environment directory + tag = f"{registry}/{env_name}" + console.print(f"[bold cyan]Building Docker image: {tag}[/bold cyan]") + + success = _build_docker_image( + env_path=env_dir, + tag=tag, + build_args=build_args if build_args else None, + ) + + if not success: + console.print("[bold red]✗ Docker build failed[/bold red]") + raise typer.Exit(1) + + console.print("[bold green]✓ Docker build successful[/bold green]") + + # Push to registry + console.print(f"[bold cyan]Pushing to registry: {registry}[/bold cyan]") + + success = _push_docker_image( + tag, registry=None + ) # Tag already includes registry + + if not success: + console.print("[bold red]✗ Docker push failed[/bold red]") + raise typer.Exit(1) + + console.print("\n[bold green]✓ Deployment complete![/bold green]") + console.print(f"[bold]Image:[/bold] {tag}") + return + + ignore_patterns = _load_ignore_patterns(env_dir, exclude) + + # Ensure authentication for HuggingFace + username = _ensure_hf_authenticated() + + # Determine repo_id + if not repo_id: + repo_id = f"{username}/{env_name}" + + # Validate repo_id format + if "/" not in repo_id or repo_id.count("/") != 1: + raise typer.BadParameter( + f"Invalid repo-id format: {repo_id}. Expected format: 'username/repo-name'" + ) + + # Initialize Hugging Face API + api = HfApi() + + # Prepare staging directory + deployment_type = ( + "with web interface" if enable_interface else "without web interface" + ) + console.print( + f"[bold cyan]Preparing files for Hugging Face deployment ({deployment_type})...[/bold cyan]" + ) + with tempfile.TemporaryDirectory() as tmpdir: + staging_dir = Path(tmpdir) / "staging" + _prepare_staging_directory( + env_dir, + env_name, + staging_dir, + ignore_patterns=ignore_patterns, + base_image=base_image, + enable_interface=enable_interface, + ) + + # Create/verify space (no-op if exists; needed when pushing to own new repo) + if not create_pr: + _create_hf_space(repo_id, api, private=private) + # When create_pr we rely on upload_folder to create branch and PR + + # Upload files + _upload_to_hf_space( + repo_id, + staging_dir, + api, + private=private, + create_pr=create_pr, + ignore_patterns=ignore_patterns, + ) + + console.print("\n[bold green]✓ Deployment complete![/bold green]") + console.print(f"Visit your space at: https://huggingface.co/spaces/{repo_id}") diff --git a/src/openenv/cli/commands/serve.py b/src/openenv/cli/commands/serve.py new file mode 100644 index 0000000000000000000000000000000000000000..df2bfa5a34d83e07ea35e5df06e523d1c565cbc5 --- /dev/null +++ b/src/openenv/cli/commands/serve.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Serve OpenEnv environments locally (TO BE IMPLEMENTED).""" + +from __future__ import annotations + +from pathlib import Path +from typing import Annotated + +import typer + +from .._cli_utils import console + +app = typer.Typer(help="Serve OpenEnv environments locally") + + +@app.command() +def serve( + env_path: Annotated[ + str | None, + typer.Argument( + help="Path to the environment directory (default: current directory)" + ), + ] = None, + port: Annotated[ + int, + typer.Option("--port", "-p", help="Port to serve on"), + ] = 8000, + host: Annotated[ + str, + typer.Option("--host", help="Host to bind to"), + ] = "0.0.0.0", + reload: Annotated[ + bool, + typer.Option("--reload", help="Enable auto-reload on code changes"), + ] = False, +) -> None: + """ + Serve an OpenEnv environment locally. + + TODO: This command is currently not implemented and has been deferred for later. + + Planned functionality: + - Run environment server locally without Docker + - Support multiple deployment modes (local, notebook, cluster) + - Auto-reload for development + - Integration with environment's [project.scripts] entry point + + For now, use Docker-based serving: + 1. Build the environment: openenv build + 2. Run the container: docker run -p 8000:8000 + + Or use uv directly: + uv run --project . server --port 8000 + """ + console.print("[bold yellow]⚠ This command is not yet implemented[/bold yellow]\n") + + console.print( + "The [bold cyan]openenv serve[/bold cyan] command has been deferred for later." + ) + + console.print("[bold]Alternative approaches:[/bold]\n") + + console.print("[cyan]Option 1: Docker-based serving (recommended)[/cyan]") + console.print(" 1. Build the environment:") + console.print(" [dim]$ openenv build[/dim]") + console.print(" 2. Run the Docker container:") + console.print( + f" [dim]$ docker run -p {port}:{port} openenv-:latest[/dim]\n" + ) + + console.print("[cyan]Option 2: Direct execution with uv[/cyan]") + + # Determine environment path + if env_path is None: + env_path_obj = Path.cwd() + else: + env_path_obj = Path(env_path) + + # Check for openenv.yaml + openenv_yaml = env_path_obj / "openenv.yaml" + if openenv_yaml.exists(): + console.print(" From your environment directory:") + console.print(f" [dim]$ cd {env_path_obj}[/dim]") + console.print(f" [dim]$ uv run --project . server --port {port}[/dim]\n") + else: + console.print(" From an environment directory with pyproject.toml:") + console.print(f" [dim]$ uv run --project . server --port {port}[/dim]\n") + + raise typer.Exit(0) diff --git a/src/openenv/cli/commands/skills.py b/src/openenv/cli/commands/skills.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb29db72e26a104e9eb75a6309fbc9ed39538eb --- /dev/null +++ b/src/openenv/cli/commands/skills.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Commands to manage OpenEnv CLI skills for AI assistants.""" + +from __future__ import annotations + +import os +import shutil +from pathlib import Path +from typing import Annotated + +import typer + +DEFAULT_SKILL_ID = "openenv-cli" + +_SKILL_YAML_PREFIX = """\ +--- +name: openenv-cli +description: "OpenEnv CLI (`openenv`) for scaffolding, validating, building, and pushing OpenEnv environments." +--- + +Install: `pip install openenv-core` + +The OpenEnv CLI command `openenv` is available. +Use `openenv --help` to view available commands. +""" + +_SKILL_TIPS = """ +## Tips + +- Start with `openenv init ` to scaffold a new environment +- Validate projects with `openenv validate` +- Build and deploy with `openenv build` and `openenv push` +- Use `openenv --help` for command-specific options +""" + +CENTRAL_LOCAL = Path(".agents/skills") +CENTRAL_GLOBAL = Path("~/.agents/skills") + +GLOBAL_TARGETS = { + "codex": Path("~/.codex/skills"), + "claude": Path("~/.claude/skills"), + "cursor": Path("~/.cursor/skills"), + "opencode": Path("~/.config/opencode/skills"), +} + +LOCAL_TARGETS = { + "codex": Path(".codex/skills"), + "claude": Path(".claude/skills"), + "cursor": Path(".cursor/skills"), + "opencode": Path(".opencode/skills"), +} + +app = typer.Typer(help="Manage OpenEnv skills for AI assistants") + + +def _build_skill_md() -> str: + """Generate SKILL.md content for the OpenEnv CLI skill.""" + from openenv import __version__ + + lines = _SKILL_YAML_PREFIX.splitlines() + lines.append("") + lines.append( + f"Generated with `openenv-core v{__version__}`. Run `openenv skills add --force` to regenerate." + ) + lines.extend(_SKILL_TIPS.splitlines()) + return "\n".join(lines).strip() + "\n" + + +def _remove_existing(path: Path, force: bool) -> None: + """Remove existing file/directory/symlink if force is True, else fail.""" + if not (path.exists() or path.is_symlink()): + return + if not force: + raise typer.Exit(code=1) + + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path) + else: + path.unlink() + + +def _install_to(skills_dir: Path, force: bool) -> Path: + """Install the OpenEnv skill in a skills directory.""" + skills_dir = skills_dir.expanduser().resolve() + skills_dir.mkdir(parents=True, exist_ok=True) + dest = skills_dir / DEFAULT_SKILL_ID + + if dest.exists() or dest.is_symlink(): + if not force: + typer.echo( + f"Skill already exists at {dest}. Re-run with --force to overwrite." + ) + raise typer.Exit(code=1) + _remove_existing(dest, force=True) + + dest.mkdir() + (dest / "SKILL.md").write_text(_build_skill_md(), encoding="utf-8") + return dest + + +def _create_symlink( + agent_skills_dir: Path, central_skill_path: Path, force: bool +) -> Path: + """Create a relative symlink from agent directory to central skill location.""" + agent_skills_dir = agent_skills_dir.expanduser().resolve() + agent_skills_dir.mkdir(parents=True, exist_ok=True) + link_path = agent_skills_dir / DEFAULT_SKILL_ID + + if link_path.exists() or link_path.is_symlink(): + if not force: + typer.echo( + f"Skill already exists at {link_path}. Re-run with --force to overwrite." + ) + raise typer.Exit(code=1) + _remove_existing(link_path, force=True) + + link_path.symlink_to(os.path.relpath(central_skill_path, agent_skills_dir)) + return link_path + + +@app.command("preview") +def skills_preview() -> None: + """Print generated SKILL.md content.""" + typer.echo(_build_skill_md()) + + +@app.command("add") +def skills_add( + claude: Annotated[ + bool, + typer.Option("--claude", help="Install for Claude."), + ] = False, + codex: Annotated[ + bool, + typer.Option("--codex", help="Install for Codex."), + ] = False, + cursor: Annotated[ + bool, + typer.Option("--cursor", help="Install for Cursor."), + ] = False, + opencode: Annotated[ + bool, + typer.Option("--opencode", help="Install for OpenCode."), + ] = False, + global_: Annotated[ + bool, + typer.Option( + "--global", + "-g", + help=( + "Install globally (user-level) instead of in the current project directory." + ), + ), + ] = False, + dest: Annotated[ + Path | None, + typer.Option(help="Install into a custom destination (skills directory path)."), + ] = None, + force: Annotated[ + bool, + typer.Option("--force", help="Overwrite existing skills in the destination."), + ] = False, +) -> None: + """Install OpenEnv CLI skill for AI assistants.""" + if dest: + if claude or codex or cursor or opencode or global_: + typer.echo( + "--dest cannot be combined with --claude, --codex, --cursor, --opencode, or --global." + ) + raise typer.Exit(code=1) + skill_dest = _install_to(dest, force) + typer.echo(f"Installed '{DEFAULT_SKILL_ID}' to {skill_dest}") + return + + central_path = CENTRAL_GLOBAL if global_ else CENTRAL_LOCAL + central_skill_path = _install_to(central_path, force) + typer.echo( + f"Installed '{DEFAULT_SKILL_ID}' to central location: {central_skill_path}" + ) + + targets = GLOBAL_TARGETS if global_ else LOCAL_TARGETS + agent_targets: list[Path] = [] + + if claude: + agent_targets.append(targets["claude"]) + if codex: + agent_targets.append(targets["codex"]) + if cursor: + agent_targets.append(targets["cursor"]) + if opencode: + agent_targets.append(targets["opencode"]) + + for agent_target in agent_targets: + link_path = _create_symlink(agent_target, central_skill_path, force) + typer.echo(f"Created symlink: {link_path}") diff --git a/src/openenv/cli/commands/validate.py b/src/openenv/cli/commands/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..32abcc17e11f8cf22e08d04e570383f57ccc1199 --- /dev/null +++ b/src/openenv/cli/commands/validate.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +OpenEnv validate command. + +This module provides the 'openenv validate' command to check if environments +are properly configured for multi-mode deployment. +""" + +import json +from pathlib import Path +from typing import Annotated + +import typer +from openenv.cli._validation import ( + build_local_validation_json_report, + format_validation_report, + get_deployment_modes, + validate_multi_mode_deployment, + validate_running_environment, +) + + +def _looks_like_url(value: str) -> bool: + """Return True when the value appears to be a URL target.""" + candidate = value.strip().lower() + return candidate.startswith("http://") or candidate.startswith("https://") + + +def validate( + target: Annotated[ + str | None, + typer.Argument( + help=( + "Path to the environment directory (default: current directory) " + "or a running OpenEnv URL (http://... or https://...)" + ), + ), + ] = None, + url: Annotated[ + str | None, + typer.Option( + "--url", + help="Validate a running OpenEnv server by base URL (e.g. http://localhost:8000)", + ), + ] = None, + json_output: Annotated[ + bool, + typer.Option( + "--json", + help="Output local validation report as JSON (runtime validation is JSON by default)", + ), + ] = False, + timeout: Annotated[ + float, + typer.Option( + "--timeout", + help="HTTP timeout in seconds for runtime validation", + min=0.1, + ), + ] = 5.0, + verbose: Annotated[ + bool, typer.Option("--verbose", "-v", help="Show detailed information") + ] = False, +) -> None: + """ + Validate local environments and running OpenEnv servers. + + Local validation checks if an environment is properly configured with: + - Required files (pyproject.toml, openenv.yaml, server/app.py, etc.) + - Docker deployment support + - uv run server capability + - python -m module execution + + Runtime validation checks if a live OpenEnv server conforms to the + versioned runtime API contract and returns a criteria-based JSON report. + + Examples: + # Validate current directory (recommended) + $ cd my_env + $ openenv validate + + # Validate a running environment and return JSON criteria + $ openenv validate --url http://localhost:8000 + $ openenv validate https://my-env.hf.space + + # Validate with detailed output + $ openenv validate --verbose + + # Validate specific environment + $ openenv validate envs/echo_env + """ + runtime_target = url + if ( + runtime_target is not None + and target is not None + and not _looks_like_url(target) + ): + typer.echo( + "Error: Cannot combine a local path argument with --url runtime validation", + err=True, + ) + raise typer.Exit(1) + + if target is not None and _looks_like_url(target): + if runtime_target is not None and runtime_target != target: + typer.echo( + "Error: Conflicting runtime targets provided via argument and --url", + err=True, + ) + raise typer.Exit(1) + runtime_target = target + + if runtime_target is not None: + try: + report = validate_running_environment(runtime_target, timeout_s=timeout) + except ValueError as exc: + typer.echo(f"Error: {exc}", err=True) + raise typer.Exit(1) from exc + + typer.echo(json.dumps(report, indent=2)) + if not report.get("passed", False): + raise typer.Exit(1) + return + + # Determine environment path (default to current directory) + if target is None: + env_path_obj = Path.cwd() + else: + env_path_obj = Path(target) + + if not env_path_obj.exists(): + typer.echo(f"Error: Path does not exist: {env_path_obj}", err=True) + raise typer.Exit(1) + + if not env_path_obj.is_dir(): + typer.echo(f"Error: Path is not a directory: {env_path_obj}", err=True) + raise typer.Exit(1) + + # Check for openenv.yaml to confirm this is an environment directory + openenv_yaml = env_path_obj / "openenv.yaml" + if not openenv_yaml.exists(): + typer.echo( + f"Error: Not an OpenEnv environment directory (missing openenv.yaml): {env_path_obj}", + err=True, + ) + typer.echo( + "Hint: Run this command from the environment root directory or specify the path", + err=True, + ) + raise typer.Exit(1) + + env_name = env_path_obj.name + if env_name.endswith("_env"): + base_name = env_name[:-4] + else: + base_name = env_name + + # Run validation + is_valid, issues = validate_multi_mode_deployment(env_path_obj) + modes = get_deployment_modes(env_path_obj) + + if json_output: + report = build_local_validation_json_report( + env_name=base_name, + env_path=env_path_obj, + is_valid=is_valid, + issues=issues, + deployment_modes=modes if verbose else None, + ) + typer.echo(json.dumps(report, indent=2)) + if not is_valid: + raise typer.Exit(1) + return + + # Show validation report + report = format_validation_report(base_name, is_valid, issues) + typer.echo(report) + + # Show deployment modes if verbose + if verbose: + typer.echo("\nSupported deployment modes:") + for mode, supported in modes.items(): + status = "[YES]" if supported else "[NO]" + typer.echo(f" {status} {mode}") + + if is_valid: + typer.echo("\nUsage examples:") + typer.echo(f" cd {env_path_obj.name} && uv run server") + typer.echo(f" cd {env_path_obj.name} && openenv build") + typer.echo(f" cd {env_path_obj.name} && openenv push") + + if not is_valid: + raise typer.Exit(1) diff --git a/src/openenv/cli/templates/__init__.py b/src/openenv/cli/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..452e81a7b8584c3447c6f83fc9560f6f9d334ced --- /dev/null +++ b/src/openenv/cli/templates/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenEnv CLI templates package.""" diff --git a/src/openenv/cli/templates/openenv_env/.dockerignore b/src/openenv/cli/templates/openenv_env/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..fc288e5de90f4988be5e0ef73d17b2314786406f --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/.dockerignore @@ -0,0 +1,15 @@ +.venv +.git +.gitignore +.env +__pycache__/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw +*.pyzwz + + diff --git a/src/openenv/cli/templates/openenv_env/README.md b/src/openenv/cli/templates/openenv_env/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3f14526a0ce173408073358a6b94d15c85c9aa97 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/README.md @@ -0,0 +1,255 @@ +--- +title: __ENV_TITLE_NAME__ Environment Server +emoji: __HF_EMOJI__ +colorFrom: __HF_COLOR_FROM__ +colorTo: __HF_COLOR_TO__ +sdk: docker +pinned: false +app_port: 8000 +base_path: /web +tags: + - openenv +--- + +# __ENV_TITLE_NAME__ Environment + +A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns. + +## Quick Start + +The simplest way to use the __ENV_TITLE_NAME__ environment is through the `__ENV_CLASS_NAME__Env` class: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +try: + # Create environment from Docker image + __ENV_NAME__env = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest") + + # Reset + result = __ENV_NAME__env.reset() + print(f"Reset: {result.observation.echoed_message}") + + # Send multiple messages + messages = ["Hello, World!", "Testing echo", "Final message"] + + for msg in messages: + result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Sent: '{msg}'") + print(f" → Echoed: '{result.observation.echoed_message}'") + print(f" → Length: {result.observation.message_length}") + print(f" → Reward: {result.reward}") + +finally: + # Always clean up + __ENV_NAME__env.close() +``` + +That's it! The `__ENV_CLASS_NAME__Env.from_docker_image()` method handles: +- Starting the Docker container +- Waiting for the server to be ready +- Connecting to the environment +- Container cleanup when you call `close()` + +## Building the Docker Image + +Before using the environment, you need to build the Docker image: + +```bash +# From project root +docker build -t __ENV_NAME__-env:latest -f server/Dockerfile . +``` + +## Deploying to Hugging Face Spaces + +You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command: + +```bash +# From the environment directory (where openenv.yaml is located) +openenv push + +# Or specify options +openenv push --namespace my-org --private +``` + +The `openenv push` command will: +1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`) +2. Prepare a custom build for Hugging Face Docker space (enables web interface) +3. Upload to Hugging Face (ensuring you're logged in) + +### Prerequisites + +- Authenticate with Hugging Face: The command will prompt for login if not already authenticated + +### Options + +- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory) +- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml) +- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM) +- `--private`: Deploy the space as private (default: public) + +### Examples + +```bash +# Push to your personal namespace (defaults to username/env-name from openenv.yaml) +openenv push + +# Push to a specific repository +openenv push --repo-id my-org/my-env + +# Push with a custom base image +openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest + +# Push as a private space +openenv push --private + +# Combine options +openenv push --repo-id my-org/my-env --base-image custom-base:latest --private +``` + +After deployment, your space will be available at: +`https://huggingface.co/spaces/` + +The deployed space includes: +- **Web Interface** at `/web` - Interactive UI for exploring the environment +- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface +- **Health Check** at `/health` - Container health monitoring +- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions + +## Environment Details + +### Action +**__ENV_CLASS_NAME__Action**: Contains a single field +- `message` (str) - The message to echo back + +### Observation +**__ENV_CLASS_NAME__Observation**: Contains the echo response and metadata +- `echoed_message` (str) - The message echoed back +- `message_length` (int) - Length of the message +- `reward` (float) - Reward based on message length (length × 0.1) +- `done` (bool) - Always False for echo environment +- `metadata` (dict) - Additional info like step count + +### Reward +The reward is calculated as: `message_length × 0.1` +- "Hi" → reward: 0.2 +- "Hello, World!" → reward: 1.3 +- Empty message → reward: 0.0 + +## Advanced Usage + +### Connecting to an Existing Server + +If you already have a __ENV_TITLE_NAME__ environment server running, you can connect directly: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Env + +# Connect to existing server +__ENV_NAME__env = __ENV_CLASS_NAME__Env(base_url="") + +# Use as normal +result = __ENV_NAME__env.reset() +result = __ENV_NAME__env.step(__ENV_CLASS_NAME__Action(message="Hello!")) +``` + +Note: When connecting to an existing server, `__ENV_NAME__env.close()` will NOT stop the server. + +### Using the Context Manager + +The client supports context manager usage for automatic connection management: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +# Connect with context manager (auto-connects and closes) +with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: + result = env.reset() + print(f"Reset: {result.observation.echoed_message}") + # Multiple steps with low latency + for msg in ["Hello", "World", "!"]: + result = env.step(__ENV_CLASS_NAME__Action(message=msg)) + print(f"Echoed: {result.observation.echoed_message}") +``` + +The client uses WebSocket connections for: +- **Lower latency**: No HTTP connection overhead per request +- **Persistent session**: Server maintains your environment state +- **Efficient for episodes**: Better for many sequential steps + +### Concurrent WebSocket Sessions + +The server supports multiple concurrent WebSocket connections. To enable this, +modify `server/app.py` to use factory mode: + +```python +# In server/app.py - use factory mode for concurrent sessions +app = create_app( + __ENV_CLASS_NAME__Environment, # Pass class, not instance + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + max_concurrent_envs=4, # Allow 4 concurrent sessions +) +``` + +Then multiple clients can connect simultaneously: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env +from concurrent.futures import ThreadPoolExecutor + +def run_episode(client_id: int): + with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as env: + result = env.reset() + for i in range(10): + result = env.step(__ENV_CLASS_NAME__Action(message=f"Client {client_id}, step {i}")) + return client_id, result.observation.message_length + +# Run 4 episodes concurrently +with ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(run_episode, range(4))) +``` + +## Development & Testing + +### Direct Environment Testing + +Test the environment logic directly without starting the HTTP server: + +```bash +# From the server directory +python3 server/__ENV_NAME___environment.py +``` + +This verifies that: +- Environment resets correctly +- Step executes actions properly +- State tracking works +- Rewards are calculated correctly + +### Running Locally + +Run the server locally for development: + +```bash +uvicorn server.app:app --reload +``` + +## Project Structure + +``` +__ENV_NAME__/ +├── .dockerignore # Docker build exclusions +├── __init__.py # Module exports +├── README.md # This file +├── openenv.yaml # OpenEnv manifest +├── pyproject.toml # Project metadata and dependencies +├── uv.lock # Locked dependencies (generated) +├── client.py # __ENV_CLASS_NAME__Env client +├── models.py # Action and Observation models +└── server/ + ├── __init__.py # Server module exports + ├── __ENV_NAME___environment.py # Core environment logic + ├── app.py # FastAPI application (HTTP + WebSocket endpoints) + └── Dockerfile # Container image definition +``` diff --git a/src/openenv/cli/templates/openenv_env/__init__.py b/src/openenv/cli/templates/openenv_env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe07a082faf989d3ae22ece407c34364b394128 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""__ENV_TITLE_NAME__ Environment.""" + +from .client import __ENV_CLASS_NAME__Env +from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + +__all__ = [ + "__ENV_CLASS_NAME__Action", + "__ENV_CLASS_NAME__Observation", + "__ENV_CLASS_NAME__Env", +] diff --git a/src/openenv/cli/templates/openenv_env/client.py b/src/openenv/cli/templates/openenv_env/client.py new file mode 100644 index 0000000000000000000000000000000000000000..720090431300aad0866c8a737f84a48a3df238b3 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/client.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""__ENV_TITLE_NAME__ Environment Client.""" + +from typing import Dict + +from openenv.core import EnvClient +from openenv.core.client_types import StepResult +from openenv.core.env_server.types import State + +from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + + +class __ENV_CLASS_NAME__Env( + EnvClient[__ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation, State] +): + """ + Client for the __ENV_TITLE_NAME__ Environment. + + This client maintains a persistent WebSocket connection to the environment server, + enabling efficient multi-step interactions with lower latency. + Each client instance has its own dedicated environment session on the server. + + Example: + >>> # Connect to a running server + >>> with __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") as client: + ... result = client.reset() + ... print(result.observation.echoed_message) + ... + ... result = client.step(__ENV_CLASS_NAME__Action(message="Hello!")) + ... print(result.observation.echoed_message) + + Example with Docker: + >>> # Automatically start container and connect + >>> client = __ENV_CLASS_NAME__Env.from_docker_image("__ENV_NAME__-env:latest") + >>> try: + ... result = client.reset() + ... result = client.step(__ENV_CLASS_NAME__Action(message="Test")) + ... finally: + ... client.close() + """ + + def _step_payload(self, action: __ENV_CLASS_NAME__Action) -> Dict: + """ + Convert __ENV_CLASS_NAME__Action to JSON payload for step message. + + Args: + action: __ENV_CLASS_NAME__Action instance + + Returns: + Dictionary representation suitable for JSON encoding + """ + return { + "message": action.message, + } + + def _parse_result(self, payload: Dict) -> StepResult[__ENV_CLASS_NAME__Observation]: + """ + Parse server response into StepResult[__ENV_CLASS_NAME__Observation]. + + Args: + payload: JSON response data from server + + Returns: + StepResult with __ENV_CLASS_NAME__Observation + """ + obs_data = payload.get("observation", {}) + observation = __ENV_CLASS_NAME__Observation( + echoed_message=obs_data.get("echoed_message", ""), + message_length=obs_data.get("message_length", 0), + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict) -> State: + """ + Parse server response into State object. + + Args: + payload: JSON response from state request + + Returns: + State object with episode_id and step_count + """ + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) diff --git a/src/openenv/cli/templates/openenv_env/models.py b/src/openenv/cli/templates/openenv_env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..5aea7f452a043602375620c48e65f0915ebf7f42 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/models.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Data models for the __ENV_TITLE_NAME__ Environment. + +The __ENV_NAME__ environment is a simple test environment that echoes back messages. +""" + +from openenv.core.env_server.types import Action, Observation +from pydantic import Field + + +class __ENV_CLASS_NAME__Action(Action): + """Action for the __ENV_TITLE_NAME__ environment - just a message to echo.""" + + message: str = Field(..., description="Message to echo back") + + +class __ENV_CLASS_NAME__Observation(Observation): + """Observation from the __ENV_TITLE_NAME__ environment - the echoed message.""" + + echoed_message: str = Field(default="", description="The echoed message") + message_length: int = Field(default=0, description="Length of the echoed message") diff --git a/src/openenv/cli/templates/openenv_env/openenv.yaml b/src/openenv/cli/templates/openenv_env/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..828cc53b2b61c37bf6f860f25cbe2881825e3fd3 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/openenv.yaml @@ -0,0 +1,7 @@ +spec_version: 1 +name: __ENV_NAME__ +type: space +runtime: fastapi +app: server.app:app +port: 8000 + diff --git a/src/openenv/cli/templates/openenv_env/pyproject.toml b/src/openenv/cli/templates/openenv_env/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..a8e59fbfa3dbc8a0df7c84d479e79cef062d8e61 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/pyproject.toml @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +[build-system] +requires = ["setuptools>=45", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "openenv-__ENV_NAME__" +version = "0.1.0" +description = "__ENV_TITLE_NAME__ environment for OpenEnv" +requires-python = ">=3.10" +dependencies = [ + # Core OpenEnv runtime (provides FastAPI server + HTTP client types) + # install from github + # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git", + "openenv-core[core]>=0.2.1", + # Environment-specific dependencies + # Add all dependencies needed for your environment here + # Examples: + # "numpy>=1.19.0", + # "torch>=2.0.0", + # "gymnasium>=0.29.0", + # "openspiel>=1.0.0", + # "smolagents>=1.22.0,<2", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-cov>=4.0.0", +] + +[project.scripts] +# Server entry point - enables running via: uv run --project . server +# or: python -m __ENV_NAME__.server.app +server = "__ENV_NAME__.server.app:main" + +[tool.setuptools] +include-package-data = true +packages = ["__ENV_NAME__", "__ENV_NAME__.server"] +package-dir = { "__ENV_NAME__" = ".", "__ENV_NAME__.server" = "server" } \ No newline at end of file diff --git a/src/openenv/cli/templates/openenv_env/server/Dockerfile b/src/openenv/cli/templates/openenv_env/server/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3d10ac76bf7e199e26fb77921f88d98f96120368 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/server/Dockerfile @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Multi-stage build using openenv-base +# This Dockerfile is flexible and works for both: +# - In-repo environments (with local OpenEnv sources) +# - Standalone environments (with openenv from PyPI/Git) +# The build script (openenv build) handles context detection and sets appropriate build args. + +ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest +FROM ${BASE_IMAGE} AS builder + +WORKDIR /app + +# Ensure git is available (required for installing dependencies from VCS) +RUN apt-get update && \ + apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* + +# Build argument to control whether we're building standalone or in-repo +ARG BUILD_MODE=in-repo +ARG ENV_NAME=__ENV_NAME__ + +# Copy environment code (always at root of build context) +COPY . /app/env + +# For in-repo builds, openenv is already vendored in the build context +# For standalone builds, openenv will be installed via pyproject.toml +WORKDIR /app/env + +# Ensure uv is available (for local builds where base image lacks it) +RUN if ! command -v uv >/dev/null 2>&1; then \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + fi + +# Install dependencies using uv sync +# If uv.lock exists, use it; otherwise resolve on the fly +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-install-project --no-editable; \ + else \ + uv sync --no-install-project --no-editable; \ + fi + +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ -f uv.lock ]; then \ + uv sync --frozen --no-editable; \ + else \ + uv sync --no-editable; \ + fi + +# Final runtime stage +FROM ${BASE_IMAGE} + +WORKDIR /app + +# Copy the virtual environment from builder +COPY --from=builder /app/env/.venv /app/.venv + +# Copy the environment code +COPY --from=builder /app/env /app/env + +# Set PATH to use the virtual environment +ENV PATH="/app/.venv/bin:$PATH" + +# Set PYTHONPATH so imports work correctly +ENV PYTHONPATH="/app/env:$PYTHONPATH" + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + +# Run the FastAPI server +# The module path is constructed to work with the /app/env structure +CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py new file mode 100644 index 0000000000000000000000000000000000000000..bbde58219abbb880e79662bde49c6adab96f77eb --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +__ENV_TITLE_NAME__ Environment Implementation. + +A simple test environment that echoes back messages sent to it. +Perfect for testing HTTP server infrastructure. +""" + +from uuid import uuid4 + +from openenv.core.env_server.interfaces import Environment +from openenv.core.env_server.types import State + +try: + from ..models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation +except ImportError: + from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + + +class __ENV_CLASS_NAME__Environment(Environment): + """ + A simple echo environment that echoes back messages. + + This environment is designed for testing the HTTP server infrastructure. + It maintains minimal state and simply echoes back whatever message it receives. + + Example: + >>> env = __ENV_CLASS_NAME__Environment() + >>> obs = env.reset() + >>> print(obs.echoed_message) # "__ENV_TITLE_NAME__ environment ready!" + >>> + >>> obs = env.step(__ENV_CLASS_NAME__Action(message="Hello")) + >>> print(obs.echoed_message) # "Hello" + >>> print(obs.message_length) # 5 + """ + + # Enable concurrent WebSocket sessions. + # Set to True if your environment isolates state between instances. + # When True, multiple WebSocket clients can connect simultaneously, each + # getting their own environment instance (when using factory mode in app.py). + SUPPORTS_CONCURRENT_SESSIONS: bool = True + + def __init__(self): + """Initialize the __ENV_NAME__ environment.""" + self._state = State(episode_id=str(uuid4()), step_count=0) + self._reset_count = 0 + + def reset(self) -> __ENV_CLASS_NAME__Observation: + """ + Reset the environment. + + Returns: + __ENV_CLASS_NAME__Observation with a ready message + """ + self._state = State(episode_id=str(uuid4()), step_count=0) + self._reset_count += 1 + + return __ENV_CLASS_NAME__Observation( + echoed_message="__ENV_TITLE_NAME__ environment ready!", + message_length=0, + done=False, + reward=0.0, + ) + + def step(self, action: __ENV_CLASS_NAME__Action) -> __ENV_CLASS_NAME__Observation: # type: ignore[override] + """ + Execute a step in the environment by echoing the message. + + Args: + action: __ENV_CLASS_NAME__Action containing the message to echo + + Returns: + __ENV_CLASS_NAME__Observation with the echoed message and its length + """ + self._state.step_count += 1 + + message = action.message + length = len(message) + + # Simple reward: longer messages get higher rewards + reward = length * 0.1 + + return __ENV_CLASS_NAME__Observation( + echoed_message=message, + message_length=length, + done=False, + reward=reward, + metadata={"original_message": message, "step": self._state.step_count}, + ) + + @property + def state(self) -> State: + """ + Get the current environment state. + + Returns: + Current State with episode_id and step_count + """ + return self._state diff --git a/src/openenv/cli/templates/openenv_env/server/__init__.py b/src/openenv/cli/templates/openenv_env/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..191fb655582f1cc13943574814ed4b39b5d60d7c --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/server/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""__ENV_TITLE_NAME__ environment server components.""" + +from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment + +__all__ = ["__ENV_CLASS_NAME__Environment"] diff --git a/src/openenv/cli/templates/openenv_env/server/app.py b/src/openenv/cli/templates/openenv_env/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..898911a2a55495426d20b438c4de009ec103ccdd --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/server/app.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +FastAPI application for the __ENV_TITLE_NAME__ Environment. + +This module creates an HTTP server that exposes the __ENV_CLASS_NAME__Environment +over HTTP and WebSocket endpoints, compatible with EnvClient. + +Endpoints: + - POST /reset: Reset the environment + - POST /step: Execute an action + - GET /state: Get current environment state + - GET /schema: Get action/observation schemas + - WS /ws: WebSocket endpoint for persistent sessions + +Usage: + # Development (with auto-reload): + uvicorn server.app:app --reload --host 0.0.0.0 --port 8000 + + # Production: + uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4 + + # Or run directly: + python -m server.app +""" + +try: + from openenv.core.env_server.http_server import create_app +except Exception as e: # pragma: no cover + raise ImportError( + "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" + ) from e + +try: + from ..models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + from .__ENV_NAME___environment import __ENV_CLASS_NAME__Environment +except ModuleNotFoundError: + from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation + from server.__ENV_NAME___environment import __ENV_CLASS_NAME__Environment + + +# Create the app with web interface and README integration +app = create_app( + __ENV_CLASS_NAME__Environment, + __ENV_CLASS_NAME__Action, + __ENV_CLASS_NAME__Observation, + env_name="__ENV_NAME__", + max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions +) + + +def main(host: str = "0.0.0.0", port: int = 8000): + """ + Entry point for direct execution via uv run or python -m. + + This function enables running the server without Docker: + uv run --project . server + uv run --project . server --port 8001 + python -m __ENV_NAME__.server.app + + Args: + host: Host address to bind to (default: "0.0.0.0") + port: Port number to listen on (default: 8000) + + For production deployments, consider using uvicorn directly with + multiple workers: + uvicorn __ENV_NAME__.server.app:app --workers 4 + """ + import uvicorn + + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + main(port=args.port) diff --git a/src/openenv/cli/templates/openenv_env/server/requirements.txt b/src/openenv/cli/templates/openenv_env/server/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..65b1c22b3db715ed9d63b9ad06cd4afb0d9412c5 --- /dev/null +++ b/src/openenv/cli/templates/openenv_env/server/requirements.txt @@ -0,0 +1,6 @@ +openenv[core]>=0.2.0 +fastapi>=0.115.0 +uvicorn>=0.24.0 + + + diff --git a/src/openenv/core/README.md b/src/openenv/core/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d153f1e4f72ce6c7b4e814c78c74e0e734c462b --- /dev/null +++ b/src/openenv/core/README.md @@ -0,0 +1,212 @@ +# image OpenEnv: Agentic Execution Environments + +An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - step(), reset(), state(). Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs. + +In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use. + + +## Overview +`openenv.core` provides the foundational building blocks for creating and interacting with containerized environments over HTTP. It enables you to build agent environments that can be deployed as Docker containers and accessed via a simple HTTP API. + +> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental +> stage. You should expect bugs, incomplete features, and APIs that may change +> in future versions. The project welcomes bugfixes, but to make sure things are +> well coordinated you should discuss any significant change before starting the +> work. It's recommended that you signal your intention to contribute in the +> issue tracker, either by filing a new issue or by claiming an existing one. + + +# OpenEnv Core + +Core components for OpenEnv - a framework for building HTTP-based agentic environments. + +## Features + +- **EnvClient**: Async-first client for interacting with remote environments +- **SyncEnvClient**: Synchronous wrapper via `.sync()` for sync codebases +- **HTTPEnvServer**: FastAPI-based server wrapper for exposing environments over HTTP/WebSocket +- **Container Providers**: Pluggable architecture for running containers (Docker, Kubernetes, etc.) +- **Type System**: Strongly-typed Action/Observation/State interfaces +- **Web Interface**: Optional web UI for interacting with environments + +## Installation + +```bash +pip install "openenv[core]" +``` + +For development: +```bash +pip install "openenv[core]" +``` + +## Quick Start + +### Creating an Environment Client + +EnvClient is **async by default**. Use `async with` and `await` for all operations: + +```python +import asyncio +from openenv.core import EnvClient, StepResult +from dataclasses import dataclass +from typing import Any + +@dataclass +class MyAction: + text: str + +@dataclass +class MyObservation: + response: str + +class MyEnvClient(EnvClient[MyAction, MyObservation, Any]): + def _step_payload(self, action: MyAction) -> dict: + return {"text": action.text} + + def _parse_result(self, payload: dict) -> StepResult[MyObservation]: + obs_data = payload["observation"] + return StepResult( + observation=MyObservation(**obs_data), + reward=payload.get("reward"), + done=payload.get("done", False) + ) + + def _parse_state(self, payload: dict) -> Any: + return payload + +# Async usage (recommended) +async def main(): + client = await MyEnvClient.from_docker_image("my-env:latest") + async with client: + result = await client.reset() + step_result = await client.step(MyAction(text="hello")) + +asyncio.run(main()) + +# Sync usage (via .sync() wrapper) +with MyEnvClient(base_url="http://localhost:8000").sync() as client: + result = client.reset() + step_result = client.step(MyAction(text="hello")) +``` + +### Creating an Environment Server + +```python +from openenv.core.env_server import Environment, HTTPEnvServer, create_app +from dataclasses import dataclass + +@dataclass +class MyAction: + text: str + +@dataclass +class MyObservation: + response: str + reward: float = 0.0 + done: bool = False + +class MyEnvironment(Environment): + def reset(self) -> MyObservation: + return MyObservation(response="Ready") + + def step(self, action: MyAction) -> MyObservation: + return MyObservation( + response=f"Echo: {action.text}", + reward=1.0, + done=False + ) + +# Create FastAPI app +env = MyEnvironment() +app = create_app(env, MyAction, MyObservation) + +# Run with: uvicorn module:app --host 0.0.0.0 --port 8000 +``` + +## Container Providers + +OpenEnv Core supports multiple container providers: + +### Local Docker Provider + +```python +from openenv.core.containers.runtime import LocalDockerProvider + +provider = LocalDockerProvider() +base_url = provider.start_container("my-env:latest") +provider.wait_for_ready(base_url) +# Use environment... +provider.stop_container() +``` + +### Kubernetes Provider (Coming Soon) + +```python +from openenv.core.containers.runtime import KubernetesProvider + +provider = KubernetesProvider(namespace="envs") +base_url = provider.start_container("my-env:latest") +# Use environment... +provider.stop_container() +``` + + +## API Reference + +### EnvClient + +Async base class for environment clients. Key methods: + +- `async connect()`: Establish WebSocket connection +- `async reset(**kwargs)`: Reset environment +- `async step(action)`: Execute action +- `async state()`: Get current state +- `async close()`: Close connection and cleanup +- `sync()`: Return a SyncEnvClient wrapper for synchronous usage + +Abstract methods to implement: +- `_step_payload(action)`: Convert action to JSON +- `_parse_result(payload)`: Parse response to StepResult +- `_parse_state(payload)`: Parse state response + +### SyncEnvClient + +Synchronous wrapper around EnvClient. Use `client.sync()` to get one: + +```python +sync_client = async_client.sync() +with sync_client: + result = sync_client.reset() + result = sync_client.step(action) +``` + +### HTTPEnvServer + +Server wrapper with these methods: + +- `register_routes(app)`: Register endpoints on FastAPI app +- `_deserialize_action(data)`: Convert JSON to Action +- `_serialize_observation(obs)`: Convert Observation to JSON + +### Environment Interface + +Base interface for environment implementations: + +- `reset()`: Reset environment and return initial observation +- `step(action)`: Execute action and return observation +- `state`: Property returning current environment state + +## License + +This project is licensed under the BSD-3-Clause License - see the LICENSE file for details. + +## Contributing + +Contributions are welcome! Please see the main OpenEnv repository for contribution guidelines. + +## Links + +- **Homepage**: https://github.com/meta-pytorch/OpenEnv +- **Documentation**: https://github.com/meta-pytorch/OpenEnv/blob/main/README.md +- **Bug Tracker**: https://github.com/meta-pytorch/OpenEnv/issues diff --git a/src/openenv/core/__init__.py b/src/openenv/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96065d6a80463e2fe599de7728243fc2adad7135 --- /dev/null +++ b/src/openenv/core/__init__.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core components for agentic environments.""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING + +from . import env_server +from .env_server import * # noqa: F403 + +if TYPE_CHECKING: + from .env_client import EnvClient + from .generic_client import GenericAction, GenericEnvClient + from .llm_client import ( + AnthropicClient, + create_llm_client, + LLMClient, + LLMResponse, + OpenAIClient, + ToolCall, + ) + from .mcp_client import MCPClientBase, MCPToolClient + from .sync_client import SyncEnvClient + +__all__ = [ + "EnvClient", + "SyncEnvClient", + "GenericEnvClient", + "GenericAction", + "MCPClientBase", + "MCPToolClient", + "AnthropicClient", + "LLMClient", + "LLMResponse", + "OpenAIClient", + "ToolCall", + "create_llm_client", +] + env_server.__all__ # type: ignore + + +_LAZY_ATTRS = { + "EnvClient": (".env_client", "EnvClient"), + "SyncEnvClient": (".sync_client", "SyncEnvClient"), + "GenericEnvClient": (".generic_client", "GenericEnvClient"), + "GenericAction": (".generic_client", "GenericAction"), + "MCPClientBase": (".mcp_client", "MCPClientBase"), + "MCPToolClient": (".mcp_client", "MCPToolClient"), + "AnthropicClient": (".llm_client", "AnthropicClient"), + "LLMClient": (".llm_client", "LLMClient"), + "LLMResponse": (".llm_client", "LLMResponse"), + "OpenAIClient": (".llm_client", "OpenAIClient"), + "ToolCall": (".llm_client", "ToolCall"), + "create_llm_client": (".llm_client", "create_llm_client"), +} + + +def __getattr__(name: str): + if name in _LAZY_ATTRS: + module_path, attr_name = _LAZY_ATTRS[name] + module = import_module(module_path, __name__) + value = getattr(module, attr_name) + globals()[name] = value + return value + + try: + value = getattr(env_server, name) + except AttributeError as exc: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from exc + + globals()[name] = value + return value + + +def __dir__() -> list[str]: + return sorted(set(globals().keys()) | set(__all__)) diff --git a/src/openenv/core/client_types.py b/src/openenv/core/client_types.py new file mode 100644 index 0000000000000000000000000000000000000000..c7501c656b66a780f29bf23309aaf00fab8df432 --- /dev/null +++ b/src/openenv/core/client_types.py @@ -0,0 +1,23 @@ +# Type definitions for EnvTorch +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar + +# Generic type for observations +ObsT = TypeVar("ObsT") +StateT = TypeVar("StateT") + + +@dataclass +class StepResult(Generic[ObsT]): + """ + Represents the result of one environment step. + + Attributes: + observation: The environment's observation after the action. + reward: Scalar reward for this step (optional). + done: Whether the episode is finished. + """ + + observation: ObsT + reward: Optional[float] = None + done: bool = False diff --git a/src/openenv/core/containers/__init__.py b/src/openenv/core/containers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38e67ef3cd60bf13a26ef7c8bf23986c3eb5990e --- /dev/null +++ b/src/openenv/core/containers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container management for environment servers.""" diff --git a/src/openenv/core/containers/images/Dockerfile b/src/openenv/core/containers/images/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..97bb1cf5e2ce0e58c82496cced3e58976baead4c --- /dev/null +++ b/src/openenv/core/containers/images/Dockerfile @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# OpenEnv Base Image +# +# This is the standard base image for all OpenEnv environment servers. +# It includes the minimal dependencies needed to run HTTP environment servers +# and uv for fast dependency management. +# +# Build from repo root: docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . +# Tag: docker tag openenv-base:latest openenv-base:0.2.0 +# + +FROM ghcr.io/astral-sh/uv:0.5.27-python3.11-bookworm-slim AS builder + +# Set working directory +WORKDIR /app + +# Copy core pyproject.toml and lockfile for dependency installation +COPY pyproject.toml uv.lock* ./ + +# Install core dependencies using uv with cache mount +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r pyproject.toml + +# Final runtime stage +FROM python:3.11-slim + +# Set metadata +LABEL maintainer="OpenEnv Team" +LABEL description="Base image for OpenEnv based environment servers with uv" +LABEL version="0.2.0" + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Copy uv from builder +COPY --from=builder /usr/local/bin/uv /usr/local/bin/uvx /usr/local/bin/ + +# Copy installed Python packages from builder +COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages + +# Copy console scripts installed by pip (uvicorn, fastapi, etc.) +COPY --from=builder /usr/local/bin/uvicorn /usr/local/bin/fastapi /usr/local/bin/ + +# Set working directory +WORKDIR /app + +# Default environment variables +ENV PYTHONPATH=/app/src +ENV PYTHONUNBUFFERED=1 +ENV UV_SYSTEM_PYTHON=1 + +# Default expose port (can be overridden) +EXPOSE 8000 + +# Note: CMD should be specified in child Dockerfiles diff --git a/src/openenv/core/containers/images/README.md b/src/openenv/core/containers/images/README.md new file mode 100644 index 0000000000000000000000000000000000000000..69c387909fc487bf4bebb2a18dced2185ecf477d --- /dev/null +++ b/src/openenv/core/containers/images/README.md @@ -0,0 +1,92 @@ +# OpenEnv Base Image + +Standard base image for all OpenEnv environment servers. + +## What's Included + +| Layer | Size | Contents | +|-------|------|----------| +| python:3.11-slim | 200 MB | Base Python runtime | +| + Dependencies | 100 MB | FastAPI, uvicorn, requests | +| **Total** | **~300 MB** | Ready for environment servers | + +## Image Sizes + +``` +openenv-base:latest 300 MB (python + fastapi + uvicorn) +``` +echo-env:latest 500 MB (python + fastapi + uvicorn + app) +coding-env:latest 520 MB (python + fastapi + uvicorn + app + tools) +another-env:latest 510 MB (python + fastapi + uvicorn + app) +--- +Total: 1.5 GB (with lots of duplication) +``` + +### With Base Images (✅ Solution) +``` +openenv-base:latest 300 MB (python + fastapi + uvicorn) +echo-env:latest 50 MB (app only, uses base) +coding-env:latest 70 MB (app + tools, uses base) +another-env:latest 45 MB (app only, uses base) +--- +Total: 465 MB (base shared, minimal duplication) +``` + +## Building the Base Image + +```bash +# From project root +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . +``` + +## Usage in Environment Dockerfiles + +Each environment Dockerfile should start with: + +```dockerfile +FROM openenv-base:latest + +# Copy only environment-specific files +COPY src/openenv/core/ /app/src/openenv/core/ +COPY envs/my_env/ /app/envs/my_env/ + +# Run the server +CMD ["uvicorn", "envs.my_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +## Base Image Contents + +- Python 3.11-slim +- FastAPI >= 0.104.0 +- Uvicorn >= 0.24.0 +- Requests >= 2.25.0 +- curl (for health checks) + +## Example: Building Echo Environment + +```bash +# Step 1: Build base image (do this once) +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . + +# Step 2: Build echo environment (uses base) +docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile . + +# Step 3: Run echo environment +docker run -p 8000:8000 echo-env:latest +``` + +## Updating the Base + +When dependencies need updating: + +1. Update `src/openenv/core/containers/images/Dockerfile` +2. Rebuild base image +3. Rebuild all environment images (they'll use new base) + +```bash +# Update base +docker build -t openenv-base:latest -f src/openenv/core/containers/images/Dockerfile . + +# Rebuild environments (they automatically use new base) +docker build -t echo-env:latest -f envs/echo_env/server/Dockerfile . +``` diff --git a/src/openenv/core/containers/runtime/__init__.py b/src/openenv/core/containers/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd514dc2fb78007e4ee1bf1f2e9777864bc76b00 --- /dev/null +++ b/src/openenv/core/containers/runtime/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container runtime providers.""" + +from .providers import ( + ContainerProvider, + DockerSwarmProvider, + KubernetesProvider, + LocalDockerProvider, + RuntimeProvider, +) +from .uv_provider import UVProvider + +__all__ = [ + "ContainerProvider", + "DockerSwarmProvider", + "LocalDockerProvider", + "KubernetesProvider", + "RuntimeProvider", + "UVProvider", +] diff --git a/src/openenv/core/containers/runtime/daytona_provider.py b/src/openenv/core/containers/runtime/daytona_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..08c899fa3f16520dbe7cb8c0804e23250d97f605 --- /dev/null +++ b/src/openenv/core/containers/runtime/daytona_provider.py @@ -0,0 +1,572 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Daytona container provider for running OpenEnv environments in Daytona cloud sandboxes. + +Requires the ``daytona`` SDK: ``pip install daytona>=0.10`` +""" + +from __future__ import annotations + +import json +import os +import shlex +import time +from typing import Any, Callable, Dict, Optional + +import yaml + +from .providers import ContainerProvider + + +class DaytonaProvider(ContainerProvider): + """ + Container provider that runs environments in Daytona cloud sandboxes. + + Example: + >>> provider = DaytonaProvider(api_key="your-key") + >>> image = DaytonaProvider.image_from_dockerfile("envs/echo_env/server/Dockerfile") + >>> base_url = provider.start_container(image) + >>> provider.wait_for_ready(base_url) + >>> provider.stop_container() + """ + + _dockerfile_registry: Dict[str, Dict[str, Any]] = {} + + def __init__( + self, + *, + api_key: Optional[str] = None, + public: bool = False, + resources: Optional[Any] = None, + auto_stop_interval: int = 15, + target: Optional[str] = None, + on_snapshot_create_logs: Optional[Callable[[str], None]] = None, + cmd: Optional[str] = None, + create_timeout: float = 300, + ): + """ + Args: + api_key: Daytona API key. Falls back to ``DAYTONA_API_KEY`` env var. + public: If True, the sandbox preview is publicly accessible. + resources: Optional ``daytona.Resources`` instance for CPU/memory. + auto_stop_interval: Minutes of inactivity before auto-stop (0 disables). + target: Daytona target region (e.g. "us"). + on_snapshot_create_logs: Callback for snapshot build log lines. + cmd: Shell command to start the server inside the sandbox. + create_timeout: Seconds to wait for sandbox creation (default 300). + Heavy images (e.g. with Playwright/Chromium) may need more. + """ + from daytona import Daytona, DaytonaConfig + + config_kwargs: Dict[str, Any] = {} + resolved_key = api_key or os.environ.get("DAYTONA_API_KEY") + if resolved_key: + config_kwargs["api_key"] = resolved_key + if target: + config_kwargs["target"] = target + + self._daytona = Daytona(DaytonaConfig(**config_kwargs)) + self._public = public + self._resources = resources + self._auto_stop_interval = auto_stop_interval + self._on_snapshot_create_logs = on_snapshot_create_logs + self._cmd = cmd + self._create_timeout = create_timeout + self._sandbox: Any = None + self._preview_url: Optional[str] = None + + def _discover_server_cmd(self, sandbox: Any, port: int = 8000) -> str: + """Discover the server command from ``openenv.yaml`` inside *sandbox*. + + Finds the file, reads the ``app`` field, and constructs a command + of the form ``cd && python -m uvicorn --host 0.0.0.0 --port ``. + + Raises: + ValueError: If ``openenv.yaml`` is not found or lacks an ``app`` field. + """ + yaml_path = self._find_openenv_yaml(sandbox) + if yaml_path is None: + raise ValueError( + "Could not find openenv.yaml inside the sandbox. " + "Pass an explicit cmd= to DaytonaProvider or start_container()." + ) + + cat_resp = sandbox.process.exec(f"cat {shlex.quote(yaml_path)}", timeout=10) + content = cat_resp.result if hasattr(cat_resp, "result") else str(cat_resp) + app = self._parse_app_field(content) + if app is None: + raise ValueError( + f"openenv.yaml at {yaml_path} does not contain an 'app' field. " + "Pass an explicit cmd= to DaytonaProvider or start_container()." + ) + + # The directory containing openenv.yaml is the env root + env_root = yaml_path.rsplit("/", 1)[0] + return ( + f"cd {shlex.quote(env_root)} && " + f"python -m uvicorn {shlex.quote(app)} --host 0.0.0.0 --port {port}" + ) + + def _find_openenv_yaml(self, sandbox: Any) -> Optional[str]: + """Locate ``openenv.yaml`` inside the sandbox. + + Tries the modern layout path ``/app/env/openenv.yaml`` first, + then falls back to a ``find`` command for the old layout. + """ + # Fast path: modern Dockerfile layout + resp = sandbox.process.exec( + "test -f /app/env/openenv.yaml && echo found", timeout=10 + ) + out = resp.result if hasattr(resp, "result") else str(resp) + if "found" in (out or ""): + return "/app/env/openenv.yaml" + + # Fallback: search for it (redirect stderr so error messages + # like "No such file or directory" don't get mistaken for paths). + resp = sandbox.process.exec( + "find /app -maxdepth 4 -name openenv.yaml -print -quit 2>/dev/null", + timeout=10, + ) + path = (resp.result if hasattr(resp, "result") else str(resp) or "").strip() + if path and path.startswith("/"): + return path + + return None + + @staticmethod + def _parse_app_field(yaml_content: str) -> Optional[str]: + """Extract the ``app`` value from raw openenv.yaml content. + + Uses PyYAML to handle comments, quotes, and nested keys correctly. + """ + try: + data = yaml.safe_load(yaml_content) or {} + except Exception: + return None + + if not isinstance(data, dict): + return None + + value = data.get("app") + if isinstance(value, str): + value = value.strip() + return value if value else None + return None + + @staticmethod + def _parse_dockerfile_cmd(dockerfile_content: str) -> Optional[str]: + """Extract the server command from the last ``CMD`` in a Dockerfile. + + Handles exec form (``CMD ["prog", "arg"]``) and shell form + (``CMD prog arg``). When a Dockerfile has multiple ``CMD`` + instructions (e.g. multi-stage builds), the last one wins - same + semantics as Docker itself. Lines where ``CMD`` appears inside a + comment are ignored. + + Returns: + The command as a single string, or ``None`` if no ``CMD`` found. + """ + import re + + last_cmd: Optional[str] = None + for line in dockerfile_content.splitlines(): + stripped = line.strip() + # Skip comments + if stripped.startswith("#"): + continue + match = re.match(r"CMD\s+(.+)", stripped, flags=re.IGNORECASE) + if match: + last_cmd = match.group(1).strip() + + if last_cmd is None: + return None + + # Exec form: CMD ["executable", "param1", ...] + if last_cmd.startswith("["): + try: + parts = json.loads(last_cmd) + if isinstance(parts, list) and all(isinstance(p, str) for p in parts): + return " ".join(parts) + except (json.JSONDecodeError, TypeError): + pass + + # Shell form: CMD executable param1 ... + return last_cmd if last_cmd else None + + @staticmethod + def strip_buildkit_syntax(dockerfile_content: str) -> str: + """Remove BuildKit ``--mount=...`` flags from ``RUN`` instructions. + + Handles single-line flags, multi-line continuations, and multiple + ``--mount`` flags spread across continuation lines. Only leading + ``--mount`` flags are removed (before the actual command starts). + + Daytona's ``Image.from_dockerfile`` does not support BuildKit + ``--mount`` syntax. This helper strips the flags so that standard + Dockerfiles (like the ones generated by ``openenv build``) can + be used directly. + """ + import re + + def strip_leading_mounts(text: str) -> str: + remaining = text + while True: + match = re.match(r"\s*--mount=\S+\s*", remaining) + if not match: + return remaining + remaining = remaining[match.end() :] + + lines = dockerfile_content.split("\n") + result: list[str] = [] + in_run = False + in_mount_prefix = False + + for line in lines: + line_out = line + run_start = False + if re.match(r"\s*RUN(\s+|$)", line, flags=re.IGNORECASE): + in_run = True + in_mount_prefix = True + run_start = True + + if in_run and in_mount_prefix: + original_ends_with_slash = line_out.rstrip().endswith("\\") + if run_start: + match = re.match(r"(\s*RUN\s+)(.*)$", line_out, flags=re.IGNORECASE) + if match: + run_prefix, remainder = match.group(1), match.group(2) + else: + run_prefix, remainder = line_out, "" + new_remainder = strip_leading_mounts(remainder) + line_out = run_prefix + new_remainder + content_for_check = new_remainder + else: + new_remainder = strip_leading_mounts(line_out) + line_out = new_remainder + content_for_check = new_remainder + + if original_ends_with_slash and not line_out.rstrip().endswith("\\"): + line_out = line_out.rstrip() + " \\" + + if content_for_check.strip() not in ("", "\\"): + in_mount_prefix = False + + if in_run and not line_out.rstrip().endswith("\\"): + in_run = False + in_mount_prefix = False + + result.append(line_out) + + return "\n".join(result) + + @classmethod + def image_from_dockerfile( + cls, + dockerfile_path: str, + context_dir: str | None = None, + ) -> str: + """Validate a Dockerfile and return a ``dockerfile:`` URI for + :meth:`start_container`. + + Eagerly validates the Dockerfile (existence, COPY sources, + BuildKit stripping) and stores the processed content in an + internal registry. The actual ``daytona.Image`` is created + later inside ``start_container``. + + Args: + dockerfile_path: Path to the Dockerfile on disk. + context_dir: Build context directory. Defaults to the + Dockerfile's grandparent directory, matching the + ``openenv init`` convention where Dockerfiles live in + ``/server/Dockerfile`` and the build context is + ``/``. Pass explicitly for non-standard layouts + (e.g. ``context_dir="."`` for repo-root contexts). + + Returns: + A ``"dockerfile:"`` string to pass to + ``start_container``. + + Raises: + FileNotFoundError: If *dockerfile_path* does not exist. + ValueError: If *context_dir* is given but does not exist, + or if COPY sources in the Dockerfile cannot be found + under the resolved context directory. + """ + import pathlib + import re + + src = pathlib.Path(dockerfile_path).resolve() + if not src.is_file(): + raise FileNotFoundError(f"Dockerfile not found: {dockerfile_path}") + + if context_dir is not None: + ctx = pathlib.Path(context_dir) + if not ctx.is_dir(): + raise ValueError(f"context_dir does not exist: {context_dir}") + else: + # Default: grandparent of the Dockerfile, matching the + # openenv init layout (/server/Dockerfile -> /). + ctx = src.parent.parent + + content = src.read_text() + stripped = cls.strip_buildkit_syntax(content) + + # Validate that COPY sources exist under the context directory. + # This catches mismatches early (e.g. a Dockerfile expecting repo + # root as context when we defaulted to the env directory). + for line in stripped.splitlines(): + m = re.match(r"^\s*COPY\s+(?!--from=)(\S+)\s+", line, re.IGNORECASE) + if not m: + continue + copy_src = m.group(1) + if copy_src.startswith("/"): + continue + resolved = ctx / copy_src + if not resolved.exists() and not any(ctx.glob(copy_src)): + raise ValueError( + f"Dockerfile COPY source '{copy_src}' not found " + f"under context_dir '{ctx}'. This Dockerfile may " + f"expect a different build context (e.g. the repo " + f"root). Pass context_dir explicitly." + ) + + # Parse CMD from the original Dockerfile so start_container can + # use it as a fallback when openenv.yaml is unavailable. + parsed_cmd = cls._parse_dockerfile_cmd(content) + + cls._dockerfile_registry[str(src)] = { + "stripped_content": stripped, + "context_dir": str(ctx), + "server_cmd": parsed_cmd, + } + + return f"dockerfile:{src}" + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Create a Daytona sandbox from a Docker image or snapshot. + + Daytona does not execute the image's CMD (known bug — ENTRYPOINT + runs, CMD does not). The server command is resolved in order: + + 1. Explicit ``cmd`` passed to the constructor. + 2. ``cmd`` key in ``**kwargs`` (popped before forwarding). + 3. Auto-discovered from ``openenv.yaml`` inside the sandbox. + 4. ``CMD`` parsed from the Dockerfile (when *image* came from + ``image_from_dockerfile``). + + Args: + image: Docker image name (e.g. ``"echo-env:latest"``), + ``"snapshot:"`` to create from a pre-built snapshot, + or ``"dockerfile:"`` returned by + :meth:`image_from_dockerfile`. + port: Must be ``None`` or ``8000``. Daytona exposes port 8000 + via its preview proxy; other ports raise ``ValueError``. + env_vars: Environment variables forwarded to the sandbox. + **kwargs: ``cmd`` (str) to override the server command; + remaining kwargs passed through to ``Daytona.create()``. + + Returns: + HTTPS preview URL for the sandbox (base_url). + """ + if port is not None and port != 8000: + raise ValueError( + f"DaytonaProvider only supports port 8000 (got {port}). " + "The Daytona preview proxy routes to port 8000 inside the sandbox." + ) + + # Resolve the server command (may be None; discovery happens after + # sandbox creation when we can inspect the filesystem). + cmd = kwargs.pop("cmd", None) or self._cmd + + # CMD parsed from Dockerfile (populated for "dockerfile:" images). + parsed_cmd: Optional[str] = None + + # Build creation params + create_kwargs: Dict[str, Any] = {} + if env_vars: + create_kwargs["env_vars"] = env_vars + if self._public: + create_kwargs["public"] = True + if self._auto_stop_interval != 15: + create_kwargs["auto_stop_interval"] = self._auto_stop_interval + + if image.startswith("snapshot:"): + from daytona import CreateSandboxFromSnapshotParams + + snapshot_name = image[len("snapshot:") :] + params = CreateSandboxFromSnapshotParams( + snapshot=snapshot_name, **create_kwargs + ) + elif image.startswith("dockerfile:"): + from daytona import CreateSandboxFromImageParams, Image + + dockerfile_path = image[len("dockerfile:") :] + meta = self._dockerfile_registry.get(dockerfile_path) + if meta is None: + raise ValueError( + f"No registered Dockerfile metadata for {dockerfile_path}. " + "Call DaytonaProvider.image_from_dockerfile() first." + ) + + parsed_cmd = meta.get("server_cmd") + + # Build the daytona Image from the pre-stripped content. + import pathlib + import uuid + + ctx = pathlib.Path(meta["context_dir"]) + tmp_name = f".daytona-{uuid.uuid4().hex[:8]}.dockerfile" + tmp_path = ctx / tmp_name + try: + tmp_path.write_text(meta["stripped_content"]) + daytona_image = Image.from_dockerfile(str(tmp_path)) + finally: + tmp_path.unlink(missing_ok=True) + + img_kwargs: Dict[str, Any] = { + "image": daytona_image, + **create_kwargs, + } + if self._resources is not None: + img_kwargs["resources"] = self._resources + params = CreateSandboxFromImageParams(**img_kwargs) + else: + from daytona import CreateSandboxFromImageParams + + img_kwargs = {"image": image, **create_kwargs} + if self._resources is not None: + img_kwargs["resources"] = self._resources + params = CreateSandboxFromImageParams(**img_kwargs) + + # Create sandbox + extra: Dict[str, Any] = dict(kwargs) + if self._on_snapshot_create_logs is not None: + extra["on_snapshot_create_logs"] = self._on_snapshot_create_logs + + self._sandbox = self._daytona.create( + params, timeout=self._create_timeout, **extra + ) + + try: + # Discover server command from openenv.yaml if not explicitly set. + if cmd is None: + try: + cmd = self._discover_server_cmd(self._sandbox) + except ValueError: + # Fall back to CMD parsed from Dockerfile (if available). + if parsed_cmd: + cmd = parsed_cmd + else: + raise + + # Wrap in bash -c so compound commands (cd ... && uvicorn ...) + # are handled correctly by nohup. Write PID so we can check + # if the process crashed later in wait_for_ready(). + escaped_cmd = shlex.quote(cmd) + self._sandbox.process.exec( + f"nohup bash -c {escaped_cmd} > /tmp/openenv-server.log 2>&1 &" + " echo $! > /tmp/openenv-server.pid", + timeout=10, + ) + + # Get a signed preview URL for port 8000. The token is + # embedded in the URL itself so no extra headers are needed. + signed = self._sandbox.create_signed_preview_url( + 8000, expires_in_seconds=86400 + ) + self._preview_url = signed.url + except Exception: + self.stop_container() + raise + + return self._preview_url + + def refresh_preview_url(self) -> str: + """Get a fresh signed preview URL (valid for 24h). + + Daytona signed URLs expire after at most 24 hours. Call this to + get a new one for long-running sessions. The returned URL points + to the same sandbox — clients will need to reconnect using it. + """ + if self._sandbox is None: + raise RuntimeError("No active sandbox to refresh URL for.") + signed = self._sandbox.create_signed_preview_url(8000, expires_in_seconds=86400) + self._preview_url = signed.url + return self._preview_url + + def stop_container(self) -> None: + """Delete the Daytona sandbox.""" + if self._sandbox is None: + return + + try: + self._daytona.delete(self._sandbox) + finally: + self._sandbox = None + self._preview_url = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None: + """ + Poll the /health endpoint until the sandbox is ready. + + Uses a longer default timeout (120s) than Docker providers because + Daytona sandboxes may have cold-start latency. + + Args: + base_url: Preview URL returned by ``start_container()``. + timeout_s: Maximum seconds to wait. + + Raises: + TimeoutError: If the sandbox doesn't become ready in time. + RuntimeError: If the server process died (detected via PID check). + """ + import requests + + health_url = f"{base_url}/health" + + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + response = requests.get(health_url, timeout=5.0) + if response.status_code == 200: + return + except requests.RequestException: + pass + + # Early exit: if the server process died, raise immediately + # instead of waiting for the full health-check timeout. + if self._sandbox is not None: + resp = self._sandbox.process.exec( + "kill -0 $(cat /tmp/openenv-server.pid) 2>/dev/null" + " && echo RUNNING || echo DEAD", + timeout=10, + ) + out = resp.result if hasattr(resp, "result") else str(resp) + if "DEAD" in (out or ""): + log_resp = self._sandbox.process.exec( + "cat /tmp/openenv-server.log 2>/dev/null", timeout=10 + ) + log = ( + log_resp.result + if hasattr(log_resp, "result") + else str(log_resp) + ) + raise RuntimeError(f"Server process died.\nLog:\n{log}") + + time.sleep(1.0) + + raise TimeoutError( + f"Daytona sandbox at {base_url} did not become ready within {timeout_s}s" + ) diff --git a/src/openenv/core/containers/runtime/providers.py b/src/openenv/core/containers/runtime/providers.py new file mode 100644 index 0000000000000000000000000000000000000000..54232a2495746f89cc81590ca87d03e6e48e3d2b --- /dev/null +++ b/src/openenv/core/containers/runtime/providers.py @@ -0,0 +1,669 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Container provider abstractions for running environment servers. + +This module provides a pluggable architecture for different container providers +(local Docker, Kubernetes, cloud providers, etc.) to be used with EnvClient. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Sequence + + +class ContainerProvider(ABC): + """ + Abstract base class for container providers. + + Providers implement this interface to support different container platforms: + - LocalDockerProvider: Runs containers on local Docker daemon + - KubernetesProvider: Runs containers in Kubernetes cluster + - FargateProvider: Runs containers on AWS Fargate + - CloudRunProvider: Runs containers on Google Cloud Run + + The provider manages a single container lifecycle and provides the base URL + for connecting to it. + + Example: + >>> provider = LocalDockerProvider() + >>> base_url = provider.start_container("echo-env:latest") + >>> print(base_url) # http://localhost:8000 + >>> # Use the environment via base_url + >>> provider.stop_container() + """ + + @abstractmethod + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a container from the specified image. + + Args: + image: Container image name (e.g., "echo-env:latest") + port: Port to expose (if None, provider chooses) + env_vars: Environment variables to pass to container + **kwargs: Provider-specific options + + Returns: + Base URL to connect to the container (e.g., "http://localhost:8000") + + Raises: + RuntimeError: If container fails to start + """ + pass + + @abstractmethod + def stop_container(self) -> None: + """ + Stop and remove the running container. + + This cleans up the container that was started by start_container(). + """ + pass + + @abstractmethod + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for the container to be ready to accept requests. + + This typically polls the /health endpoint until it returns 200. + + Args: + base_url: Base URL of the container + timeout_s: Maximum time to wait + + Raises: + TimeoutError: If container doesn't become ready in time + """ + pass + + +class LocalDockerProvider(ContainerProvider): + """ + Container provider for local Docker daemon. + + This provider runs containers on the local machine using Docker. + Useful for development and testing. + + Example: + >>> provider = LocalDockerProvider() + >>> base_url = provider.start_container("echo-env:latest") + >>> # Container running on http://localhost: + >>> provider.stop_container() + """ + + def __init__(self): + """Initialize the local Docker provider.""" + self._container_id: Optional[str] = None + self._container_name: Optional[str] = None + + # Check if Docker is available + import subprocess + + try: + subprocess.run( + ["docker", "version"], + check=True, + capture_output=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): + raise RuntimeError( + "Docker is not available. Please install Docker Desktop or Docker Engine." + ) + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a Docker container locally. + + Args: + image: Docker image name + port: Port to expose (if None, finds available port) + env_vars: Environment variables for the container + **kwargs: Additional Docker run options + + Returns: + Base URL to connect to the container + """ + import subprocess + import time + + # Find available port if not specified + if port is None: + port = self._find_available_port() + + # Generate container name + self._container_name = self._generate_container_name(image) + + # Build docker run command + cmd = [ + "docker", + "run", + "-d", # Detached + "--name", + self._container_name, + "-p", + f"{port}:8000", # Map port + ] + + # Add environment variables + if env_vars: + for key, value in env_vars.items(): + cmd.extend(["-e", f"{key}={value}"]) + + # Add image + cmd.append(image) + + # Run container + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + self._container_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = f"Failed to start Docker container.\nCommand: {' '.join(cmd)}\nExit code: {e.returncode}\nStderr: {e.stderr}\nStdout: {e.stdout}" + raise RuntimeError(error_msg) from e + + # Wait a moment for container to start + time.sleep(1) + + base_url = f"http://localhost:{port}" + return base_url + + def stop_container(self) -> None: + """ + Stop and remove the Docker container. + """ + if self._container_id is None: + return + + import subprocess + + try: + # Stop container + subprocess.run( + ["docker", "stop", self._container_id], + capture_output=True, + check=True, + timeout=10, + ) + + # Remove container + subprocess.run( + ["docker", "rm", self._container_id], + capture_output=True, + check=True, + timeout=10, + ) + except subprocess.CalledProcessError: + # Container might already be stopped/removed + pass + finally: + self._container_id = None + self._container_name = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for container to be ready by polling /health endpoint. + + Args: + base_url: Base URL of the container + timeout_s: Maximum time to wait + + Raises: + TimeoutError: If container doesn't become ready + """ + import time + + import requests + + start_time = time.time() + health_url = f"{base_url}/health" + + # Bypass proxy for localhost to avoid proxy issues + proxies = {"http": None, "https": None} + + while time.time() - start_time < timeout_s: + try: + response = requests.get(health_url, timeout=2.0, proxies=proxies) + if response.status_code == 200: + return + except requests.RequestException: + pass + + time.sleep(0.5) + + raise TimeoutError( + f"Container at {base_url} did not become ready within {timeout_s}s" + ) + + def _find_available_port(self) -> int: + """ + Find an available port on localhost. + + Returns: + An available port number + """ + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _generate_container_name(self, image: str) -> str: + """ + Generate a unique container name based on image name and timestamp. + + Args: + image: Docker image name + + Returns: + A unique container name + """ + import time + + clean_image = image.split("/")[-1].split(":")[0] + timestamp = int(time.time() * 1000) + return f"{clean_image}-{timestamp}" + + +class DockerSwarmProvider(ContainerProvider): + """ + Container provider that uses Docker Swarm services for local concurrency. + + This provider creates a replicated Swarm service backed by the local Docker + engine. The built-in load-balancer fans requests across the replicas, + allowing multiple container instances to run concurrently on the developer + workstation (mirroring the workflow described in the Docker stack docs). + """ + + def __init__( + self, + *, + auto_init_swarm: bool = True, + overlay_network: Optional[str] = None, + ): + """ + Args: + auto_init_swarm: Whether to call ``docker swarm init`` when Swarm + is not active. Otherwise, user must manually initialize Swarm. + overlay_network: Optional overlay network name for the service. + When provided, the network is created with + ``docker network create --driver overlay --attachable`` if it + does not already exist. + """ + self._service_name: Optional[str] = None + self._service_id: Optional[str] = None + self._published_port: Optional[int] = None + self._overlay_network = overlay_network + self._auto_init_swarm = auto_init_swarm + + self._ensure_docker_available() + self._ensure_swarm_initialized() + if self._overlay_network: + self._ensure_overlay_network(self._overlay_network) + + def start_container( + self, + image: str, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start (or scale) a Swarm service for the given image. + + Supported kwargs: + replicas (int): Number of container replicas (default: 2). + cpu_limit (float | str): CPU limit passed to ``--limit-cpu``. + memory_limit (str): Memory limit passed to ``--limit-memory``. + constraints (Sequence[str]): Placement constraints. + labels (Dict[str, str]): Service labels. + command (Sequence[str] | str): Override container command. + """ + import shlex + import subprocess + import time + + allowed_kwargs = { + "replicas", + "cpu_limit", + "memory_limit", + "constraints", + "labels", + "command", + } + unknown = set(kwargs) - allowed_kwargs + if unknown: + raise ValueError(f"Unsupported kwargs for DockerSwarmProvider: {unknown}") + + replicas = int(kwargs.get("replicas", 2)) + cpu_limit = kwargs.get("cpu_limit") + memory_limit = kwargs.get("memory_limit") + constraints: Optional[Sequence[str]] = kwargs.get("constraints") + labels: Optional[Dict[str, str]] = kwargs.get("labels") + command_override = kwargs.get("command") + + if port is None: + port = self._find_available_port() + + self._service_name = self._generate_service_name(image) + self._published_port = port + + cmd = [ + "docker", + "service", + "create", + "--detach", + "--name", + self._service_name, + "--replicas", + str(max(1, replicas)), + "--publish", + f"{port}:8000", + ] + + if self._overlay_network: + cmd.extend(["--network", self._overlay_network]) + + if env_vars: + for key, value in env_vars.items(): + cmd.extend(["--env", f"{key}={value}"]) + + if cpu_limit is not None: + cmd.extend(["--limit-cpu", str(cpu_limit)]) + + if memory_limit is not None: + cmd.extend(["--limit-memory", str(memory_limit)]) + + if constraints: + for constraint in constraints: + cmd.extend(["--constraint", constraint]) + + if labels: + for key, value in labels.items(): + cmd.extend(["--label", f"{key}={value}"]) + + cmd.append(image) + + if command_override: + if isinstance(command_override, str): + cmd.extend(shlex.split(command_override)) + else: + cmd.extend(command_override) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + ) + self._service_id = result.stdout.strip() + except subprocess.CalledProcessError as e: + error_msg = ( + "Failed to start Docker Swarm service.\n" + f"Command: {' '.join(cmd)}\n" + f"Exit code: {e.returncode}\n" + f"Stdout: {e.stdout}\n" + f"Stderr: {e.stderr}" + ) + raise RuntimeError(error_msg) from e + + # Give Swarm a brief moment to schedule the tasks. + time.sleep(1.0) + + return f"http://localhost:{port}" + + def stop_container(self) -> None: + """ + Remove the Swarm service (and keep the Swarm manager running). + """ + if not self._service_name: + return + + import subprocess + + try: + subprocess.run( + ["docker", "service", "rm", self._service_name], + capture_output=True, + check=True, + timeout=10, + ) + except subprocess.CalledProcessError: + # Service may already be gone; ignore. + pass + finally: + self._service_name = None + self._service_id = None + self._published_port = None + + def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None: + """ + Wait for at least one replica to become healthy by polling /health. + + Note: With Swarm's load balancer, requests round-robin across replicas, + so this only verifies that at least one replica is responding. Some + replicas may still be starting when this returns. + """ + import time + + import requests + + deadline = time.time() + timeout_s + health_url = f"{base_url}/health" + + # Bypass proxy for localhost to avoid proxy issues + proxies = {"http": None, "https": None} + + while time.time() < deadline: + try: + response = requests.get(health_url, timeout=2.0, proxies=proxies) + if response.status_code == 200: + return + except requests.RequestException: + pass + + time.sleep(0.5) + + raise TimeoutError( + f"Swarm service at {base_url} did not become ready within {timeout_s}s" + ) + + def _ensure_docker_available(self) -> None: + import subprocess + + try: + subprocess.run( + ["docker", "version"], + check=True, + capture_output=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ) as exc: + raise RuntimeError( + "Docker is not available. Please install Docker Desktop or Docker Engine." + ) from exc + + def _ensure_swarm_initialized(self) -> None: + import subprocess + + try: + result = subprocess.run( + ["docker", "info", "--format", "{{.Swarm.LocalNodeState}}"], + capture_output=True, + text=True, + check=True, + timeout=5, + ) + state = result.stdout.strip().lower() + if state == "active": + return + except subprocess.CalledProcessError: + state = "unknown" + + if not self._auto_init_swarm: + raise RuntimeError( + f"Docker Swarm is not active (state={state}). Enable Swarm manually or pass auto_init_swarm=True." + ) + + try: + subprocess.run( + ["docker", "swarm", "init"], + check=True, + capture_output=True, + timeout=10, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError("Failed to initialize Docker Swarm") from e + + def _ensure_overlay_network(self, network: str) -> None: + import subprocess + + inspect = subprocess.run( + ["docker", "network", "inspect", network], + capture_output=True, + text=True, + check=False, + ) + if inspect.returncode == 0: + return + + try: + subprocess.run( + [ + "docker", + "network", + "create", + "--driver", + "overlay", + "--attachable", + network, + ], + check=True, + capture_output=True, + timeout=10, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to create overlay network '{network}'") from e + + def _find_available_port(self) -> int: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _generate_service_name(self, image: str) -> str: + import time + + clean_image = image.split("/")[-1].split(":")[0] + timestamp = int(time.time() * 1000) + return f"{clean_image}-swarm-{timestamp}" + + +class KubernetesProvider(ContainerProvider): + """ + Container provider for Kubernetes clusters. + + This provider creates pods in a Kubernetes cluster and exposes them + via services or port-forwarding. + + Example: + >>> provider = KubernetesProvider(namespace="envtorch-dev") + >>> base_url = provider.start_container("echo-env:latest") + >>> # Pod running in k8s, accessible via service or port-forward + >>> provider.stop_container() + """ + + pass + + +class RuntimeProvider(ABC): + """ + Abstract base class for runtime providers that are not container providers. + Providers implement this interface to support different runtime platforms: + - UVProvider: Runs environments via `uv run` + + The provider manages a single runtime lifecycle and provides the base URL + for connecting to it. + + Example: + >>> provider = UVProvider(project_path="/path/to/env") + >>> base_url = provider.start() + >>> print(base_url) # http://localhost:8000 + >>> provider.stop() + """ + + @abstractmethod + def start( + self, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> str: + """ + Start a runtime from the specified image. + + Args: + image: Runtime image name + port: Port to expose (if None, provider chooses) + env_vars: Environment variables for the runtime + **kwargs: Additional runtime options + """ + + @abstractmethod + def stop(self) -> None: + """ + Stop the runtime. + """ + pass + + @abstractmethod + def wait_for_ready(self, timeout_s: float = 30.0) -> None: + """ + Wait for the runtime to be ready to accept requests. + """ + pass + + def __enter__(self) -> "RuntimeProvider": + """ + Enter the runtime provider. + """ + self.start() + return self + + def __exit__(self, exc_type, exc, tb) -> None: + """ + Exit the runtime provider. + """ + self.stop() + return False diff --git a/src/openenv/core/containers/runtime/uv_provider.py b/src/openenv/core/containers/runtime/uv_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..3ddc89b9bdccbd0d18604c3de5f49fd3cbc74612 --- /dev/null +++ b/src/openenv/core/containers/runtime/uv_provider.py @@ -0,0 +1,224 @@ +"""Providers for launching ASGI applications via ``uv run``.""" + +from __future__ import annotations + +import os +import socket +import subprocess +import time +from typing import Dict, Optional + +import requests + +from .providers import RuntimeProvider + + +def _check_uv_installed() -> None: + try: + subprocess.check_output(["uv", "--version"]) + except FileNotFoundError as exc: + raise RuntimeError( + "`uv` executable not found. Install uv from https://docs.astral.sh and ensure it is on PATH." + ) from exc + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + sock.listen(1) + return sock.getsockname()[1] + + +def _create_uv_command( + *, + host: str, + port: int, + reload: bool, + workers: int, + app: str, + project_path: str, +) -> list[str]: + command: list[str] = ["uv", "run", "--isolated", "--project", project_path] + + command.append("--") + command.extend( + [ + "uvicorn", + app, + "--host", + host, + "--port", + str(port), + "--workers", + str(workers), + ] + ) + + if reload: + command.append("--reload") + + return command + + +def _poll_health(health_url: str, timeout_s: float) -> None: + """Poll a health endpoint until it returns HTTP 200 or times out.""" + + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + timeout = max(0.0001, min(deadline - time.time(), 2.0)) + response = requests.get(health_url, timeout=timeout) + if response.status_code == 200: + return + except requests.RequestException: + continue + + time.sleep(0.5) + + raise TimeoutError(f"Server did not become ready within {timeout_s:.1f} seconds") + + +class UVProvider(RuntimeProvider): + """ + RuntimeProvider implementation backed by ``uv run``. + + Args: + project_path: Local path to a uv project (passed to ``uv run --project``) + app: ASGI application path for uvicorn (defaults to ``server.app:app``) + host: Host interface to bind to (defaults to ``0.0.0.0``) + reload: Whether to enable uvicorn's reload mode + env_vars: Environment variables to pass through to the spawned process + context_timeout_s: How long to wait for the environment to become ready + + Example: + >>> provider = UVProvider(project_path="/path/to/env") + >>> base_url = provider.start() + >>> print(base_url) # http://localhost:8000 + >>> # Use the environment via base_url + >>> provider.stop() + """ + + def __init__( + self, + *, + project_path: str, + app: str = "server.app:app", + host: str = "0.0.0.0", + reload: bool = False, + env_vars: Optional[Dict[str, str]] = None, + context_timeout_s: float = 60.0, + ): + """Initialize the UVProvider.""" + self.project_path = os.path.abspath(project_path) + self.app = app + self.host = host + self.reload = reload + self.env_vars = env_vars + self.context_timeout_s = context_timeout_s + _check_uv_installed() + self._process = None + self._base_url = None + + def start( + self, + port: Optional[int] = None, + env_vars: Optional[Dict[str, str]] = None, + workers: int = 1, + **_: Dict[str, str], + ) -> str: + """ + Start the environment via `uv run`. + + Args: + port: The port to bind the environment to + env_vars: Environment variables to pass to the environment + workers: The number of workers to use + + Returns: + The base URL of the environment + + Raises: + RuntimeError: If the environment is already running + """ + if self._process is not None and self._process.poll() is None: + raise RuntimeError("UVProvider is already running") + + bind_port = port or _find_free_port() + + command = _create_uv_command( + host=self.host, + port=bind_port, + reload=self.reload, + workers=workers, + app=self.app, + project_path=self.project_path, + ) + + env = os.environ.copy() + + if self.env_vars: + env.update(self.env_vars) + if env_vars: + env.update(env_vars) + + try: + self._process = subprocess.Popen(command, env=env) + except OSError as exc: + raise RuntimeError(f"Failed to launch `uv run`: {exc}") from exc + + client_host = "127.0.0.1" if self.host in {"0.0.0.0", "::"} else self.host + self._base_url = f"http://{client_host}:{bind_port}" + return self._base_url + + def wait_for_ready(self, timeout_s: float = 60.0) -> None: + """ + Wait for the environment to become ready. + + Args: + timeout_s: The timeout to wait for the environment to become ready + + Raises: + RuntimeError: If the environment is not running + TimeoutError: If the environment does not become ready within the timeout + """ + if self._process and self._process.poll() is not None: + code = self._process.returncode + raise RuntimeError(f"uv process exited prematurely with code {code}") + + _poll_health(f"{self._base_url}/health", timeout_s=timeout_s) + + def stop(self) -> None: + """ + Stop the environment. + + Raises: + RuntimeError: If the environment is not running + """ + if self._process is None: + return + + if self._process.poll() is None: + self._process.terminate() + try: + self._process.wait(timeout=10.0) + except subprocess.TimeoutExpired: + self._process.kill() + self._process.wait(timeout=5.0) + + self._process = None + self._base_url = None + + @property + def base_url(self) -> str: + """ + The base URL of the environment. + + Returns: + The base URL of the environment + + Raises: + RuntimeError: If the environment is not running + """ + if self._base_url is None: + raise RuntimeError("UVProvider has not been started") + return self._base_url diff --git a/src/openenv/core/containers/test_local_docker_provider.py b/src/openenv/core/containers/test_local_docker_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..ac520a4b68afa699894dd68c0508b1e41936704c --- /dev/null +++ b/src/openenv/core/containers/test_local_docker_provider.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +""" +End-to-end test for LocalDockerProvider. + +This script tests the complete flow: +1. Start a container using LocalDockerProvider +2. Wait for it to be ready +3. Make HTTP requests to test the environment +4. Clean up the container +""" + +import sys +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import requests +from openenv.core.containers.runtime import LocalDockerProvider + + +# TODO: Remove this test or make it a functional test sicne this will be tested in e2e test for echo env +def test_local_docker_provider(): + """Test LocalDockerProvider end-to-end.""" + print("=" * 60) + print("LocalDockerProvider End-to-End Test") + print("=" * 60) + print() + + provider = None + + try: + # Step 1: Create provider + print("Step 1: Creating LocalDockerProvider...") + provider = LocalDockerProvider() + print("✓ Provider created\n") + + # Step 2: Start container + print("Step 2: Starting echo-env container...") + base_url = provider.start_container("echo-env:latest") + print(f"✓ Container started at: {base_url}") + if provider._container_id: + print(f" Container ID: {provider._container_id[:12]}...") + if provider._container_name: + print(f" Container name: {provider._container_name}\n") + + # Step 3: Wait for ready + print("Step 3: Waiting for container to be ready...") + provider.wait_for_ready(base_url, timeout_s=30.0) + print("✓ Container is ready!\n") + + # Step 4: Test health endpoint + print("Step 4: Testing /health endpoint...") + response = requests.get(f"{base_url}/health") + print(f" Status: {response.status_code}") + print(f" Response: {response.json()}") + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + print("✓ Health check passed\n") + + # Step 5: Test reset endpoint + print("Step 5: Testing /reset endpoint...") + response = requests.post( + f"{base_url}/reset", + json={}, + headers={"Content-Type": "application/json"}, + ) + print(f" Status: {response.status_code}") + data = response.json() + print(f" Message: {data['observation']['echoed_message']}") + print(f" Reward: {data['reward']}") + print(f" Done: {data['done']}") + assert response.status_code == 200 + assert data["observation"]["echoed_message"] == "Echo environment ready!" + print("✓ Reset test passed\n") + + # Step 6: Test step endpoint + print("Step 6: Testing /step endpoint...") + response = requests.post( + f"{base_url}/step", + json={"action": {"message": "Hello from LocalDockerProvider!"}}, + headers={"Content-Type": "application/json"}, + ) + print(f" Status: {response.status_code}") + data = response.json() + print(f" Echoed: {data['observation']['echoed_message']}") + print(f" Length: {data['observation']['message_length']}") + print(f" Reward: {data['reward']}") + assert response.status_code == 200 + assert ( + data["observation"]["echoed_message"] == "Hello from LocalDockerProvider!" + ) + assert data["observation"]["message_length"] == 31 + print("✓ Step test passed\n") + + # Step 7: Test state endpoint + print("Step 7: Testing /state endpoint...") + response = requests.get(f"{base_url}/state") + print(f" Status: {response.status_code}") + data = response.json() + print(f" Episode ID: {data['episode_id']}") + print(f" Step count: {data['step_count']}") + assert response.status_code == 200 + assert data["step_count"] == 1 # One step from above + print("✓ State test passed\n") + + # Step 8: Multiple steps + print("Step 8: Testing multiple steps...") + for i in range(3): + response = requests.post( + f"{base_url}/step", + json={"action": {"message": f"Message {i + 1}"}}, + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 200 + print(f" Step {i + 1}: ✓") + + # Check state updated + response = requests.get(f"{base_url}/state") + data = response.json() + assert data["step_count"] == 4 # 1 + 3 more steps + print(f" Final step count: {data['step_count']}") + print("✓ Multiple steps test passed\n") + + print("=" * 60) + print("✓ All tests passed!") + print("=" * 60) + print() + + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + import traceback + + traceback.print_exc() + return False + + finally: + # Step 9: Cleanup + if provider is not None: + print("\nStep 9: Cleaning up container...") + try: + provider.stop_container() + print("✓ Container stopped and removed\n") + except Exception as e: + print(f"⚠️ Cleanup warning: {e}\n") + + +def test_provider_with_custom_port(): + """Test provider with custom port.""" + print("=" * 60) + print("LocalDockerProvider with Custom Port Test") + print("=" * 60) + print() + + provider = None + + try: + provider = LocalDockerProvider() + + print("Starting container on custom port 8123...") + base_url = provider.start_container("echo-env:latest", port=8123) + print(f"✓ Started at: {base_url}") + assert ":8123" in base_url + + print("Waiting for ready...") + provider.wait_for_ready(base_url) + print("✓ Ready!") + + print("Testing health...") + response = requests.get(f"{base_url}/health") + assert response.status_code == 200 + print("✓ Health check passed") + + print("\n✓ Custom port test passed!\n") + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + return False + + finally: + if provider is not None: + provider.stop_container() + print("✓ Cleaned up\n") + + +def test_provider_with_env_vars(): + """Test provider with environment variables.""" + print("=" * 60) + print("LocalDockerProvider with Environment Variables Test") + print("=" * 60) + print() + + provider = None + + try: + provider = LocalDockerProvider() + + print("Starting container with environment variables...") + base_url = provider.start_container( + "echo-env:latest", env_vars={"DEBUG": "true", "LOG_LEVEL": "info"} + ) + print(f"✓ Started at: {base_url}") + + print("Waiting for ready...") + provider.wait_for_ready(base_url) + print("✓ Ready!") + + print("Testing health...") + response = requests.get(f"{base_url}/health") + assert response.status_code == 200 + print("✓ Health check passed") + + print("\n✓ Environment variables test passed!\n") + return True + + except Exception as e: + print(f"\n❌ Test failed: {e}") + return False + + finally: + if provider is not None: + provider.stop_container() + print("✓ Cleaned up\n") + + +if __name__ == "__main__": + print() + print("🐳 LocalDockerProvider Test Suite") + print() + + results = [] + + # Run basic test + results.append(("Basic End-to-End", test_local_docker_provider())) + + # Run custom port test + results.append(("Custom Port", test_provider_with_custom_port())) + + # Run environment variables test + results.append(("Environment Variables", test_provider_with_env_vars())) + + # Summary + print("=" * 60) + print("Test Summary") + print("=" * 60) + for name, passed in results: + status = "✓ PASSED" if passed else "✗ FAILED" + print(f"{name:25} {status}") + print("=" * 60) + + all_passed = all(result for _, result in results) + if all_passed: + print("\n🎉 All tests passed!") + exit(0) + else: + print("\n❌ Some tests failed") + exit(1) diff --git a/src/openenv/core/env_client.py b/src/openenv/core/env_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceb344bca20d55d2f9e7ba9aa39595ef61fca30 --- /dev/null +++ b/src/openenv/core/env_client.py @@ -0,0 +1,484 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Environment client for persistent sessions. + +This module provides a WebSocket-based client that maintains a persistent connection +to an environment server, enabling efficient multi-step interactions without +the overhead of HTTP request/response cycles. + +The client is async by default. For synchronous usage, use the `.sync()` method +to get a `SyncEnvClient` wrapper. + +Example (async): + >>> async with GenericEnvClient(base_url="ws://localhost:8000") as env: + ... result = await env.reset() + ... result = await env.step({"code": "print('hello')"}) + +Example (sync wrapper): + >>> env = GenericEnvClient(base_url="ws://localhost:8000").sync() + >>> with env: + ... result = env.reset() + ... result = env.step({"code": "print('hello')"}) +""" + +from __future__ import annotations + +import asyncio +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, Type, TYPE_CHECKING, TypeVar + +from .client_types import StateT, StepResult +from .containers.runtime import LocalDockerProvider, UVProvider +from .utils import convert_to_ws_url + +if TYPE_CHECKING: + from websockets.asyncio.client import ClientConnection + + from .containers.runtime import ContainerProvider, RuntimeProvider + from .sync_client import SyncEnvClient + +from websockets.asyncio.client import connect as ws_connect + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") +EnvClientT = TypeVar("EnvClientT", bound="EnvClient") + + +class EnvClient(ABC, Generic[ActT, ObsT, StateT]): + """ + Async environment client for persistent sessions. + + This client maintains a persistent WebSocket connection to an environment + server, enabling efficient multi-step interactions. Each client instance + corresponds to a dedicated environment session on the server. + + The client is async by default. For synchronous usage, use the `.sync()` + method to get a `SyncEnvClient` wrapper. + + Features: + - Lower latency for sequential interactions + - Session state is maintained server-side + - Better suited for long-running episodes + - Async by default for modern Python async/await patterns + + Example (async): + >>> from envs.coding_env.client import CodingEnv + >>> + >>> # Connect to a server using async context manager + >>> async with CodingEnv(base_url="ws://localhost:8000") as env: + ... result = await env.reset(seed=42) + ... while not result.done: + ... action = agent.predict(result.observation) + ... result = await env.step(action) + + Example (sync wrapper): + >>> env = CodingEnv(base_url="ws://localhost:8000").sync() + >>> with env: + ... result = env.reset(seed=42) + ... result = env.step(action) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + max_message_size_mb: float = 100.0, + provider: Optional["ContainerProvider | RuntimeProvider"] = None, + mode: Optional[str] = None, + ): + """ + Initialize environment client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + Will be converted to ws:// if http:// is provided. + connect_timeout_s: Timeout for establishing WebSocket connection + message_timeout_s: Timeout for receiving responses to messages + max_message_size_mb: Maximum WebSocket message size in megabytes. + Default 100MB to handle large observations (screenshots, DOM, etc.) + provider: Optional container/runtime provider for lifecycle management. + Can be a ContainerProvider (Docker) or RuntimeProvider (UV). + mode: Communication mode: 'simulation' for Gym-style API (default) or + 'production' for MCP JSON-RPC protocol. Can also be set via the + OPENENV_CLIENT_MODE environment variable. Constructor parameter + takes precedence over environment variable. Case-insensitive. + """ + # Determine mode (constructor > env var > default) + if mode is None: + mode = os.environ.get("OPENENV_CLIENT_MODE", "simulation") + + # Normalize and validate mode + mode = mode.lower() + if mode not in ("simulation", "production"): + raise ValueError( + f"Invalid mode: '{mode}'. Must be 'simulation' or 'production'. " + f"Set via constructor parameter or OPENENV_CLIENT_MODE environment variable." + ) + + # Store mode (use object.__setattr__ to bypass immutability) + object.__setattr__(self, "_mode", mode) + + # Convert HTTP URL to WebSocket URL + ws_url = convert_to_ws_url(base_url) + + self._ws_url = f"{ws_url}/ws" + self._connect_timeout = connect_timeout_s + self._message_timeout = message_timeout_s + self._max_message_size = int( + max_message_size_mb * 1024 * 1024 + ) # Convert MB to bytes + self._provider = provider + self._ws: Optional[ClientConnection] = None + + def __setattr__(self, name: str, value: Any) -> None: + """Prevent modification of _mode after initialization.""" + if name == "_mode" and hasattr(self, "_mode"): + raise AttributeError("Cannot modify mode after initialization") + super().__setattr__(name, value) + + async def connect(self) -> "EnvClient": + """ + Establish WebSocket connection to the server. + + Returns: + self for method chaining + + Raises: + ConnectionError: If connection cannot be established + """ + if self._ws is not None: + return self + + # Bypass proxy for localhost connections + ws_url_lower = self._ws_url.lower() + is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower + + old_no_proxy = os.environ.get("NO_PROXY") + if is_localhost: + # Set NO_PROXY to bypass proxy for localhost + current_no_proxy = old_no_proxy or "" + if "localhost" not in current_no_proxy.lower(): + os.environ["NO_PROXY"] = ( + f"{current_no_proxy},localhost,127.0.0.1" + if current_no_proxy + else "localhost,127.0.0.1" + ) + + try: + self._ws = await ws_connect( + self._ws_url, + open_timeout=self._connect_timeout, + max_size=self._max_message_size, + ) + except Exception as e: + raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e + finally: + # Restore original NO_PROXY value + if is_localhost: + if old_no_proxy is None: + os.environ.pop("NO_PROXY", None) + else: + os.environ["NO_PROXY"] = old_no_proxy + + return self + + async def disconnect(self) -> None: + """Close the WebSocket connection.""" + if self._ws is not None: + try: + # Send close message + await self._send({"type": "close"}) + except Exception: + pass # Best effort + try: + await self._ws.close() + except Exception: + pass + self._ws = None + + async def _ensure_connected(self) -> None: + """Ensure WebSocket connection is established.""" + if self._ws is None: + await self.connect() + + async def _send(self, message: Dict[str, Any]) -> None: + """Send a message over the WebSocket.""" + await self._ensure_connected() + assert self._ws is not None + await self._ws.send(json.dumps(message)) + + async def _receive(self) -> Dict[str, Any]: + """Receive and parse a message from the WebSocket.""" + assert self._ws is not None + raw = await asyncio.wait_for(self._ws.recv(), timeout=self._message_timeout) + return json.loads(raw) + + async def _send_and_receive(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Send a message and wait for response.""" + await self._send(message) + response = await self._receive() + + # Check for error response + if response.get("type") == "error": + error_data = response.get("data", {}) + raise RuntimeError( + f"Server error: {error_data.get('message', 'Unknown error')} " + f"(code: {error_data.get('code', 'UNKNOWN')})" + ) + + return response + + @classmethod + async def from_docker_image( + cls: Type[EnvClientT], + image: str, + provider: Optional["ContainerProvider"] = None, + **kwargs: Any, + ) -> EnvClientT: + """ + Create an environment client by spinning up a Docker container. + + Args: + image: Docker image name to run (e.g., "coding-env:latest") + provider: Container provider to use (defaults to LocalDockerProvider) + **kwargs: Additional arguments to pass to provider.start_container() + + Returns: + Connected client instance + """ + if provider is None: + provider = LocalDockerProvider() + + # Start container + base_url = provider.start_container(image, **kwargs) + + # Wait for server to be ready + provider.wait_for_ready(base_url) + + # Create and connect client + client = cls(base_url=base_url, provider=provider) + await client.connect() + + return client + + @classmethod + async def from_env( + cls: Type[EnvClientT], + repo_id: str, + *, + use_docker: bool = True, + provider: Optional["ContainerProvider | RuntimeProvider"] = None, + **provider_kwargs: Any, + ) -> EnvClientT: + """ + Create a client from a Hugging Face Space. + + Args: + repo_id: Hugging Face space identifier ``{org}/{space}``. + use_docker: When ``True`` (default) pull from the HF registry and + launch via :class:`LocalDockerProvider`. When ``False`` run the + space locally with :class:`UVProvider`. + provider: Optional provider instance to reuse. Must be a + :class:`ContainerProvider` when ``use_docker=True`` and a + :class:`RuntimeProvider` otherwise. + provider_kwargs: Additional keyword arguments forwarded to + either the container provider's ``start_container`` (docker) + or to the ``UVProvider`` constructor/start (uv). When + ``use_docker=False``, the ``project_path`` argument can be + used to override the default git URL + (``git+https://huggingface.co/spaces/{repo_id}``). + + Returns: + Connected client instance + + Examples: + >>> # Pull and run from HF Docker registry + >>> env = await MyEnv.from_env("openenv/echo-env") + >>> + >>> # Run locally with UV (clones the space) + >>> env = await MyEnv.from_env("openenv/echo-env", use_docker=False) + >>> + >>> # Run from a local checkout + >>> env = await MyEnv.from_env( + ... "openenv/echo-env", + ... use_docker=False, + ... project_path="/path/to/local/checkout" + ... ) + """ + # Extract start args that apply to both providers + start_args = {} + for key in ("port", "env_vars", "workers"): + if key in provider_kwargs: + start_args[key] = provider_kwargs.pop(key) + + if use_docker: + # Docker mode: pull from HF registry + docker_provider = provider or LocalDockerProvider() + tag = provider_kwargs.pop("tag", "latest") + image = f"registry.hf.space/{repo_id.replace('/', '-')}:{tag}" + base_url = docker_provider.start_container( + image, **start_args, **provider_kwargs + ) + docker_provider.wait_for_ready(base_url) + + client = cls(base_url=base_url, provider=docker_provider) + await client.connect() + return client + else: + # UV mode: clone and run with uv + if provider is None: + uv_kwargs = dict(provider_kwargs) + project_path = uv_kwargs.pop("project_path", None) + if project_path is None: + project_path = f"git+https://huggingface.co/spaces/{repo_id}" + + provider = UVProvider(project_path=project_path, **uv_kwargs) + else: + if provider_kwargs: + raise ValueError( + "provider_kwargs cannot be used when supplying a provider instance" + ) + + base_url = provider.start(**start_args) + provider.wait_for_ready() + + client = cls(base_url=base_url, provider=provider) + await client.connect() + return client + + @abstractmethod + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + raise NotImplementedError + + @abstractmethod + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Convert a JSON response from the env server to StepResult[ObsT].""" + raise NotImplementedError + + @abstractmethod + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Convert a JSON response from the state endpoint to a State object.""" + raise NotImplementedError + + async def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment with optional parameters. + + Args: + **kwargs: Optional parameters passed to the environment's reset method. + Common parameters include: + - seed: Random seed for reproducibility + - episode_id: Custom episode identifier + + Returns: + StepResult containing initial observation + """ + message = { + "type": "reset", + "data": kwargs, + } + response = await self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + async def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters (currently ignored) + + Returns: + StepResult containing observation, reward, and done status + """ + message = { + "type": "step", + "data": self._step_payload(action), + } + response = await self._send_and_receive(message) + return self._parse_result(response.get("data", {})) + + async def state(self) -> StateT: + """ + Get the current environment state from the server. + + Returns: + State object with environment state information + """ + message = {"type": "state"} + response = await self._send_and_receive(message) + return self._parse_state(response.get("data", {})) + + async def close(self) -> None: + """ + Close the WebSocket connection and clean up resources. + + If this client was created via from_docker_image() or from_env(), + this will also stop and remove the associated container/process. + """ + await self.disconnect() + + if self._provider is not None: + # Handle both ContainerProvider and RuntimeProvider + if hasattr(self._provider, "stop_container"): + self._provider.stop_container() + elif hasattr(self._provider, "stop"): + self._provider.stop() + + async def __aenter__(self) -> "EnvClient": + """Enter async context manager, ensuring connection is established.""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit async context manager, closing connection.""" + await self.close() + + def __enter__(self) -> "EnvClient": + """Sync context manager entry - raises error suggesting async usage.""" + raise TypeError( + "EnvClient is async by default. Use 'async with' instead of 'with', " + "or call .sync() to get a synchronous wrapper:\n" + " async with client: # async usage\n" + " with client.sync(): # sync wrapper" + ) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Sync context manager exit - should not be reached.""" + pass # pragma: no cover + + def sync(self) -> "SyncEnvClient": + """ + Return a synchronous wrapper around this async client. + + Use this method when you need synchronous access to the environment + without async/await syntax. This is useful for: + - Integration with synchronous codebases + - Interactive/REPL usage + - Stopping async from "infecting" the call stack + + Returns: + SyncEnvClient wrapper that provides synchronous methods + + Example: + >>> # Create async client and get sync wrapper + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous API + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"code": "print('hello')"}) + """ + from .sync_client import SyncEnvClient + + return SyncEnvClient(self) diff --git a/src/openenv/core/env_server/__init__.py b/src/openenv/core/env_server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0f1f2845f09ec758c1fcedb16dbb771059156b --- /dev/null +++ b/src/openenv/core/env_server/__init__.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core environment interfaces and types.""" + +from .base_transforms import CompositeTransform, NullTransform +from .exceptions import ( + ConcurrencyConfigurationError, + EnvironmentFactoryError, + OpenEnvError, + SessionCapacityError, + SessionCreationError, + SessionNotFoundError, +) +from .http_server import create_app, create_fastapi_app, HTTPEnvServer +from .interfaces import Environment, Message, ModelTokenizer, Transform + +try: + from .mcp_environment import MCPEnvironment +except ModuleNotFoundError: + MCPEnvironment = None # type: ignore[assignment] + +from .mcp_types import ( + CallToolAction, + CallToolObservation, + JsonRpcError, + # JSON-RPC types + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + ListToolsAction, + ListToolsObservation, + McpMethod, + RESERVED_TOOL_NAMES, + Tool, + ToolError, + ToolErrorType, + WSMCPMessage, + WSMCPResponse, +) +from .route_config import GetEndpointConfig +from .serialization import ( + deserialize_action, + deserialize_action_with_preprocessing, + serialize_observation, +) +from .types import ( + Action, + BaseMessage, + ConcurrencyConfig, + HealthResponse, + HealthStatus, + Observation, + SchemaResponse, + ServerCapacityStatus, + ServerMode, + SessionInfo, + State, + WSCloseMessage, + WSErrorCode, + WSErrorResponse, + WSIncomingMessage, + WSObservationResponse, + WSResetMessage, + WSStateMessage, + WSStateResponse, + WSStepMessage, +) + +try: + from .web_interface import create_web_interface_app, WebInterfaceManager +except ModuleNotFoundError: + create_web_interface_app = None # type: ignore[assignment] + WebInterfaceManager = None # type: ignore[assignment] + +__all__ = [ + # Core interfaces + "Environment", + "Transform", + "Message", + "ModelTokenizer", + # Types + "Action", + "Observation", + "State", + "SchemaResponse", + "HealthResponse", + # Enums + "HealthStatus", + "ServerMode", + "WSErrorCode", + # WebSocket message types + "BaseMessage", + "WSIncomingMessage", + "WSResetMessage", + "WSStepMessage", + "WSStateMessage", + "WSCloseMessage", + "WSObservationResponse", + "WSStateResponse", + "WSErrorResponse", + # Concurrency types + "ConcurrencyConfig", + "ServerCapacityStatus", + "SessionInfo", + # Exceptions + "OpenEnvError", + "ConcurrencyConfigurationError", + "SessionCapacityError", + "SessionNotFoundError", + "SessionCreationError", + "EnvironmentFactoryError", + # Base transforms + "CompositeTransform", + "NullTransform", + # HTTP Server + "HTTPEnvServer", + "create_app", + "create_fastapi_app", + # Web Interface + "create_web_interface_app", + "WebInterfaceManager", + # Serialization utilities + "deserialize_action", + "deserialize_action_with_preprocessing", + "serialize_observation", + # Route configuration + "GetEndpointConfig", + # MCP types + "Tool", + "ToolError", + "ToolErrorType", + "ListToolsAction", + "CallToolAction", + "ListToolsObservation", + "CallToolObservation", + "WSMCPMessage", + "WSMCPResponse", + "RESERVED_TOOL_NAMES", + "MCPEnvironment", + # JSON-RPC types + "JsonRpcErrorCode", + "JsonRpcError", + "JsonRpcRequest", + "JsonRpcResponse", + "McpMethod", +] diff --git a/src/openenv/core/env_server/base_transforms.py b/src/openenv/core/env_server/base_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..ab48ebb48b58962ff56d282713a1d63907b0f390 --- /dev/null +++ b/src/openenv/core/env_server/base_transforms.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base transform implementations for composing environment-specific transforms.""" + +from .interfaces import Transform +from .types import Observation + + +class CompositeTransform(Transform): + """Combines multiple transforms into a single transform.""" + + def __init__(self, transforms: list[Transform]): + self.transforms = transforms + + def __call__(self, observation: Observation) -> Observation: + for transform in self.transforms: + observation = transform(observation) + return observation + + +class NullTransform(Transform): + """Default transform that passes through unchanged.""" + + def __call__(self, observation: Observation) -> Observation: + return observation diff --git a/src/openenv/core/env_server/exceptions.py b/src/openenv/core/env_server/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..5701913e0bcac67e6f84d3861d57c4949665677a --- /dev/null +++ b/src/openenv/core/env_server/exceptions.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Custom exceptions for environment server operations.""" + +from typing import Optional + + +class OpenEnvError(Exception): + """Base exception for all OpenEnv errors.""" + + pass + + +class ConcurrencyConfigurationError(OpenEnvError): + """ + Raised when an environment is misconfigured for concurrent sessions. + + This error is raised during server startup when max_concurrent_envs > 1 + is specified for an environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + + def __init__( + self, + environment_name: str, + max_concurrent_envs: int, + message: Optional[str] = None, + ): + self.environment_name = environment_name + self.max_concurrent_envs = max_concurrent_envs + + if message is None: + message = ( + f"Environment '{environment_name}' is not marked as SUPPORTS_CONCURRENT_SESSIONS. " + f"Cannot run with max_concurrent_envs={max_concurrent_envs}. " + f"Either set max_concurrent_envs=1 or ensure the environment " + f"properly isolates session state and set SUPPORTS_CONCURRENT_SESSIONS=True." + ) + + super().__init__(message) + + +class SessionCapacityError(OpenEnvError): + """ + Raised when the server cannot accept new sessions due to capacity limits. + + This error is raised when a new WebSocket connection is attempted but + the server has already reached max_concurrent_envs active sessions. + """ + + def __init__( + self, + active_sessions: int, + max_sessions: int, + message: Optional[str] = None, + ): + self.active_sessions = active_sessions + self.max_sessions = max_sessions + + if message is None: + message = ( + f"Server at capacity: {active_sessions}/{max_sessions} sessions active. " + f"Cannot accept new connections." + ) + + super().__init__(message) + + +class SessionNotFoundError(OpenEnvError): + """Raised when attempting to access a session that does not exist.""" + + def __init__(self, session_id: str, message: Optional[str] = None): + self.session_id = session_id + + if message is None: + message = f"Session '{session_id}' not found." + + super().__init__(message) + + +class SessionCreationError(OpenEnvError): + """Raised when a session cannot be created.""" + + def __init__(self, reason: str, message: Optional[str] = None): + self.reason = reason + + if message is None: + message = f"Failed to create session: {reason}" + + super().__init__(message) + + +class EnvironmentFactoryError(OpenEnvError): + """Raised when the environment factory fails to create an instance.""" + + def __init__(self, factory_name: str, message: Optional[str] = None): + self.factory_name = factory_name + + if message is None: + message = f"Environment factory '{factory_name}' failed to create instance." + + super().__init__(message) diff --git a/src/openenv/core/env_server/gradio_theme.py b/src/openenv/core/env_server/gradio_theme.py new file mode 100644 index 0000000000000000000000000000000000000000..7cebea2284d8d19e41d5954b498bcc3bb7ff39a4 --- /dev/null +++ b/src/openenv/core/env_server/gradio_theme.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unified terminal-style theme for OpenEnv Gradio UI (light/dark).""" + +from __future__ import annotations + +import gradio as gr + +_MONO_FONTS = ( + "JetBrains Mono", + "Fira Code", + "Cascadia Code", + "Consolas", + "ui-monospace", + "monospace", +) + +_CORE_FONT = ( + "Lato", + "Inter", + "Arial", + "Helvetica", + "sans-serif", +) + +_ZERO_RADIUS = gr.themes.Size( + xxs="0px", + xs="0px", + sm="0px", + md="0px", + lg="0px", + xl="0px", + xxl="0px", +) + +_GREEN_HUE = gr.themes.Color( + c50="#e6f4ea", + c100="#ceead6", + c200="#a8dab5", + c300="#6fcc8b", + c400="#3fb950", + c500="#238636", + c600="#1a7f37", + c700="#116329", + c800="#0a4620", + c900="#033a16", + c950="#04200d", +) + +_NEUTRAL_HUE = gr.themes.Color( + c50="#f6f8fa", + c100="#eaeef2", + c200="#d0d7de", + c300="#afb8c1", + c400="#8c959f", + c500="#6e7781", + c600="#57606a", + c700="#424a53", + c800="#32383f", + c900="#24292f", + c950="#1b1f24", +) + +OPENENV_GRADIO_THEME = gr.themes.Base( + primary_hue=_GREEN_HUE, + secondary_hue=_NEUTRAL_HUE, + neutral_hue=_NEUTRAL_HUE, + font=_CORE_FONT, + font_mono=_MONO_FONTS, + radius_size=_ZERO_RADIUS, +).set( + body_background_fill="#ffffff", + background_fill_primary="#ffffff", + background_fill_secondary="#f6f8fa", + block_background_fill="#ffffff", + block_border_color="#ffffff", + block_label_text_color="#57606a", + block_title_text_color="#24292f", + border_color_primary="#d0d7de", + input_background_fill="#ffffff", + input_border_color="#d0d7de", + button_primary_background_fill="#1a7f37", + button_primary_background_fill_hover="#116329", + button_primary_text_color="#ffffff", + button_secondary_background_fill="#f6f8fa", + button_secondary_background_fill_hover="#eaeef2", + button_secondary_text_color="#24292f", + button_secondary_border_color="#d0d7de", + body_background_fill_dark="#0d1117", + background_fill_primary_dark="#0d1117", + background_fill_secondary_dark="#0d1117", + block_background_fill_dark="#0d1117", + block_border_color_dark="#0d1117", + block_label_text_color_dark="#8b949e", + block_title_text_color_dark="#c9d1d9", + border_color_primary_dark="#30363d", + input_background_fill_dark="#0d1117", + input_border_color_dark="#30363d", + button_primary_background_fill_dark="#30363d", + button_primary_background_fill_hover_dark="#484f58", + button_primary_text_color_dark="#c9d1d9", + button_secondary_background_fill_dark="#21262d", + button_secondary_background_fill_hover_dark="#30363d", + button_secondary_text_color_dark="#c9d1d9", + button_secondary_border_color_dark="#30363d", +) + +OPENENV_GRADIO_CSS = """ +* { border-radius: 0 !important; } +.col-left { padding: 16px !important; } +.col-right { padding: 16px !important; } +.prose, .markdown-text, .md, +.prose > *, .markdown-text > * { + background: transparent !important; + border: none !important; + box-shadow: none !important; +} +.dark .col-left { + border-left-color: rgba(139, 148, 158, 0.4) !important; +} +.dark .col-right { + border-left-color: rgba(201, 209, 217, 0.3) !important; +} +""" diff --git a/src/openenv/core/env_server/gradio_ui.py b/src/openenv/core/env_server/gradio_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1a630bd1db39588304b42520f08bb45f477e81 --- /dev/null +++ b/src/openenv/core/env_server/gradio_ui.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Gradio-based web UI for OpenEnv environments. + +Replaces the legacy HTML/JavaScript interface when ENABLE_WEB_INTERFACE is set. +Mount at /web via gr.mount_gradio_app() from create_web_interface_app(). +""" + +from __future__ import annotations + +import json +import re +from typing import Any, Dict, List, Optional + +import gradio as gr + +from .types import EnvironmentMetadata + + +def _escape_md(text: str) -> str: + """Escape Markdown special characters in user-controlled content.""" + return re.sub(r"([\\`*_\{\}\[\]()#+\-.!|~>])", r"\\\1", str(text)) + + +def _format_observation(data: Dict[str, Any]) -> str: + """Format reset/step response for Markdown display.""" + lines: List[str] = [] + obs = data.get("observation", {}) + if isinstance(obs, dict): + if obs.get("prompt"): + lines.append(f"**Prompt:**\n\n{_escape_md(obs['prompt'])}\n") + messages = obs.get("messages", []) + if messages: + lines.append("**Messages:**\n") + for msg in messages: + sender = _escape_md(str(msg.get("sender_id", "?"))) + content = _escape_md(str(msg.get("content", ""))) + cat = _escape_md(str(msg.get("category", ""))) + lines.append(f"- `[{cat}]` Player {sender}: {content}") + lines.append("") + reward = data.get("reward") + done = data.get("done") + if reward is not None: + lines.append(f"**Reward:** `{reward}`") + if done is not None: + lines.append(f"**Done:** `{done}`") + return "\n".join(lines) if lines else "*No observation data*" + + +def _readme_section(metadata: Optional[EnvironmentMetadata]) -> str: + """README content for the left panel.""" + if not metadata or not metadata.readme_content: + return "*No README available.*" + return metadata.readme_content + + +def get_gradio_display_title( + metadata: Optional[EnvironmentMetadata], + fallback: str = "OpenEnv Environment", +) -> str: + """Return the title used for the Gradio app (browser tab and Blocks).""" + name = metadata.name if metadata else fallback + return f"OpenEnv Agentic Environment: {name}" + + +def build_gradio_app( + web_manager: Any, + action_fields: List[Dict[str, Any]], + metadata: Optional[EnvironmentMetadata], + is_chat_env: bool, + title: str = "OpenEnv Environment", + quick_start_md: Optional[str] = None, +) -> gr.Blocks: + """ + Build a Gradio Blocks app for the OpenEnv web interface. + + Args: + web_manager: WebInterfaceManager (reset/step_environment, get_state). + action_fields: Field dicts from _extract_action_fields(action_cls). + metadata: Environment metadata for README/name. + is_chat_env: If True, single message textbox; else form from action_fields. + title: App title (overridden by metadata.name when present; see get_gradio_display_title). + quick_start_md: Optional Quick Start markdown (class names already replaced). + + Returns: + gr.Blocks to mount with gr.mount_gradio_app(app, blocks, path="/web"). + """ + readme_content = _readme_section(metadata) + display_title = get_gradio_display_title(metadata, fallback=title) + + async def reset_env(): + try: + data = await web_manager.reset_environment() + obs_md = _format_observation(data) + return ( + obs_md, + json.dumps(data, indent=2), + "Environment reset successfully.", + ) + except Exception as e: + return ("", "", f"Error: {e}") + + def _step_with_action(action_data: Dict[str, Any]): + async def _run(): + try: + data = await web_manager.step_environment(action_data) + obs_md = _format_observation(data) + return ( + obs_md, + json.dumps(data, indent=2), + "Step complete.", + ) + except Exception as e: + return ("", "", f"Error: {e}") + + return _run + + async def step_chat(message: str): + if not (message or str(message).strip()): + return ("", "", "Please enter an action message.") + action = {"message": str(message).strip()} + return await _step_with_action(action)() + + def get_state_sync(): + try: + data = web_manager.get_state() + return json.dumps(data, indent=2) + except Exception as e: + return f"Error: {e}" + + with gr.Blocks(title=display_title) as demo: + with gr.Row(): + with gr.Column(scale=1, elem_classes="col-left"): + if quick_start_md: + with gr.Accordion("Quick Start", open=True): + gr.Markdown(quick_start_md) + with gr.Accordion("README", open=False): + gr.Markdown(readme_content) + + with gr.Column(scale=2, elem_classes="col-right"): + obs_display = gr.Markdown( + value=("# Playground\n\nClick **Reset** to start a new episode."), + ) + with gr.Group(): + if is_chat_env: + action_input = gr.Textbox( + label="Action message", + placeholder="e.g. Enter your message...", + ) + step_inputs = [action_input] + step_fn = step_chat + else: + step_inputs = [] + for field in action_fields: + name = field["name"] + field_type = field.get("type", "text") + label = name.replace("_", " ").title() + placeholder = field.get("placeholder", "") + if field_type == "checkbox": + inp = gr.Checkbox(label=label) + elif field_type == "number": + inp = gr.Number(label=label) + elif field_type == "select": + choices = field.get("choices") or [] + inp = gr.Dropdown( + choices=choices, + label=label, + allow_custom_value=False, + ) + elif field_type in ("textarea", "tensor"): + inp = gr.Textbox( + label=label, + placeholder=placeholder, + lines=3, + ) + else: + inp = gr.Textbox( + label=label, + placeholder=placeholder, + ) + step_inputs.append(inp) + + async def step_form(*values): + if not action_fields: + return await _step_with_action({})() + action_data = {} + for i, field in enumerate(action_fields): + if i >= len(values): + break + name = field["name"] + val = values[i] + if field.get("type") == "checkbox": + action_data[name] = bool(val) + elif val is not None and val != "": + action_data[name] = val + return await _step_with_action(action_data)() + + step_fn = step_form + + with gr.Row(): + step_btn = gr.Button("Step", variant="primary") + reset_btn = gr.Button("Reset", variant="secondary") + state_btn = gr.Button("Get state", variant="secondary") + with gr.Row(): + status = gr.Textbox( + label="Status", + interactive=False, + ) + raw_json = gr.Code( + label="Raw JSON response", + language="json", + interactive=False, + ) + + reset_btn.click( + fn=reset_env, + outputs=[obs_display, raw_json, status], + ) + step_btn.click( + fn=step_fn, + inputs=step_inputs, + outputs=[obs_display, raw_json, status], + ) + if is_chat_env: + action_input.submit( + fn=step_fn, + inputs=step_inputs, + outputs=[obs_display, raw_json, status], + ) + state_btn.click( + fn=get_state_sync, + outputs=[raw_json], + ) + + return demo diff --git a/src/openenv/core/env_server/http_server.py b/src/openenv/core/env_server/http_server.py new file mode 100644 index 0000000000000000000000000000000000000000..658f63ef98bf78d278b8926271c217da23c79a37 --- /dev/null +++ b/src/openenv/core/env_server/http_server.py @@ -0,0 +1,1391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +HTTP server wrapper for Environment instances. + +This module provides utilities to wrap any Environment subclass and expose it +over HTTP and WebSocket endpoints that EnvClient can consume. +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import os +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Dict, Optional, Type + +from fastapi import ( + Body, + FastAPI, + HTTPException, + Request, + status, + WebSocket, + WebSocketDisconnect, +) +from pydantic import ValidationError + +from .interfaces import Environment +from .mcp_environment import get_server_tools +from .mcp_types import ( + JsonRpcErrorCode, + JsonRpcRequest, + JsonRpcResponse, + McpMethod, + WSMCPMessage, + WSMCPResponse, +) +from .route_config import GetEndpointConfig, register_get_endpoints +from .serialization import deserialize_action, serialize_observation +from .types import ( + Action, + ConcurrencyConfig, + EnvironmentMetadata, + HealthResponse, + HealthStatus, + Observation, + ResetRequest, + ResetResponse, + SchemaResponse, + ServerCapacityStatus, + ServerMode, + SessionInfo, + State, + StepRequest, + StepResponse, + WSCloseMessage, + WSErrorCode, + WSErrorResponse, + WSObservationResponse, + WSResetMessage, + WSStateMessage, + WSStateResponse, + WSStepMessage, +) + + +def _make_json_serializable(obj: Any) -> Any: + """ + Convert an object to a JSON-serializable form. + + Handles Pydantic models, dataclasses, and other common types. + + Args: + obj: The object to convert + + Returns: + A JSON-serializable representation of the object + """ + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, (list, tuple)): + return [_make_json_serializable(item) for item in obj] + if isinstance(obj, dict): + return {k: _make_json_serializable(v) for k, v in obj.items()} + if hasattr(obj, "model_dump"): + # Pydantic model + return obj.model_dump() + if hasattr(obj, "__dict__"): + # Object with __dict__ + return {k: _make_json_serializable(v) for k, v in obj.__dict__.items()} + # Fallback to string representation + return str(obj) + + +from .exceptions import ( + ConcurrencyConfigurationError, + EnvironmentFactoryError, + SessionCapacityError, +) + + +class HTTPEnvServer: + """ + HTTP server wrapper for Environment instances. + + This class wraps an Environment and exposes its reset(), step(), and state + methods as HTTP and WebSocket endpoints compatible with EnvClient. + + The server expects: + - Action deserialization: Converts JSON dict to Action subclass + - Observation serialization: Converts Observation subclass to JSON dict + + Example: + >>> from core.env_server import HTTPEnvServer + >>> from envs.coding_env.server import CodeExecutionEnvironment + >>> from envs.coding_env.models import CodeAction, CodeObservation + >>> + >>> # Pass environment class (factory pattern) + >>> server = HTTPEnvServer( + ... env=CodeExecutionEnvironment, + ... action_cls=CodeAction, + ... observation_cls=CodeObservation, + ... max_concurrent_envs=4, + ... ) + >>> + >>> # Register routes with FastAPI + >>> from fastapi import FastAPI + >>> app = FastAPI() + >>> server.register_routes(app) + """ + + def __init__( + self, + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, + ): + """ + Initialize HTTP server wrapper. + + Args: + env: Environment factory (callable) that creates new instances. + Will be called to create a new environment for each WebSocket session. + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum number of concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Raises: + ValueError: If both max_concurrent_envs and concurrency_config are provided. + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + # Validate that env is callable + if not callable(env): + raise TypeError( + f"env must be a callable (class or factory function), got {type(env)}. " + f"Pass the environment class (e.g., MyEnvironment) not an instance (e.g., MyEnvironment())." + ) + + self._env_factory: Callable[[], Environment] = env + + # Handle concurrency configuration + if max_concurrent_envs is not None and concurrency_config is not None: + raise ValueError( + "Cannot specify both 'max_concurrent_envs' and 'concurrency_config'. " + "Please use only one method to configure concurrency." + ) + + if concurrency_config is not None: + self._concurrency_config = concurrency_config + elif max_concurrent_envs is not None: + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=max_concurrent_envs, + session_timeout=None, + ) + else: + # Default configuration + self._concurrency_config = ConcurrencyConfig( + max_concurrent_envs=1, + session_timeout=None, + ) + + self._max_concurrent_envs = self._concurrency_config.max_concurrent_envs + + # Validate concurrency configuration + self._validate_concurrency_safety() + + self.action_cls = action_cls + self.observation_cls = observation_cls + + # Session management for WebSocket connections + self._sessions: Dict[str, Environment] = {} + self._session_executors: Dict[str, ThreadPoolExecutor] = {} + self._session_info: Dict[str, SessionInfo] = {} + self._session_lock = asyncio.Lock() + + # Create thread pool for running sync code in async context + # This is needed for environments using sync libraries (e.g., Playwright) + self._executor = ThreadPoolExecutor(max_workers=32) + + def _validate_concurrency_safety(self) -> None: + """ + Validate that the environment supports the configured concurrency level. + + Raises: + ConcurrencyConfigurationError: If max_concurrent_envs > 1 for an + environment that is not marked as SUPPORTS_CONCURRENT_SESSIONS. + """ + if self._max_concurrent_envs <= 1: + return + + if inspect.isclass(self._env_factory): + env_cls = self._env_factory + else: + _temp_env = self._env_factory() + env_cls = type(_temp_env) + _temp_env.close() + del _temp_env + + if not getattr(env_cls, "SUPPORTS_CONCURRENT_SESSIONS", False): + raise ConcurrencyConfigurationError( + environment_name=env_cls.__name__, + max_concurrent_envs=self._max_concurrent_envs, + ) + + def get_capacity_status(self) -> ServerCapacityStatus: + """ + Get the current capacity status of the server. + + Returns: + ServerCapacityStatus with current session counts and availability. + """ + return ServerCapacityStatus.from_counts( + active=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + async def _run_sync_in_thread_pool( + self, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the thread pool executor.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs)) + + def _get_valid_kwargs( + self, + sig: inspect.Signature, + kwargs: Dict[str, Any], + skip_params: Optional[set[str]] = None, + ) -> Dict[str, Any]: + """Filter kwargs to only include parameters accepted by the function signature.""" + if skip_params is None: + skip_params = set() + + valid_kwargs = {} + + has_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + + for k, v in kwargs.items(): + if k in sig.parameters or has_kwargs: + if k not in skip_params: + valid_kwargs[k] = v + + return valid_kwargs + + async def _create_session(self) -> tuple[str, Environment]: + """ + Create a new WebSocket session with its own environment instance. + + Returns: + Tuple of (session_id, environment) + + Raises: + SessionCapacityError: If max concurrent sessions reached + EnvironmentFactoryError: If the factory fails to create an environment + """ + async with self._session_lock: + if len(self._sessions) >= self._max_concurrent_envs: + raise SessionCapacityError( + active_sessions=len(self._sessions), + max_sessions=self._max_concurrent_envs, + ) + + session_id = str(uuid.uuid4()) + current_time = time.time() + + # Create executor and reserve slot so capacity is not exceeded while + # we create the env outside the lock (avoids blocking other sessions) + executor = ThreadPoolExecutor(max_workers=1) + self._session_executors[session_id] = executor + self._sessions[session_id] = None # placeholder until env is ready + + try: + # Create environment in the executor thread (outside lock) + loop = asyncio.get_event_loop() + env = await loop.run_in_executor(executor, self._env_factory) + except Exception as e: + async with self._session_lock: + executor.shutdown(wait=False) + self._session_executors.pop(session_id, None) + self._sessions.pop(session_id, None) + factory_name = getattr( + self._env_factory, "__name__", str(self._env_factory) + ) + raise EnvironmentFactoryError(factory_name) from e + + async with self._session_lock: + self._sessions[session_id] = env + self._session_info[session_id] = SessionInfo( + session_id=session_id, + created_at=current_time, + last_activity_at=current_time, + step_count=0, + environment_type=type(env).__name__, + ) + + return session_id, env + + async def _destroy_session(self, session_id: str) -> None: + """ + Destroy a WebSocket session and cleanup resources. + + Args: + session_id: The session ID to destroy + """ + async with self._session_lock: + env = self._sessions.pop(session_id, None) + executor = self._session_executors.pop(session_id, None) + self._session_info.pop(session_id, None) + + # Run close() in the same executor where the env was created + # This is required for thread-sensitive libraries like Playwright/greenlet + if env is not None: + if executor is not None: + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(executor, env.close) + except Exception: + # If executor close fails, try direct close as fallback + try: + env.close() + except Exception: + pass # Best effort cleanup + else: + try: + env.close() + except Exception: + pass # Best effort cleanup + + # Shutdown executor after close is done + if executor is not None: + executor.shutdown(wait=False) + + def _update_session_activity( + self, session_id: str, increment_step: bool = False + ) -> None: + """ + Update session activity timestamp and optionally increment step count. + + Args: + session_id: The session ID to update + increment_step: If True, increment the step count + """ + if session_id in self._session_info: + self._session_info[session_id].last_activity_at = time.time() + if increment_step: + self._session_info[session_id].step_count += 1 + + def get_session_info(self, session_id: str) -> Optional[SessionInfo]: + """ + Get information about a specific session. + + Args: + session_id: The session ID to query + + Returns: + SessionInfo if the session exists, None otherwise + """ + return self._session_info.get(session_id) + + async def _run_in_session_executor( + self, session_id: str, func: Callable[..., Observation], *args, **kwargs + ) -> Observation: + """Run a synchronous function in the session's thread pool executor.""" + executor = self._session_executors.get(session_id, self._executor) + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, lambda: func(*args, **kwargs)) + + @property + def active_sessions(self) -> int: + """Return the number of active WebSocket sessions.""" + return len(self._sessions) + + @property + def max_concurrent_envs(self) -> int: + """Return the maximum number of concurrent environments.""" + return self._max_concurrent_envs + + @property + def is_concurrency_safe(self) -> bool: + """Return whether the environment is marked as concurrency safe.""" + import inspect + + if inspect.isclass(self._env_factory): + return getattr(self._env_factory, "SUPPORTS_CONCURRENT_SESSIONS", False) + else: + _temp_env = self._env_factory() + result = getattr(_temp_env, "SUPPORTS_CONCURRENT_SESSIONS", False) + _temp_env.close() + del _temp_env + return result + + @property + def concurrency_config(self) -> ConcurrencyConfig: + """Return the concurrency configuration.""" + return self._concurrency_config + + def register_routes( + self, app: FastAPI, mode: ServerMode | str = ServerMode.SIMULATION + ) -> None: + """ + Register HTTP routes on a FastAPI application. + + Args: + app: FastAPI application instance + mode: Server mode - either SIMULATION or PRODUCTION (or string equivalents). + In production mode, simulation control endpoints (/reset, /step, /state) + are NOT registered. Only safe endpoints (/health, /schema, /metadata, /ws) + are available. Defaults to SIMULATION for backwards compatibility. + + Raises: + ValueError: If mode is not a valid ServerMode or string equivalent. + """ + # Convert string to ServerMode enum for backwards compatibility + if isinstance(mode, str): + try: + mode = ServerMode(mode.lower()) + except ValueError: + valid_modes = [m.value for m in ServerMode] + raise ValueError( + f"Invalid mode: '{mode}'. Must be one of: {valid_modes}" + ) + + # Helper function to handle reset endpoint + async def reset_handler( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + """Reset endpoint - returns initial observation.""" + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True) + + is_async = _env.reset_async.__func__ is not Environment.reset_async + + if is_async: + sig = inspect.signature(_env.reset_async) + else: + sig = inspect.signature(_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, kwargs) + + if is_async: + observation = await _env.reset_async(**valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.reset, **valid_kwargs + ) + return ResetResponse(**serialize_observation(observation)) + finally: + _env.close() + + # Helper function to handle step endpoint + async def step_handler(request: StepRequest) -> StepResponse: + """Step endpoint - executes action and returns observation.""" + action_data = request.action + + try: + action = deserialize_action(action_data, self.action_cls) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail=e.errors() + ) + + _env = self._env_factory() + + try: + kwargs = request.model_dump(exclude_unset=True, exclude={"action"}) + + is_async = _env.step_async.__func__ is not Environment.step_async + + if is_async: + sig = inspect.signature(_env.step_async) + else: + sig = inspect.signature(_env.step) + valid_kwargs = self._get_valid_kwargs( + sig, kwargs, skip_params={"action"} + ) + + if is_async: + observation = await _env.step_async(action, **valid_kwargs) + else: + observation = await self._run_sync_in_thread_pool( + _env.step, action, **valid_kwargs + ) + + return StepResponse(**serialize_observation(observation)) + finally: + _env.close() + + # Helper function to handle MCP endpoint + async def mcp_handler( + request: JsonRpcRequest, session_env: Optional[Environment] = None + ) -> JsonRpcResponse: + """ + Handle MCP JSON-RPC requests. + + Supports tools/list and tools/call methods in JSON-RPC 2.0 format. + """ + method = request.method + request_id = request.id + + # Use provided session environment or create temporary one + if session_env is not None: + _env = session_env + should_close = False + else: + _env = self._env_factory() + should_close = True + try: + if method == McpMethod.TOOLS_LIST: + # Check if environment is MCP-enabled + if not hasattr(_env, "mcp_client"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ) + + # Use async context manager for MCP client + async with _env.mcp_client: + tools = await _env.mcp_client.list_tools() + + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() if hasattr(t, "model_dump") else dict(t) + for t in tools + ] + }, + request_id=request_id, + ) + + elif method == McpMethod.TOOLS_CALL: + params = request.params + tool_name = params.get("name") + arguments = params.get("arguments", {}) + + if not hasattr(_env, "mcp_client"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ) + + if not tool_name: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + "Missing 'name' in params", + request_id=request_id, + ) + + # Use async context manager for MCP client + async with _env.mcp_client: + result = await _env.mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + + # Ensure result is JSON serializable + serializable_result = _make_json_serializable(result) + + return JsonRpcResponse.success( + result=serializable_result, + request_id=request_id, + ) + + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.METHOD_NOT_FOUND, + f"Method not found: {method}", + request_id=request_id, + ) + + except Exception as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=request_id, + ) + finally: + if should_close: + _env.close() + + # Register MCP WebSocket endpoint (available in both production and simulation modes) + @app.websocket("/mcp") + async def mcp_websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for MCP JSON-RPC requests. + + Each WebSocket connection gets its own environment instance for MCP operations. + + Message Protocol: + - Client sends: JSON-RPC 2.0 request (tools/list, tools/call) + - Server responds: JSON-RPC 2.0 response (result or error) + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + jsonrpc_dict = json.loads(raw_message) + jsonrpc_request = JsonRpcRequest(**jsonrpc_dict) + except json.JSONDecodeError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR, + f"Parse error: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + except ValidationError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + try: + # Call mcp_handler with session environment + response = await mcp_handler( + jsonrpc_request, session_env=session_env + ) + await websocket.send_text(response.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=jsonrpc_request.id, + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + data={ + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + }, + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + data={"factory_name": e.factory_name}, + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = JsonRpcResponse.error_response( + JsonRpcErrorCode.SERVER_ERROR, + str(e), + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + + # Register simulation control routes only in simulation mode + if mode == ServerMode.SIMULATION: + + @app.post( + "/reset", + response_model=ResetResponse, + tags=["Environment Control"], + summary="Reset the environment", + description=""" +Reset the environment to its initial state and return the first observation. + +You can optionally provide a seed for reproducibility and an episode_id for tracking. + """, + responses={ + 200: { + "description": "Environment reset successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "ready", "data": {}}, + "reward": None, + "done": False, + } + } + }, + } + }, + ) + async def reset( + request: ResetRequest = Body(default_factory=ResetRequest), + ) -> ResetResponse: + return await reset_handler(request) + + @app.post( + "/step", + response_model=StepResponse, + tags=["Environment Control"], + summary="Execute an action in the environment", + description=""" +Execute an action in the environment and receive the resulting observation. + +The action must conform to the environment's action schema, which can be +retrieved from the `/schema` endpoint. If the action is invalid, +the endpoint will return HTTP 422 with detailed validation errors. + +The response includes: +- **observation**: The environment's response to the action +- **reward**: Optional reward signal (float or None) +- **done**: Boolean indicating if the episode has terminated + """, + responses={ + 200: { + "description": "Action executed successfully", + "content": { + "application/json": { + "example": { + "observation": {"status": "success", "data": {}}, + "reward": 1.0, + "done": False, + } + } + }, + }, + 422: { + "description": "Validation error - invalid action format or values", + "content": { + "application/json": { + "example": { + "detail": [ + { + "type": "string_too_short", + "loc": ["body", "action", "message"], + "msg": "String should have at least 1 character", + "input": "", + } + ] + } + } + }, + }, + 500: { + "description": "Internal server error during action execution" + }, + }, + ) + async def step(request: StepRequest) -> StepResponse: + return await step_handler(request) + + def get_state_handler() -> State: + _env = self._env_factory() + try: + return _env.state + finally: + _env.close() + + def get_metadata_handler() -> EnvironmentMetadata: + _env = self._env_factory() + try: + return _env.get_metadata() + finally: + _env.close() + + # Build list of GET endpoints based on mode + get_endpoints = [ + GetEndpointConfig( + path="/metadata", + handler=get_metadata_handler, + response_model=EnvironmentMetadata, + tag="Environment Info", + summary="Get environment metadata", + description=""" +Get metadata about this environment. + +Returns information about the environment including name, description, +version, author, and documentation links. + """, + ), + GetEndpointConfig( + path="/health", + handler=lambda: HealthResponse(status=HealthStatus.HEALTHY), + response_model=HealthResponse, + tag="Health", + summary="Health check", + description="Check if the environment server is running and healthy.", + ), + ] + + # Only register /state endpoint in simulation mode + if mode == ServerMode.SIMULATION: + get_endpoints.insert( + 0, + GetEndpointConfig( + path="/state", + handler=get_state_handler, + response_model=State, + tag="State Management", + summary="Get current environment state", + description=""" +Retrieve the current internal state of the environment. + +The structure of the state object is defined by the environment's State model. + """, + ), + ) + + register_get_endpoints(app, get_endpoints) + + # Register combined schema endpoint + @app.get( + "/schema", + response_model=SchemaResponse, + tags=["Schema"], + summary="Get all JSON schemas", + description=""" +Get JSON schemas for actions, observations, and state in a single response. + +Returns a combined schema object containing: +- **action**: JSON schema for actions accepted by this environment +- **observation**: JSON schema for observations returned by this environment +- **state**: JSON schema for environment state objects + +This is more efficient than calling individual schema endpoints and provides +all schema information needed to interact with the environment. + """, + responses={ + 200: { + "description": "Combined schemas retrieved successfully", + "content": { + "application/json": { + "example": { + "action": { + "type": "object", + "properties": {"message": {"type": "string"}}, + }, + "observation": { + "type": "object", + "properties": {"response": {"type": "string"}}, + }, + "state": { + "type": "object", + "properties": {"step_count": {"type": "integer"}}, + }, + } + } + }, + } + }, + ) + async def get_schemas() -> SchemaResponse: + """Return all schemas in one response.""" + return SchemaResponse( + action=self.action_cls.model_json_schema(), + observation=self.observation_cls.model_json_schema(), + state=State.model_json_schema(), + ) + + # Register MCP endpoint for production mode (direct MCP access) + @app.post("/mcp") + async def mcp_endpoint(request_raw: Request) -> Dict[str, Any]: + """ + MCP JSON-RPC endpoint for production mode. + + Bypasses step() overhead and provides direct access to MCP tools. + Supports tools/list and tools/call methods. + """ + # Parse JSON manually to handle parse errors gracefully + try: + body = await request_raw.body() + request_dict = json.loads(body) + request = JsonRpcRequest(**request_dict) + except json.JSONDecodeError: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR + ).model_dump() + except ValidationError as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ).model_dump() + except Exception: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.PARSE_ERROR + ).model_dump() + + method = request.method + params = request.params + request_id = request.id + + # Create a temporary environment for MCP access + _env = self._env_factory() + + try: + # Check if environment supports MCP + if not hasattr(_env, "mcp_client") and not hasattr(_env, "mcp_server"): + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "Environment does not support MCP", + request_id=request_id, + ).model_dump() + + if method == McpMethod.TOOLS_LIST: + # List tools from MCP server + if hasattr(_env, "mcp_client") and _env.mcp_client: + async with _env.mcp_client: + tools = await _env.mcp_client.list_tools() + return JsonRpcResponse.success( + result={ + "tools": [ + t.model_dump() + if hasattr(t, "model_dump") + else dict(t) + for t in tools + ] + }, + request_id=request_id, + ).model_dump() + elif hasattr(_env, "mcp_server") and _env.mcp_server: + # Use server directly + tools = [] + for tool_name, tool in get_server_tools( + _env.mcp_server + ).items(): + tool_dict = { + "name": tool.name, + "description": tool.description or "", + "inputSchema": tool.parameters or {}, + } + tools.append(tool_dict) + return JsonRpcResponse.success( + result={"tools": tools}, + request_id=request_id, + ).model_dump() + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, + ).model_dump() + + elif method == McpMethod.TOOLS_CALL: + tool_name = params.get("name") + arguments = params.get("arguments", {}) + + if not tool_name: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + "Invalid params - 'name' is required", + request_id=request_id, + ).model_dump() + + # Call tool via MCP + if hasattr(_env, "mcp_client") and _env.mcp_client: + async with _env.mcp_client: + result = await _env.mcp_client.call_tool( + name=tool_name, arguments=arguments + ) + elif hasattr(_env, "mcp_server") and _env.mcp_server: + # Call tool directly on FastMCP server + server_tools = get_server_tools(_env.mcp_server) + if tool_name in server_tools: + tool = server_tools[tool_name] + result = tool.fn(**arguments) + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_PARAMS, + f"Tool not found: {tool_name}", + request_id=request_id, + ).model_dump() + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + "MCP server not available", + request_id=request_id, + ).model_dump() + + # Make result JSON serializable + serializable_result = _make_json_serializable(result) + + return JsonRpcResponse.success( + result=serializable_result, + request_id=request_id, + ).model_dump() + + else: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.METHOD_NOT_FOUND, + f"Method not found: {method}", + request_id=request_id, + ).model_dump() + + except Exception as e: + return JsonRpcResponse.error_response( + JsonRpcErrorCode.INTERNAL_ERROR, + str(e), + request_id=request_id, + ).model_dump() + finally: + _env.close() + + # Register WebSocket endpoint for persistent sessions + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + """ + WebSocket endpoint for persistent environment sessions. + + Each WebSocket connection gets its own environment instance. + + Message Protocol: + - Client sends: WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage + - Server responds: WSObservationResponse | WSStateResponse | WSErrorResponse + """ + await websocket.accept() + + session_id = None + session_env = None + + try: + # Create session with dedicated environment + session_id, session_env = await self._create_session() + + while True: + # Receive message from client + raw_message = await websocket.receive_text() + + try: + message_dict = json.loads(raw_message) + except json.JSONDecodeError as e: + error_resp = WSErrorResponse( + data={ + "message": f"Invalid JSON: {e}", + "code": WSErrorCode.INVALID_JSON, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + continue + + msg_type = message_dict.get("type", "") + + try: + match msg_type: + case "reset": + msg = WSResetMessage(**message_dict) + + is_async = ( + session_env.reset_async.__func__ + is not Environment.reset_async + ) + + if is_async: + sig = inspect.signature(session_env.reset_async) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await session_env.reset_async( + **valid_kwargs + ) + else: + sig = inspect.signature(session_env.reset) + valid_kwargs = self._get_valid_kwargs(sig, msg.data) + observation = await self._run_in_session_executor( + session_id, session_env.reset, **valid_kwargs + ) + + self._update_session_activity(session_id) + + response = WSObservationResponse( + data=serialize_observation(observation), + ) + + case "step": + msg = WSStepMessage(**message_dict) + action = deserialize_action(msg.data, self.action_cls) + + is_async = ( + session_env.step_async.__func__ + is not Environment.step_async + ) + + if is_async: + observation = await session_env.step_async(action) + else: + observation = await self._run_in_session_executor( + session_id, session_env.step, action + ) + + self._update_session_activity( + session_id, increment_step=True + ) + + response = WSObservationResponse( + data=serialize_observation(observation) + ) + + case "state": + msg = WSStateMessage(**message_dict) + state = session_env.state + if hasattr(state, "model_dump"): + state_data = state.model_dump() + else: + state_data = dict(state) if state else {} + + response = WSStateResponse(data=state_data) + + case "close": + msg = WSCloseMessage(**message_dict) + break + + case "mcp": + msg = WSMCPMessage(**message_dict) + try: + rpc_request = JsonRpcRequest(**msg.data) + except (ValidationError, Exception) as e: + rpc_response = JsonRpcResponse.error_response( + JsonRpcErrorCode.INVALID_REQUEST, + f"Invalid request: {e}", + ) + else: + rpc_response = await mcp_handler( + rpc_request, + session_env=session_env, + ) + response = WSMCPResponse(data=rpc_response.model_dump()) + + case _: + response = WSErrorResponse( + data={ + "message": f"Unknown message type: {msg_type}", + "code": WSErrorCode.UNKNOWN_TYPE, + } + ) + + await websocket.send_text(response.model_dump_json()) + + except ValidationError as e: + error_resp = WSErrorResponse( + data={ + "message": "Invalid message", + "code": WSErrorCode.VALIDATION_ERROR, + "errors": e.errors(), + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.EXECUTION_ERROR, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + + except WebSocketDisconnect: + pass + except SessionCapacityError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.CAPACITY_REACHED, + "active_sessions": e.active_sessions, + "max_sessions": e.max_sessions, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except EnvironmentFactoryError as e: + error_resp = WSErrorResponse( + data={ + "message": str(e), + "code": WSErrorCode.FACTORY_ERROR, + "factory_name": e.factory_name, + } + ) + await websocket.send_text(error_resp.model_dump_json()) + except Exception as e: + error_resp = WSErrorResponse( + data={"message": str(e), "code": WSErrorCode.SESSION_ERROR} + ) + await websocket.send_text(error_resp.model_dump_json()) + finally: + if session_id: + await self._destroy_session(session_id) + try: + await websocket.close() + except RuntimeError: + pass + + +def create_app( + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, + gradio_builder: Optional[Callable[..., Any]] = None, +) -> FastAPI: + """ + Create a FastAPI application with or without web interface. + + This function creates a FastAPI app with the web interface enabled by default, + including README integration for better user experience. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + gradio_builder: Optional callable to build a custom Gradio UI at /web. + Signature: (web_manager, action_fields, metadata, is_chat_env, title, + quick_start_md) -> gr.Blocks. When None, the default Gradio app is used. + See docs/customizing-web-ui.md. + + Returns: + FastAPI application instance with or without web interface and README integration + """ + # Check if web interface should be enabled + # This can be controlled via environment variable or build argument + enable_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ( + "true", + "1", + "yes", + ) + + if enable_web: + # Gradio-based web UI (gradio is a core dependency) + from .web_interface import create_web_interface_app + + return create_web_interface_app( + env, + action_cls, + observation_cls, + env_name, + max_concurrent_envs, + concurrency_config, + gradio_builder=gradio_builder, + ) + else: + # Use standard FastAPI app without web interface + return create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) + + +def create_fastapi_app( + env: Callable[[], Environment], + action_cls: Type[Action], + observation_cls: Type[Observation], + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[ConcurrencyConfig] = None, +) -> FastAPI: + """ + Create a FastAPI application with comprehensive documentation. + + Args: + env: Environment factory (callable) that creates new instances + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + max_concurrent_envs: Maximum concurrent WebSocket sessions. + Mutually exclusive with concurrency_config. + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings. + Mutually exclusive with max_concurrent_envs. + + Returns: + FastAPI application instance + """ + try: + from fastapi import FastAPI + except ImportError: + raise ImportError( + "FastAPI is required. Install with: pip install fastapi uvicorn" + ) + + app = FastAPI( + title="OpenEnv Environment HTTP API", + version="1.0.0", + description=""" +# OpenEnv Environment HTTP API + +HTTP API for interacting with OpenEnv environments through a standardized interface. + +## Features + +* **Environment Reset**: Initialize or restart episodes +* **Action Execution**: Send actions and receive observations +* **State Inspection**: Query current environment state +* **Schema Access**: Retrieve JSON schemas for actions and observations + +## Workflow + +1. Call `/reset` to start a new episode and get initial observation +2. Call `/step` repeatedly with actions to interact with environment +3. Episode ends when observation returns `done: true` +4. Call `/state` anytime to inspect current environment state + +## Documentation + +* **Swagger UI**: Available at `/docs` +* **ReDoc**: Available at `/redoc` +* **OpenAPI Schema**: Available at `/openapi.json` + """, + openapi_tags=[ + { + "name": "Environment Control", + "description": "Core operations for environment interaction (reset, step)", + }, + { + "name": "State Management", + "description": "Operations for inspecting environment state", + }, + { + "name": "Environment Info", + "description": "Information about the environment", + }, + { + "name": "Schema", + "description": "JSON Schema endpoints for actions, observations, and state", + }, + {"name": "Health", "description": "Service health and status checks"}, + ], + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + contact={ + "name": "OpenEnv Team", + "url": "https://github.com/meta-pytorch/OpenEnv", + }, + license_info={ + "name": "BSD-3-Clause", + "url": "https://github.com/meta-pytorch/OpenEnv/blob/main/LICENSE", + }, + ) + + server = HTTPEnvServer( + env, + action_cls, + observation_cls, + max_concurrent_envs, + concurrency_config=concurrency_config, + ) + server.register_routes(app) + return app diff --git a/src/openenv/core/env_server/interfaces.py b/src/openenv/core/env_server/interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..9fa837549aa1e2bf1c439f1d7a52e845a556ae18 --- /dev/null +++ b/src/openenv/core/env_server/interfaces.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Generic, Optional, Protocol, TYPE_CHECKING, TypeVar + +from typing_extensions import TypedDict + +from .types import Action, EnvironmentMetadata, Observation, State + +if TYPE_CHECKING: + from openenv.core.rubrics import Rubric + +ActT = TypeVar("ActT", bound=Action) +ObsT = TypeVar("ObsT", bound=Observation) +StateT = TypeVar("StateT", bound=State) + + +class Message(TypedDict): + """A message in a conversation. + + Compatible with Huggingface chat template format. + """ + + role: str + content: str + + +class ModelTokenizer(Protocol): + """Protocol for tokenizers that support chat templates. + + This protocol defines the interface that tokenizers must implement + to work with chat-based environments. It's compatible with + Huggingface transformers tokenizers. + """ + + def apply_chat_template( + self, + conversation: list[Message], + tokenize: bool = True, + return_tensors: str | None = None, + **kwargs: Any, + ) -> Any: + """Apply a chat template to format and optionally tokenize a conversation. + + Args: + conversation: List of message dictionaries with 'role' and 'content' + tokenize: Whether to tokenize the output + return_tensors: Format for returned tensors ('pt' for PyTorch) + **kwargs: Additional arguments + + Returns: + Formatted and optionally tokenized conversation + """ + ... + + def decode( + self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any + ) -> str: + """Decode token IDs back to text. + + Args: + token_ids: Token IDs to decode + skip_special_tokens: Whether to skip special tokens in output + **kwargs: Additional arguments + + Returns: + Decoded text string + """ + ... + + +class Transform(ABC, Generic[ObsT]): + """Transform observations to add rewards, metrics, or other modifications. + + Transforms follow the TorchRL pattern where they take an observation + and return a (potentially modified) observation. This allows for + flexible reward computation and observation augmentation. + """ + + @abstractmethod + def __call__(self, observation: ObsT) -> ObsT: + """Transform an observation. + + Args: + observation: The input observation + + Returns: + The transformed observation + """ + pass + + +class Environment(ABC, Generic[ActT, ObsT, StateT]): + """Base class for all environment servers following Gym/Gymnasium API. + + Args: + transform: Optional transform to apply to observations + rubric: Optional rubric for reward computation. When provided, the + rubric's output can be used to set the observation's reward in step(). + + Class Attributes: + SUPPORTS_CONCURRENT_SESSIONS: Whether this environment supports concurrent sessions. + When True, multiple WebSocket connections can each have their own + environment instance (up to max_concurrent_envs). When False (default), + the environment should only be used with a single session at a time. + + Set this to True in your Environment subclass if: + - The environment uses proper session isolation (e.g., unique working dirs) + - No shared mutable state exists between instances + - External resources (databases, APIs) can handle concurrent access + + Attributes: + rubric: Optional rubric for computing rewards. Environments can set this + in __init__ and use it in step() to compute observation rewards. + Training infrastructure can access it for introspection: + for name, r in env.rubric.named_rubrics(): + print(f"{name}: {r.last_score}") + + See RFC 004 for rubric design: rfcs/004-rubrics.md + """ + + # Class-level flag indicating whether this environment supports concurrent sessions + SUPPORTS_CONCURRENT_SESSIONS: bool = False + + # Optional rubric for reward computation + rubric: Optional["Rubric"] + + def __init__( + self, + transform: Optional[Transform[ObsT]] = None, + rubric: Optional["Rubric"] = None, + ): + self.transform = transform + self.rubric = rubric + + @abstractmethod + def reset( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Reset the environment and return initial observation.""" + pass + + async def reset_async( + self, + seed: Optional[int] = None, + episode_id: Optional[str] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of reset. Default implementation calls sync reset. + + Override to provide true async implementation. + """ + return self.reset(seed=seed, episode_id=episode_id, **kwargs) + + @abstractmethod + def step( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Take a step in the environment.""" + pass + + async def step_async( + self, + action: ActT, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> ObsT: + """Async version of step. Default implementation calls sync step. + + Override to provide true async implementation. + """ + return self.step(action, timeout_s=timeout_s, **kwargs) + + @property + @abstractmethod + def state(self) -> StateT: + """Get the current environment state.""" + pass + + def get_metadata(self) -> EnvironmentMetadata: + """ + Get metadata about this environment. + + Override this method to provide custom metadata for the environment. + Default implementation returns basic metadata derived from class name. + + Returns: + EnvironmentMetadata with environment information + """ + return EnvironmentMetadata( + name=self.__class__.__name__, + description=f"{self.__class__.__name__} environment", + version="1.0.0", + ) + + def _apply_transform(self, observation: ObsT) -> ObsT: + """Apply transform if one is provided.""" + if self.transform is not None: + return self.transform(observation) + return observation + + def _apply_rubric(self, action: ActT, observation: ObsT) -> float: + """Apply rubric if one is provided. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value from the rubric, or 0.0 if no rubric is set. + + Usage in step(): + def step(self, action: MyAction, ...) -> MyObservation: + # ... execute action and create observation ... + observation.reward = self._apply_rubric(action, observation) + return observation + """ + if self.rubric is not None: + return self.rubric(action, observation) + return 0.0 + + async def _apply_rubric_async(self, action: ActT, observation: ObsT) -> float: + """Apply rubric asynchronously if one is provided. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value from the rubric, or 0.0 if no rubric is set. + + Usage in step_async(): + async def step_async(self, action: MyAction, ...) -> MyObservation: + # ... execute action and create observation ... + observation.reward = await self._apply_rubric_async(action, observation) + return observation + """ + if self.rubric is not None: + result = self.rubric(action, observation) + # If rubric returns a coroutine, await it + if inspect.iscoroutine(result): + return await result + return result + return 0.0 + + def _reset_rubric(self) -> None: + """Reset the rubric state if one is provided. + + Call this in reset() to clear any trajectory state in the rubric. + + Usage in reset(): + def reset(self, ...) -> MyObservation: + self._reset_rubric() + # ... create initial observation ... + return observation + """ + if self.rubric is not None: + self.rubric.reset() + + async def _reset_rubric_async(self) -> None: + """Reset the rubric state asynchronously if one is provided. + + Call this in reset_async() to clear any trajectory state in the rubric. + + Usage in reset_async(): + async def reset_async(self, ...) -> MyObservation: + await self._reset_rubric_async() + # ... create initial observation ... + return observation + """ + if self.rubric is not None: + # Check if rubric has async reset method + if hasattr(self.rubric, "reset_async"): + result = self.rubric.reset_async() + if inspect.iscoroutine(result): + await result + else: + self.rubric.reset() + + def close(self) -> None: + """Clean up resources used by the environment. + + Override this method to implement custom cleanup logic. + Called when the environment is being destroyed or reset. + """ + pass diff --git a/src/openenv/core/env_server/mcp_environment.py b/src/openenv/core/env_server/mcp_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..03f66e37897ec81796d468f3d0590d465deddea1 --- /dev/null +++ b/src/openenv/core/env_server/mcp_environment.py @@ -0,0 +1,624 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP Environment base class for OpenEnv. + +This module provides the MCPEnvironment base class that integrates FastMCP servers +with OpenEnv's Gym-style Environment interface. It handles MCP tool discovery +and invocation through the step() API, following RFC 003. + +Key features: +- Automatic routing of ListToolsAction and CallToolAction to MCP server +- Reserved tool name validation (reset, step, state, close are protected) +- Timeout handling for tool calls +- Proper error categorization (tool not found, execution errors, timeouts) +- Mode-aware tool registration (production vs simulation) +- Code mode support via get_callables() and execute_code() + +Usage: + from fastmcp import FastMCP + from openenv.core.env_server.mcp_environment import MCPEnvironment + + class MyMCPEnv(MCPEnvironment): + def __init__(self): + mcp = FastMCP("my-server") + + # Register mode-specific tools + @self.tool(mode="production") + def my_tool(arg: str) -> str: + return f"Production: {arg}" + + @self.tool(mode="simulation") + def my_tool(arg: str) -> str: + return f"Simulation: {arg}" + + super().__init__(mcp) + + def reset(self, seed=None, episode_id=None, **kwargs): + # Reset logic here + ... + + def _step_impl(self, action): + # Handle non-MCP actions + ... + + @property + def state(self): + # Return current state + ... +""" + +import asyncio +import inspect +from abc import abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, Optional + +from fastmcp import Client +from fastmcp.client.client import CallToolResult +from mcp.types import TextContent + +from ..utils import run_async_safely +from .interfaces import Environment +from .mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + RESERVED_TOOL_NAMES, + Tool, + ToolError, + ToolErrorType, +) +from .types import Action, Observation + + +# Default timeout for MCP tool calls in seconds +MCP_TOOL_CALL_TIMEOUT = 30.0 + +# Valid modes for tool registration +VALID_MODES = {"production", "simulation"} + + +def get_server_tools(mcp_server: Any) -> Dict[str, Any]: + """ + Get tools from a FastMCP server, compatible with both 2.x and 3.x. + + Returns: + Dictionary mapping tool names to tool objects. + """ + # FastMCP 2.x: get_tools() returns dict {name: Tool} + if hasattr(mcp_server, "get_tools"): + result = run_async_safely(mcp_server.get_tools()) + if isinstance(result, dict): + return result + # FastMCP 3.x: list_tools() returns list of Tool objects + if hasattr(mcp_server, "list_tools"): + tools_list = run_async_safely(mcp_server.list_tools()) + return {t.name: t for t in tools_list} + return {} + + +class MCPEnvironment(Environment): + """ + Base class for environments that expose tools via MCP (Model Context Protocol). + + MCPEnvironment bridges FastMCP servers with OpenEnv's Gym-style API, allowing + agents to discover and invoke MCP tools through the standard step() interface. + + The class automatically handles: + - ListToolsAction: Returns available tools from the MCP server + - CallToolAction: Invokes a specific tool with arguments + + All other actions are delegated to the abstract _step_impl() method, + which subclasses must implement. + + Args: + mcp_server: A FastMCP server instance containing tool definitions. + The server's tools will be validated against reserved names. + transform: Optional transform to apply to observations (inherited from Environment). + + Raises: + ValueError: If any tool in the MCP server uses a reserved name + (reset, step, state, close). + + Example: + >>> from fastmcp import FastMCP + >>> mcp = FastMCP("calculator") + >>> @mcp.tool() + ... def add(a: int, b: int) -> int: + ... return a + b + >>> env = MyMCPEnvironment(mcp) + >>> obs = env.step(ListToolsAction()) + >>> obs.tools[0].name + 'add' + """ + + def __init__(self, mcp_server: Any, transform: Optional[Any] = None) -> None: + """ + Initialize the MCP environment. + + Args: + mcp_server: A FastMCP server instance with tool definitions. + transform: Optional transform to apply to observations. + + Raises: + ValueError: If any tool uses a reserved name (reset, step, state, close). + """ + super().__init__(transform=transform) + + # Validate tool names before storing + self._validate_tool_names(mcp_server) + + self.mcp_server = mcp_server + self.mcp_client = Client(mcp_server) + + # Track mode-specific tools: {tool_name: {mode: func}} + # mode can be "production", "simulation", or None (available in all modes) + self._mode_tools = defaultdict(dict) + + # Track tool schemas for list_tools: {tool_name: {mode: schema}} + self._mode_tool_schemas = defaultdict(dict) + + @property + def supports_code_mode(self) -> bool: + """Check if this environment supports code mode (execute_code).""" + return True + + def _get_server_tools(self, mcp_server: Any) -> Dict[str, Any]: + """ + Get tools from a FastMCP server, compatible with both 2.x and 3.x. + + Returns: + Dictionary mapping tool names to tool objects. + """ + return get_server_tools(mcp_server) + + def get_callables(self) -> Dict[str, Callable]: + """ + Get callable functions for code mode. + + Returns tool functions as direct Python callables, enabling code mode + where agents write Python code that calls tools directly (no JSON-RPC + overhead). Mode-specific tools are filtered by the current mode. + + Returns: + Dictionary mapping tool names to callables. + """ + callables: Dict[str, Callable] = {} + current_mode = getattr(self, "_mode", None) + + # Extract callables from FastMCP server using public API + for tool_name, tool in self._get_server_tools(self.mcp_server).items(): + if hasattr(tool, "fn") and callable(tool.fn): + callables[tool_name] = tool.fn + + # Add mode-specific tools available in current mode + for tool_name, mode_funcs in self._mode_tools.items(): + if None in mode_funcs: + # Tool available in all modes (already in FastMCP if registered there) + if tool_name not in callables: + callables[tool_name] = mode_funcs[None] + elif current_mode in mode_funcs: + # Tool available in current mode only + callables[tool_name] = mode_funcs[current_mode] + + return callables + + def execute_code(self, code: str) -> Observation: + """ + Execute Python code with tools available as callables. + + This enables the CodeAct pattern where agents write Python code + that calls tools directly as functions, avoiding JSON-RPC overhead. + + Args: + code: Python code to execute. Tools are available as functions + in the execution namespace. Set a variable named 'result' + to capture the return value. + + Returns: + Observation with result in metadata["result"] or error in + metadata["error"]. + """ + namespace = self.get_callables() + + result_dict: Dict[str, Any] = {} + try: + exec(code, namespace, result_dict) + result = result_dict.get("result") + return Observation(done=False, reward=0.0, metadata={"result": result}) + except SyntaxError as e: + return Observation( + done=False, reward=0.0, metadata={"error": f"Syntax error: {str(e)}"} + ) + except Exception as e: + return Observation(done=False, reward=0.0, metadata={"error": str(e)}) + + def _validate_tool_names(self, mcp_server: Any) -> None: + """ + Validate that no tools use reserved names. + + Reserved names (reset, step, state, close) are protected to maintain + the dual API boundary between infrastructure and agent APIs. + + Args: + mcp_server: The FastMCP server to validate. + + Raises: + ValueError: If any tool uses a reserved name. + """ + tools_dict = self._get_server_tools(mcp_server) + if tools_dict: + tool_names = set(tools_dict.keys()) + conflicts = tool_names & RESERVED_TOOL_NAMES + if conflicts: + raise ValueError( + f"MCP tools cannot use reserved names: {sorted(conflicts)}. " + f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" + ) + + def tool(self, mode: Optional[str] = None) -> Callable: + """ + Decorator for registering mode-aware tools. + + Args: + mode: Optional mode for the tool ("production" or "simulation"). + If None, tool is available in all modes. + + Returns: + A decorator function for registering tools. + + Raises: + ValueError: If mode is not None, "production", or "simulation". + """ + if mode is not None and mode not in VALID_MODES: + raise ValueError( + f"Invalid mode '{mode}'. Mode must be 'production', 'simulation', or None." + ) + + def decorator(func: Callable) -> Callable: + tool_name = func.__name__ + # Validate tool name is not reserved + if tool_name in RESERVED_TOOL_NAMES: + raise ValueError( + f"Tool name '{tool_name}' is reserved and cannot be used. " + f"Reserved names are: {sorted(RESERVED_TOOL_NAMES)}" + ) + + # If mode is None, register with FastMCP as usual + if mode is None: + decorated_func = self.mcp_server.tool()(func) + self._mode_tools[tool_name][None] = func + return decorated_func + + # For mode-specific tools, don't register with FastMCP + # Instead, track them ourselves + self._mode_tools[tool_name][mode] = func + + # Extract schema information from function signature + sig = inspect.signature(func) + schema = { + "type": "object", + "properties": {}, + "required": [], + } + + for param_name, param in sig.parameters.items(): + # Get type annotation + param_type = param.annotation + json_type = "string" # default + if param_type in (int, "int"): + json_type = "integer" + elif param_type in (float, "float"): + json_type = "number" + elif param_type in (bool, "bool"): + json_type = "boolean" + + schema["properties"][param_name] = {"type": json_type} + + # If no default value, it's required + if param.default == inspect.Parameter.empty: + schema["required"].append(param_name) + + # Store the schema for this mode-specific tool + self._mode_tool_schemas[tool_name][mode] = { + "name": tool_name, + "description": func.__doc__ or "", + "input_schema": schema, + } + + return func + + return decorator + + def step( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """ + Execute an action in the environment. + + This method routes MCP-specific actions (ListToolsAction, CallToolAction) + to the appropriate handlers, while delegating all other actions to + the subclass's _step_impl() method. + + Args: + action: The action to execute. Can be: + - ListToolsAction: Returns available MCP tools + - CallToolAction: Invokes a specific MCP tool + - Any other Action: Delegated to _step_impl() + timeout_s: Optional timeout in seconds for the action. + Defaults to MCP_TOOL_CALL_TIMEOUT (30s) for MCP actions. + **kwargs: Additional arguments passed to handlers. + + Returns: + Observation appropriate to the action type: + - ListToolsObservation for ListToolsAction + - CallToolObservation for CallToolAction + - Subclass-defined Observation for other actions + """ + if isinstance(action, ListToolsAction): + return self._handle_list_tools() + elif isinstance(action, CallToolAction): + return self._handle_call_tool(action, timeout_s=timeout_s) + else: + return self._step_impl(action, timeout_s=timeout_s, **kwargs) + + def _handle_list_tools(self) -> ListToolsObservation: + """ + Handle a ListToolsAction by querying the MCP server. + + Returns: + ListToolsObservation containing all available tools with their + names, descriptions, and input schemas, filtered by current mode. + """ + try: + # Get current mode + current_mode = getattr(self, "_mode", None) + + # Start with tools from FastMCP server (mode=None tools) + tools_result = run_async_safely(self._async_list_tools()) + + # Build list of Tool objects + tools = [] + + # Add FastMCP tools that are not mode-specific + for tool in tools_result: + if tool.name not in self._mode_tool_schemas: + tools.append( + Tool( + name=tool.name, + description=tool.description or "", + input_schema=tool.inputSchema + if hasattr(tool, "inputSchema") + else {}, + ) + ) + + # Add mode-specific tools available in current mode + for tool_name, mode_schemas in self._mode_tool_schemas.items(): + if None in mode_schemas: + # Tool available in all modes + schema = mode_schemas[None] + tools.append( + Tool( + name=schema["name"], + description=schema["description"], + input_schema=schema["input_schema"], + ) + ) + elif current_mode in mode_schemas: + # Tool available in current mode + schema = mode_schemas[current_mode] + tools.append( + Tool( + name=schema["name"], + description=schema["description"], + input_schema=schema["input_schema"], + ) + ) + + return ListToolsObservation(tools=tools) + + except Exception as e: + # Return an observation with error in metadata + return ListToolsObservation( + tools=[], + metadata={ + "error": str(e), + "error_type": "list_tools_failed", + }, + ) + + async def _async_list_tools(self) -> list: + """ + Async helper to list tools from the MCP client. + + Returns: + List of tool objects from the MCP server. + """ + async with self.mcp_client: + return await self.mcp_client.list_tools() + + def _handle_call_tool( + self, + action: CallToolAction, + timeout_s: Optional[float] = None, + ) -> CallToolObservation: + """ + Handle a CallToolAction by invoking the specified tool. + + Args: + action: The CallToolAction containing tool_name and arguments. + timeout_s: Timeout in seconds. Defaults to MCP_TOOL_CALL_TIMEOUT (30s). + + Returns: + CallToolObservation with the tool's result or an error. + """ + timeout = timeout_s if timeout_s is not None else MCP_TOOL_CALL_TIMEOUT + + # Check if this is a mode-specific tool + tool_name = action.tool_name + current_mode = getattr(self, "_mode", None) + + if tool_name in self._mode_tools: + mode_info = self._mode_tools[tool_name] + + # Check if tool is available in current mode + # Tool is available if: + # 1. It has a None mode (available in all modes), OR + # 2. It has an implementation for the current mode + if None in mode_info: + # Use the mode-agnostic version + func = mode_info[None] + elif current_mode in mode_info: + # Use the mode-specific version + func = mode_info[current_mode] + else: + # Tool not available in current mode + return CallToolObservation( + tool_name=tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.TOOL_NOT_FOUND, + message=f"Tool '{tool_name}' not available in {current_mode} mode", + ), + ) + + # Call the mode-specific function directly + try: + # Check if function is async and await if necessary + if inspect.iscoroutinefunction(func): + result = run_async_safely(func(**action.arguments)) + else: + result = func(**action.arguments) + + # Wrap result in CallToolResult format to match FastMCP behavior + return CallToolObservation( + tool_name=tool_name, + result=CallToolResult( + content=[TextContent(type="text", text=str(result))], + structured_content={"result": result}, + meta=None, + data=result, + is_error=False, + ), + ) + except Exception as e: + return CallToolObservation( + tool_name=tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.EXECUTION_ERROR, + message=str(e), + ), + ) + + # Not a mode-specific tool, use FastMCP + try: + # Run the async call_tool with timeout + # Use run_async_safely to handle both sync and async contexts + result = run_async_safely( + asyncio.wait_for( + self._async_call_tool(action.tool_name, action.arguments), + timeout=timeout, + ) + ) + + return CallToolObservation( + tool_name=action.tool_name, + result=result, + ) + + except asyncio.TimeoutError: + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=ToolErrorType.TIMEOUT, + message=f"Tool '{action.tool_name}' timed out after {timeout} seconds", + ), + ) + + except Exception as e: + error_message = str(e) + + # Determine error type based on the exception + if ( + "not found" in error_message.lower() + or "unknown tool" in error_message.lower() + ): + error_type = ToolErrorType.TOOL_NOT_FOUND + elif ( + "invalid" in error_message.lower() + or "argument" in error_message.lower() + ): + error_type = ToolErrorType.INVALID_ARGS + else: + error_type = ToolErrorType.EXECUTION_ERROR + + return CallToolObservation( + tool_name=action.tool_name, + result=None, + error=ToolError( + error_type=error_type, + message=error_message, + ), + ) + + async def _async_call_tool(self, tool_name: str, arguments: dict) -> Any: + """ + Async helper to call a tool on the MCP server. + + Args: + tool_name: Name of the tool to invoke. + arguments: Dictionary of arguments to pass to the tool. + + Returns: + The result from the tool execution. + """ + async with self.mcp_client: + return await self.mcp_client.call_tool(tool_name, arguments) + + @abstractmethod + def _step_impl( + self, + action: Action, + timeout_s: Optional[float] = None, + **kwargs: Any, + ) -> Observation: + """ + Handle non-MCP actions in the environment. + + Subclasses must implement this method to handle any actions that are + not ListToolsAction or CallToolAction. This is where environment-specific + action processing should occur. + + Args: + action: The action to execute (guaranteed not to be an MCP action). + timeout_s: Optional timeout in seconds. + **kwargs: Additional arguments. + + Returns: + An Observation appropriate for the action. + """ + pass + + def close(self) -> None: + """ + Clean up resources used by the environment. + + This method cleans up the MCP client and any other resources. + Subclasses should call super().close() if they override this method. + """ + # The MCP client uses async context manager, so cleanup happens + # automatically when the context exits. We just clear references. + self.mcp_client = None + self.mcp_server = None diff --git a/src/openenv/core/env_server/mcp_types.py b/src/openenv/core/env_server/mcp_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa5b7449e2fa60dea46efc6b0992a6359146b2b --- /dev/null +++ b/src/openenv/core/env_server/mcp_types.py @@ -0,0 +1,321 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP (Model Context Protocol) type definitions for OpenEnv. + +This module defines strongly typed models for MCP tool discovery and invocation, +following RFC 003. These types map MCP's REST-like API (tools/list, tools/call) +to Gym-style action types. + +Key design decisions: +- Tool discovery (list_tools) does NOT require reset() first +- Reserved tool names (reset, step, state, close) are prohibited +- Both step() and WebSocket /mcp paths are supported +""" + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +from .types import Action, BaseMessage, Observation + + +# ============================================================================= +# JSON-RPC 2.0 Types +# ============================================================================= + + +class JsonRpcErrorCode(int, Enum): + """ + Standard JSON-RPC 2.0 error codes. + + See: https://www.jsonrpc.org/specification#error_object + """ + + # Standard JSON-RPC errors + PARSE_ERROR = -32700 # Invalid JSON was received + INVALID_REQUEST = -32600 # JSON is not a valid Request object + METHOD_NOT_FOUND = -32601 # Method does not exist / is not available + INVALID_PARAMS = -32602 # Invalid method parameter(s) + INTERNAL_ERROR = -32603 # Internal JSON-RPC error + + # Server errors (reserved for implementation-defined errors) + SERVER_ERROR = -32000 # Generic server error + + +class McpMethod(str, Enum): + """Supported MCP method names.""" + + TOOLS_LIST = "tools/list" + TOOLS_CALL = "tools/call" + + +class JsonRpcError(BaseModel): + """ + JSON-RPC 2.0 error object. + + See: https://www.jsonrpc.org/specification#error_object + """ + + model_config = ConfigDict(extra="forbid") + + code: int = Field(description="Error code indicating the error type") + message: str = Field(description="Short description of the error") + data: Optional[Any] = Field( + default=None, description="Additional error information" + ) + + @classmethod + def from_code( + cls, code: JsonRpcErrorCode, message: Optional[str] = None, data: Any = None + ) -> "JsonRpcError": + """Create an error from a standard error code.""" + default_messages = { + JsonRpcErrorCode.PARSE_ERROR: "Parse error", + JsonRpcErrorCode.INVALID_REQUEST: "Invalid Request", + JsonRpcErrorCode.METHOD_NOT_FOUND: "Method not found", + JsonRpcErrorCode.INVALID_PARAMS: "Invalid params", + JsonRpcErrorCode.INTERNAL_ERROR: "Internal error", + JsonRpcErrorCode.SERVER_ERROR: "Server error", + } + return cls( + code=code.value, + message=message or default_messages.get(code, "Unknown error"), + data=data, + ) + + +class JsonRpcRequest(BaseModel): + """ + JSON-RPC 2.0 request object. + + See: https://www.jsonrpc.org/specification#request_object + """ + + model_config = ConfigDict(extra="forbid") + + jsonrpc: Literal["2.0"] = Field(description="JSON-RPC version, must be '2.0'") + method: str = Field(description="Name of the method to be invoked") + params: Dict[str, Any] = Field( + default_factory=dict, description="Parameter values for the method" + ) + id: Optional[Union[str, int]] = Field( + default=None, description="Request identifier established by the client" + ) + + +class JsonRpcResponse(BaseModel): + """ + JSON-RPC 2.0 response object. + + Per JSON-RPC 2.0 spec, a response has either 'result' or 'error', not both. + This model excludes None values during serialization to comply with the spec. + + See: https://www.jsonrpc.org/specification#response_object + """ + + model_config = ConfigDict(extra="forbid") + + jsonrpc: Literal["2.0"] = Field(default="2.0", description="JSON-RPC version") + result: Optional[Any] = Field( + default=None, description="Result of the method invocation" + ) + error: Optional[JsonRpcError] = Field( + default=None, description="Error object if method invocation failed" + ) + id: Optional[Union[str, int]] = Field( + default=None, description="Request identifier from the request" + ) + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """Serialize to dict, excluding result or error when None (JSON-RPC compliance).""" + # Always include jsonrpc and id, but only include result OR error + data: Dict[str, Any] = {"jsonrpc": self.jsonrpc, "id": self.id} + if self.error is not None: + data["error"] = ( + self.error.model_dump() + if hasattr(self.error, "model_dump") + else self.error + ) + else: + # Only include result if there's no error + data["result"] = self.result + return data + + def model_dump_json(self, **kwargs) -> str: + """Serialize to JSON string, excluding result or error when None (JSON-RPC compliance).""" + import json + + return json.dumps(self.model_dump()) + + @classmethod + def success( + cls, result: Any, request_id: Optional[Union[str, int]] = None + ) -> "JsonRpcResponse": + """Create a success response.""" + return cls(result=result, id=request_id) + + @classmethod + def error_response( + cls, + code: JsonRpcErrorCode, + message: Optional[str] = None, + data: Any = None, + request_id: Optional[Union[str, int]] = None, + ) -> "JsonRpcResponse": + """Create an error response from a standard error code.""" + return cls( + error=JsonRpcError.from_code(code, message, data), + id=request_id, + ) + + +# ============================================================================= +# MCP Tool Types +# ============================================================================= + + +class Tool(BaseModel): + """ + Strongly typed MCP tool specification. + + Follows the MCP ToolSpec format for tool discovery. + See: https://modelcontextprotocol.io/specification/2025-06-18/server/tools + """ + + model_config = ConfigDict(extra="forbid") + + name: str = Field(description="Unique identifier for the tool") + description: str = Field( + description="Human-readable description of what the tool does" + ) + input_schema: Dict[str, Any] = Field( + description="JSON Schema for the tool's input parameters" + ) + + +class ToolErrorType(str, Enum): + """Types of errors that can occur during tool execution.""" + + EXECUTION_ERROR = "execution_error" # Tool ran but failed + INVALID_ARGS = "invalid_args" # Invalid arguments provided + TRANSPORT_ERROR = "transport_error" # Communication failure + TOOL_NOT_FOUND = "tool_not_found" # Tool doesn't exist + TIMEOUT = "timeout" # Operation timed out + + +class ToolError(BaseModel): + """ + Structured error for tool execution failures. + + This is used for transport/framework errors, NOT for errors returned + by the tool itself (those go in the result field). + """ + + model_config = ConfigDict(extra="forbid") + + error_type: ToolErrorType = Field(description="Category of the error") + message: str = Field(description="Human-readable error message") + + +# --- MCP Actions --- + + +class ListToolsAction(Action): + """ + Request list of available tools from the environment. + + This action triggers MCP's tools/list operation and returns + all available tools with their schemas. + + Note: Does NOT require reset() to be called first. + """ + + type: Literal["list_tools"] = Field( + default="list_tools", description="Action type discriminator" + ) + + +class CallToolAction(Action): + """ + Call a specific tool via MCP. + + This action triggers MCP's tools/call operation with the + specified tool name and arguments. + """ + + type: Literal["call_tool"] = Field( + default="call_tool", description="Action type discriminator" + ) + tool_name: str = Field(description="Name of the tool to call") + arguments: Dict[str, Any] = Field( + default_factory=dict, description="Arguments to pass to the tool" + ) + + +# --- MCP Observations --- + + +class ListToolsObservation(Observation): + """ + Response containing available tools. + + Returned when processing a ListToolsAction. + """ + + tools: List[Tool] = Field(description="List of available tools with their schemas") + + +class CallToolObservation(Observation): + """ + Response from tool execution. + + Contains the tool's result or an error if the call failed. + Tool-specific errors (from the tool itself) are included in the result. + Transport/framework errors use the error field. + """ + + tool_name: str = Field(description="Name of the tool that was called") + result: Any = Field( + default=None, description="Tool-specific result (may include tool errors)" + ) + error: Optional[ToolError] = Field( + default=None, description="Transport/framework error if call failed" + ) + + +# --- WebSocket Message Types for MCP --- + + +class WSMCPMessage(BaseMessage): + """ + WebSocket message for MCP JSON-RPC requests. + + Allows direct MCP access via WebSocket for production inference, + bypassing the step() API. + """ + + type: Literal["mcp"] = Field(default="mcp", description="Message type") + data: Dict[str, Any] = Field(description="JSON-RPC payload (method, params, id)") + + +class WSMCPResponse(BaseModel): + """ + WebSocket response for MCP JSON-RPC. + + Contains the JSON-RPC response from the MCP server. + """ + + model_config = ConfigDict(extra="forbid") + + type: str = Field(default="mcp", description="Response type") + data: Dict[str, Any] = Field(description="JSON-RPC response payload") + + +# Reserved tool names that cannot be used (protects dual API boundary) +RESERVED_TOOL_NAMES = frozenset(["reset", "step", "state", "close"]) diff --git a/src/openenv/core/env_server/route_config.py b/src/openenv/core/env_server/route_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d74a7f202be0731400a6b954dfd37d9012c1f8f7 --- /dev/null +++ b/src/openenv/core/env_server/route_config.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Route configuration utilities for declarative FastAPI route registration. + +This module provides utilities to reduce boilerplate in route registration +by using configuration objects instead of repeated function calls. +""" + +from dataclasses import dataclass +from typing import Callable, List, Type + +from fastapi import FastAPI +from pydantic import BaseModel + + +@dataclass +class GetEndpointConfig: + """Configuration for a simple GET endpoint.""" + + path: str + handler: Callable[[], BaseModel | dict] + response_model: Type[BaseModel] | type[dict] + tag: str + summary: str + description: str + + +def register_get_endpoints(app: FastAPI, configs: List[GetEndpointConfig]) -> None: + """ + Register multiple GET endpoints from configuration. + + Args: + app: FastAPI application instance + configs: List of GET endpoint configurations + """ + for config in configs: + # Capture handler in a closure to avoid non-serializable default parameter + def make_endpoint( + handler: Callable[[], BaseModel | dict], + ) -> Callable[[], BaseModel | dict]: + async def endpoint() -> BaseModel | dict: + return handler() + + return endpoint + + app.get( + config.path, + response_model=config.response_model, + tags=[config.tag], + summary=config.summary, + description=config.description, + )(make_endpoint(config.handler)) diff --git a/src/openenv/core/env_server/serialization.py b/src/openenv/core/env_server/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b50d9aeb873794044e77ee398a7f2b5fca8093 --- /dev/null +++ b/src/openenv/core/env_server/serialization.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared serialization and deserialization utilities for OpenEnv HTTP servers. + +This module provides common utilities for converting between JSON dictionaries +and Pydantic models (Action/Observation) to eliminate code duplication across +HTTP server and web interface implementations. +""" + +from typing import Any, Dict, Type + +from .types import Action, Observation + + +def deserialize_action(action_data: Dict[str, Any], action_cls: Type[Action]) -> Action: + """ + Convert JSON dict to Action instance using Pydantic validation. + + This is a basic deserialization that works for most environments. + For special cases (e.g., tensor fields, custom type conversions), + use deserialize_action_with_preprocessing(). + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + + Note: + This uses Pydantic's model_validate() for automatic validation. + """ + return action_cls.model_validate(action_data) + + +def deserialize_action_with_preprocessing( + action_data: Dict[str, Any], action_cls: Type[Action] +) -> Action: + """ + Convert JSON dict to Action instance with preprocessing for special types. + + This version handles common type conversions needed for web interfaces: + - Converting lists/strings to tensors for 'tokens' field + - Converting string action_id to int + - Other custom preprocessing as needed + + Args: + action_data: Dictionary containing action data + action_cls: The Action subclass to instantiate + + Returns: + Action instance + + Raises: + ValidationError: If action_data is invalid for the action class + """ + processed_data = {} + + for key, value in action_data.items(): + if key == "tokens" and isinstance(value, (list, str)): + # Convert list or string to tensor + if isinstance(value, str): + # If it's a string, try to parse it as a list of numbers + try: + import json + + value = json.loads(value) + except Exception: + # If parsing fails, treat as empty list + value = [] + if isinstance(value, list): + try: + import torch # type: ignore + + processed_data[key] = torch.tensor(value, dtype=torch.long) + except ImportError: + # If torch not available, keep as list + processed_data[key] = value + else: + processed_data[key] = value + elif key == "action_id" and isinstance(value, str): + # Convert action_id from string to int + try: + processed_data[key] = int(value) + except ValueError: + # If conversion fails, keep original value + processed_data[key] = value + else: + processed_data[key] = value + + return action_cls.model_validate(processed_data) + + +def serialize_observation(observation: Observation) -> Dict[str, Any]: + """ + Convert Observation instance to JSON-compatible dict using Pydantic. + + Args: + observation: Observation instance + + Returns: + Dictionary compatible with EnvClient._parse_result() + + The format matches what EnvClient expects: + { + "observation": {...}, # Observation fields + "reward": float | None, + "done": bool, + } + """ + # Use Pydantic's model_dump() for serialization + obs_dict = observation.model_dump( + exclude={ + "reward", + "done", + "metadata", + } # Exclude these from observation dict + ) + + # Extract reward and done directly from the observation + reward = observation.reward + done = observation.done + + # Return in EnvClient expected format + return { + "observation": obs_dict, + "reward": reward, + "done": done, + } diff --git a/src/openenv/core/env_server/types.py b/src/openenv/core/env_server/types.py new file mode 100644 index 0000000000000000000000000000000000000000..34a198013442e5000f7fbf75b7f24157b6c04683 --- /dev/null +++ b/src/openenv/core/env_server/types.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Annotated, Any, Dict, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +# Type aliases +Scalar = Union[int, float, bool] + + +# ============================================================================= +# Enums for Type Safety +# ============================================================================= + + +class ServerMode(str, Enum): + """Server operation mode.""" + + SIMULATION = "simulation" + PRODUCTION = "production" + + +class HealthStatus(str, Enum): + """Server health status values.""" + + HEALTHY = "healthy" + UNHEALTHY = "unhealthy" + DEGRADED = "degraded" + + +class WSErrorCode(str, Enum): + """WebSocket error codes for structured error handling.""" + + INVALID_JSON = "INVALID_JSON" + UNKNOWN_TYPE = "UNKNOWN_TYPE" + VALIDATION_ERROR = "VALIDATION_ERROR" + EXECUTION_ERROR = "EXECUTION_ERROR" + CAPACITY_REACHED = "CAPACITY_REACHED" + FACTORY_ERROR = "FACTORY_ERROR" + SESSION_ERROR = "SESSION_ERROR" + + +# ============================================================================= +# Core Types +# ============================================================================= + + +class Action(BaseModel): + """Base class for all environment actions. + + All action subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", # Reject unknown fields + validate_assignment=True, # Validate on field assignment + arbitrary_types_allowed=True, # Allow numpy arrays, torch tensors, etc. + ) + + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the action" + ) + + +class Observation(BaseModel): + """Base class for all environment observations. + + All observation subclasses should inherit from this base class. + Uses Pydantic for automatic validation and serialization. + """ + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + done: bool = Field(default=False, description="Whether the episode has terminated") + reward: bool | int | float | None = Field( + default=None, description="Reward signal from the last action" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional metadata for the observation" + ) + + +class ResetRequest(BaseModel): + """Request model for environment reset.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom reset parameters + json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]}, + ) + + seed: Optional[int] = Field( + default=None, ge=0, description="Random seed for reproducible episodes" + ) + episode_id: Optional[str] = Field( + default=None, max_length=255, description="Custom episode identifier" + ) + + +class ResetResponse(BaseModel): + """Response model for environment reset.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Initial observation from the environment" + ) + reward: Optional[float] = Field( + default=None, description="Initial reward (typically None at reset)" + ) + done: bool = Field( + default=False, description="Whether episode is already done (typically False)" + ) + + +class StepRequest(BaseModel): + """Request model for environment step.""" + + model_config = ConfigDict( + extra="allow", # Allow extra fields for custom step parameters + json_schema_extra={ + "examples": [ + {"action": {"value": 1}, "timeout_s": 30.0}, + {"action": {"value": 1}, "render": True, "verbose": False}, + ] + }, + ) + + action: Dict[str, Any] = Field( + ..., + description="Action to execute, must conform to environment's action schema", + ) + timeout_s: Optional[float] = Field( + default=None, + gt=0, + description="Optional timeout in seconds for action execution", + ) + request_id: Optional[str] = Field( + default=None, + max_length=255, + description="Optional request identifier for tracking", + ) + + +class StepResponse(BaseModel): + """Response model for environment step.""" + + model_config = ConfigDict(extra="forbid") + + observation: Dict[str, Any] = Field( + ..., description="Observation resulting from the action" + ) + reward: Optional[float] = Field( + default=None, description="Reward signal from the action" + ) + done: bool = Field(default=False, description="Whether the episode has terminated") + + +class BaseMessage(BaseModel): + """Base class for WebSocket messages with shared configuration.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + +class State(BaseModel): + """Base class for environment state. + + Represents internal environment state, separate from observations. + """ + + model_config = ConfigDict( + extra="allow", # Allow extra fields for flexibility + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + episode_id: Optional[str] = Field( + default=None, description="Unique identifier for the current episode" + ) + step_count: int = Field( + default=0, + ge=0, # Greater than or equal to 0 + description="Number of steps taken in the current episode", + ) + + +class CodeExecResult(BaseMessage): + """Result of code execution containing stdout, stderr, and exit code.""" + + stdout: str = Field(description="Standard output from code execution") + stderr: str = Field(description="Standard error from code execution") + exit_code: int = Field(description="Exit code from code execution") + + +class EnvironmentMetadata(BaseMessage): + """Metadata about an environment for documentation and UI purposes.""" + + name: str = Field(description="Name of the environment") + description: str = Field(description="Description of what the environment does") + readme_content: Optional[str] = Field( + default=None, description="Content of the README file for the environment" + ) + version: Optional[str] = Field( + default=None, description="Version of the environment" + ) + author: Optional[str] = Field(default=None, description="Author of the environment") + documentation_url: Optional[str] = Field( + default=None, description="URL to the environment's documentation" + ) + + +class SchemaResponse(BaseMessage): + """Response model for the combined schema endpoint.""" + + action: Dict[str, Any] = Field( + description="JSON schema for actions accepted by this environment" + ) + observation: Dict[str, Any] = Field( + description="JSON schema for observations returned by this environment" + ) + state: Dict[str, Any] = Field( + description="JSON schema for environment state objects" + ) + + +class HealthResponse(BaseMessage): + """Response model for health check endpoint.""" + + status: HealthStatus = Field( + default=HealthStatus.HEALTHY, + description="Health status of the environment server", + ) + + +class WSResetMessage(BaseMessage): + """WebSocket message to reset the environment.""" + + type: Literal["reset"] = Field(default="reset", description="Message type") + data: Dict[str, Any] = Field( + default_factory=dict, + description="Optional reset parameters (seed, episode_id, etc.)", + ) + + +class WSStepMessage(BaseMessage): + """WebSocket message to execute a step.""" + + type: Literal["step"] = Field(default="step", description="Message type") + data: Dict[str, Any] = Field( + ..., description="Action data conforming to environment's action schema" + ) + + +class WSStateMessage(BaseMessage): + """WebSocket message to request current state.""" + + type: Literal["state"] = Field(default="state", description="Message type") + + +class WSCloseMessage(BaseMessage): + """WebSocket message to close the session.""" + + type: Literal["close"] = Field(default="close", description="Message type") + + +# Discriminated union for incoming WebSocket messages +# Note: WSMCPMessage is defined in mcp_types.py to avoid circular imports +# The union here covers the core message types; MCP messages are handled separately +WSIncomingMessage = Annotated[ + WSResetMessage | WSStepMessage | WSStateMessage | WSCloseMessage, + Field(discriminator="type"), +] + + +class WSObservationResponse(BaseModel): + """WebSocket response containing an observation.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["observation"] = Field( + default="observation", description="Response type" + ) + data: Dict[str, Any] = Field(description="Observation data") + + +class WSStateResponse(BaseModel): + """WebSocket response containing environment state.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["state"] = Field(default="state", description="Response type") + data: Dict[str, Any] = Field(description="State data") + + +class WSErrorResponse(BaseModel): + """WebSocket response for errors.""" + + model_config = ConfigDict(extra="forbid") + + type: Literal["error"] = Field(default="error", description="Response type") + data: Dict[str, Any] = Field(description="Error details including message and code") + + +class ConcurrencyConfig(BaseMessage): + """Configuration for concurrent environment sessions.""" + + max_concurrent_envs: int = Field( + default=1, + ge=1, + description="Maximum number of concurrent WebSocket sessions allowed", + ) + session_timeout: Optional[float] = Field( + default=None, + gt=0, + description="Timeout in seconds for inactive sessions. None means no timeout.", + ) + + +class ServerCapacityStatus(BaseMessage): + """Status of server capacity for concurrent sessions.""" + + active_sessions: int = Field( + ge=0, + description="Number of currently active sessions", + ) + max_sessions: int = Field( + ge=1, + description="Maximum number of allowed sessions", + ) + + @model_validator(mode="after") + def check_capacity_bounds(self) -> "ServerCapacityStatus": + if self.active_sessions > self.max_sessions: + raise ValueError( + f"active_sessions ({self.active_sessions}) cannot exceed " + f"max_sessions ({self.max_sessions})" + ) + return self + + @property + def available_slots(self) -> int: + """Number of available session slots.""" + return self.max_sessions - self.active_sessions + + @property + def is_at_capacity(self) -> bool: + """Whether the server has reached maximum capacity.""" + return self.available_slots == 0 + + @classmethod + def from_counts(cls, active: int, max_sessions: int) -> "ServerCapacityStatus": + """Create status from active and max session counts.""" + return cls( + active_sessions=active, + max_sessions=max_sessions, + ) + + +class SessionInfo(BaseMessage): + """Information about an active session.""" + + session_id: str = Field(description="Unique identifier for the session") + created_at: float = Field(description="Unix timestamp when the session was created") + last_activity_at: float = Field( + description="Unix timestamp of the last activity in the session" + ) + step_count: int = Field( + default=0, + ge=0, + description="Number of steps executed in this session", + ) + environment_type: str = Field( + description="Environment type for this session (e.g. `CodingEnv`)" + ) diff --git a/src/openenv/core/env_server/web_interface.py b/src/openenv/core/env_server/web_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..284740eb408b8e2b798037918967b7a50abee72d --- /dev/null +++ b/src/openenv/core/env_server/web_interface.py @@ -0,0 +1,644 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Web interface for OpenEnv environments. + +When ENABLE_WEB_INTERFACE is set, the server exposes a Gradio UI at /web for +reset, step, and state observation. Controlled by the CLI enable_interface +option (e.g. openenv push --enable-interface) or ENABLE_WEB_INTERFACE env var. +""" + +from __future__ import annotations + +import asyncio +import json +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Type + +import gradio as gr +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from pydantic import BaseModel, ConfigDict, Field + +from .gradio_theme import OPENENV_GRADIO_CSS, OPENENV_GRADIO_THEME +from .gradio_ui import build_gradio_app, get_gradio_display_title +from .interfaces import Environment +from .serialization import deserialize_action_with_preprocessing, serialize_observation +from .types import Action, EnvironmentMetadata, Observation, State + +# Quick Start markdown template; placeholders match init suffixes (__ENV_NAME__, __ENV_CLASS_NAME__*). +DEFAULT_QUICK_START_MARKDOWN = """ +### Connect to this environment + +Connect from Python using `__ENV_CLASS_NAME__Env`: + +```python +from __ENV_NAME__ import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Env + +with __ENV_CLASS_NAME__Env.from_env("") as env: + result = await env.step(__ENV_CLASS_NAME__Action(message="...")) +``` + +Or connect directly to a running server: + +```python +env = __ENV_CLASS_NAME__Env(base_url="http://localhost:8000") +``` + +### Contribute to this environment + +Submit improvements via pull request on the Hugging Face Hub. + +```bash +openenv fork --repo-id / +``` + +Then make your changes and submit a pull request: + +```bash +cd +openenv push --create-pr +``` + +For more information, see the [OpenEnv documentation](https://meta-pytorch.org/OpenEnv/). +""" + + +def get_quick_start_markdown( + metadata: Optional[EnvironmentMetadata], + action_cls: Type[Action], + observation_cls: Type[Observation], +) -> str: + """ + Build Quick Start markdown with class names replaced from current env (init-style suffixes). + + Uses the same placeholder names as the init template so that __ENV_CLASS_NAME__Env, + __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation and __ENV_NAME__ are + replaced with the actual class/package names. + """ + import os + + # Prefix from action class (e.g. EchoAction -> Echo) + action_name = getattr(action_cls, "__name__", "Action") + if action_name.endswith("Action"): + prefix = action_name[: -len("Action")] + else: + prefix = action_name.replace("Action", "").strip() or "Env" + + env_client_name = f"{prefix}Env" + obs_name = getattr(observation_cls, "__name__", "Observation") + pkg_name = (metadata.name if metadata else "env").replace(" ", "_").lower() + + space_id = os.environ.get("SPACE_ID", "/") + + content = DEFAULT_QUICK_START_MARKDOWN + content = content.replace("__ENV_CLASS_NAME__Env", env_client_name) + content = content.replace("__ENV_CLASS_NAME__Action", action_name) + content = content.replace("__ENV_CLASS_NAME__Observation", obs_name) + content = content.replace("__ENV_CLASS_NAME__", prefix) + content = content.replace("__ENV_NAME__", pkg_name) + content = content.replace("", space_id) + return content.strip() + + +def load_environment_metadata( + env: Environment, env_name: Optional[str] = None +) -> EnvironmentMetadata: + """ + Load environment metadata including README content. + + Args: + env: The environment instance, class, or factory function. + - If a class: used as a factory, won't call instance methods + - If a function: used as a factory, won't call instance methods + - If an instance: may call get_metadata() if available + env_name: Optional environment name for README file lookup + + Returns: + EnvironmentMetadata with loaded information + """ + import inspect + + # Determine what type of env we received: + # 1. A class (used as factory) - e.g., PythonCodeActEnv + # 2. A function (factory function) - e.g., create_chat_environment + # 3. An actual instance - e.g., SnakeEnvironment() + is_class = inspect.isclass(env) + is_function = inspect.isfunction(env) or inspect.ismethod(env) + is_factory = is_class or is_function + + # Try to get metadata from environment if it's an instance with get_metadata + if not is_factory and hasattr(env, "get_metadata"): + return env.get_metadata() + + # Determine the class name for default metadata + if is_class: + # env is the class itself + class_name = env.__name__ + elif is_function: + # env is a factory function - use its name or derive from env_name + class_name = env_name or env.__name__ + else: + # env is an instance + class_name = env.__class__.__name__ + + # Default metadata + metadata = EnvironmentMetadata( + name=env_name or class_name, + description=f"{class_name} environment", + version="1.0.0", + ) + + # Try to load README from file system + readme_content = _load_readme_from_filesystem(env_name) + if readme_content: + metadata.readme_content = readme_content + + return metadata + + +def _load_readme_from_filesystem(env_name: Optional[str]) -> Optional[str]: + """ + Load README content from the filesystem. + + Tries multiple locations: + 1. Container filesystem: /app/README.md + 2. Local development: src/envs/{env_name}/README.md + 3. Environment variable: ENV_README_PATH + """ + import os + from pathlib import Path + + # Try container filesystem first + container_readme = Path("/app/README.md") + if container_readme.exists(): + try: + return container_readme.read_text(encoding="utf-8") + except Exception: + pass + + # Try environment variable path + custom_path = os.environ.get("ENV_README_PATH") + if custom_path and Path(custom_path).exists(): + try: + return Path(custom_path).read_text(encoding="utf-8") + except Exception: + pass + + # Try local development path + if env_name: + local_readme = Path(f"src/envs/{env_name}/README.md") + if local_readme.exists(): + try: + return local_readme.read_text(encoding="utf-8") + except Exception: + pass + + return None + + +class ActionLog(BaseModel): + """Log entry for an action taken.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + timestamp: str = Field(description="Timestamp when action was taken") + action: Dict[str, Any] = Field(description="Action that was taken") + observation: Dict[str, Any] = Field(description="Observation returned from action") + reward: Optional[float] = Field( + default=None, description="Reward received from action" + ) + done: bool = Field(description="Whether the episode is done after this action") + step_count: int = Field(description="Step count when this action was taken") + + +class EpisodeState(BaseModel): + """Current episode state for the web interface.""" + + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + episode_id: Optional[str] = Field(default=None, description="Current episode ID") + step_count: int = Field(description="Current step count in episode") + current_observation: Optional[Dict[str, Any]] = Field( + default=None, description="Current observation" + ) + action_logs: List[ActionLog] = Field( + default_factory=list, description="List of action logs" + ) + is_reset: bool = Field( + default=True, description="Whether the episode has been reset" + ) + + +class WebInterfaceManager: + """Manages the web interface for an environment.""" + + MAX_ACTION_LOGS = 1000 + + def __init__( + self, + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + metadata: Optional[EnvironmentMetadata] = None, + ): + import inspect + + # If env is a class or factory function, instantiate it + if inspect.isclass(env) or inspect.isfunction(env): + self.env = env() + else: + self.env = env + self.action_cls = action_cls + self.observation_cls = observation_cls + self.metadata = metadata or EnvironmentMetadata( + name=env.__class__.__name__, + description=f"{env.__class__.__name__} environment", + ) + self.episode_state = EpisodeState( + episode_id=None, + step_count=0, + current_observation=None, + action_logs=[], + ) + self.connected_clients: List[WebSocket] = [] + # Thread pool for running sync code (e.g., Playwright sync API) in async context + self._executor = ThreadPoolExecutor(max_workers=1) + + async def _run_sync_in_thread_pool(self, func, *args, **kwargs): + """Run a synchronous function in the thread pool executor. + + This is needed for environments using sync libraries (e.g., Playwright sync API) + that cannot be called directly from an async context. + """ + loop = asyncio.get_event_loop() + # Use default arguments to capture values at lambda definition time + # to avoid closure issues with late binding + return await loop.run_in_executor( + self._executor, lambda f=func, a=args, kw=kwargs: f(*a, **kw) + ) + + async def connect_websocket(self, websocket: WebSocket): + """Connect a new WebSocket client.""" + await websocket.accept() + self.connected_clients.append(websocket) + + # Send current state to the new client + await self._send_state_update() + + async def disconnect_websocket(self, websocket: WebSocket): + """Disconnect a WebSocket client.""" + if websocket in self.connected_clients: + self.connected_clients.remove(websocket) + + async def _send_state_update(self): + """Send current state to all connected clients.""" + if not self.connected_clients: + return + + state_data = { + "type": "state_update", + "episode_state": self.episode_state.model_dump(), + } + + # Send to all connected clients + disconnected_clients = [] + for client in self.connected_clients: + try: + await client.send_text(json.dumps(state_data)) + except Exception: + disconnected_clients.append(client) + + # Remove disconnected clients + for client in disconnected_clients: + self.connected_clients.remove(client) + + async def reset_environment(self) -> Dict[str, Any]: + """Reset the environment and update state.""" + # Run sync reset in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool(self.env.reset) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = 0 + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs = [] + self.episode_state.is_reset = True + + # Send state update + await self._send_state_update() + + return serialized + + async def step_environment(self, action_data: Dict[str, Any]) -> Dict[str, Any]: + """Execute a step in the environment and update state.""" + # Deserialize action with preprocessing for web interface special cases + action: Action = deserialize_action_with_preprocessing( + action_data, self.action_cls + ) + + # Run sync step in thread pool to avoid blocking event loop + # and to support environments using sync libraries (e.g., Playwright) + observation: Observation = await self._run_sync_in_thread_pool( + self.env.step, action + ) + state: State = self.env.state + + # Serialize observation once using shared utility + serialized = serialize_observation(observation) + + # Create action log + action_log = ActionLog( + timestamp=datetime.now().isoformat(), + action=action.model_dump(exclude={"metadata"}), + observation=serialized["observation"], + reward=observation.reward, + done=observation.done, + step_count=state.step_count, + ) + + # Update episode state + self.episode_state.episode_id = state.episode_id + self.episode_state.step_count = state.step_count + self.episode_state.current_observation = serialized["observation"] + self.episode_state.action_logs.append(action_log) + if len(self.episode_state.action_logs) > self.MAX_ACTION_LOGS: + self.episode_state.action_logs = self.episode_state.action_logs[ + -self.MAX_ACTION_LOGS : + ] + self.episode_state.is_reset = False + + # Send state update + await self._send_state_update() + + return serialized + + def get_state(self) -> Dict[str, Any]: + """Get current environment state.""" + state: State = self.env.state + return state.model_dump() + + +def create_web_interface_app( + env: Environment, + action_cls: Type[Action], + observation_cls: Type[Observation], + env_name: Optional[str] = None, + max_concurrent_envs: Optional[int] = None, + concurrency_config: Optional[Any] = None, + gradio_builder: Optional[Callable[..., Any]] = None, +) -> FastAPI: + """ + Create a FastAPI application with web interface for the given environment. + + Args: + env: The Environment instance to serve + action_cls: The Action subclass this environment expects + observation_cls: The Observation subclass this environment returns + env_name: Optional environment name for README loading + max_concurrent_envs: Maximum concurrent WebSocket sessions + concurrency_config: Optional ConcurrencyConfig for advanced concurrency settings + gradio_builder: Optional callable (web_manager, action_fields, metadata, + is_chat_env, title, quick_start_md) -> gr.Blocks to use instead of the + default Gradio UI. Lets envs replace or customize the /web interface. + + Returns: + FastAPI application instance with web interface + """ + from .http_server import create_fastapi_app + + # Create the base environment app + app = create_fastapi_app( + env, action_cls, observation_cls, max_concurrent_envs, concurrency_config + ) + + # Load environment metadata + metadata = load_environment_metadata(env, env_name) + + # Create web interface manager + web_manager = WebInterfaceManager(env, action_cls, observation_cls, metadata) + + # Web API routes first (so they take precedence over Gradio mount at /web) + @app.get("/web/metadata") + async def web_metadata(): + """Get environment metadata.""" + return web_manager.metadata.model_dump() + + @app.websocket("/ws/ui") + async def websocket_ui_endpoint(websocket: WebSocket): + """WebSocket endpoint for web UI real-time updates. + + Note: Uses /ws/ui to avoid conflict with /ws in http_server.py + which is used for concurrent environment sessions. + """ + await web_manager.connect_websocket(websocket) + try: + while True: + # Keep connection alive + await websocket.receive_text() + except WebSocketDisconnect: + await web_manager.disconnect_websocket(websocket) + + @app.post("/web/reset") + async def web_reset(): + """Reset endpoint for web interface.""" + return await web_manager.reset_environment() + + @app.post("/web/step") + async def web_step(request: Dict[str, Any]): + """Step endpoint for web interface.""" + # Check if this is a message-based request (chat environment) + if "message" in request: + message = request["message"] + if hasattr(web_manager.env, "message_to_action"): + action = web_manager.env.message_to_action(message) + if hasattr(action, "tokens"): + action_data = {"tokens": action.tokens.tolist()} + else: + action_data = action.model_dump(exclude={"metadata"}) + else: + action_data = {"message": message} + else: + action_data = request.get("action", {}) + + return await web_manager.step_environment(action_data) + + @app.get("/web/state") + async def web_state(): + """State endpoint for web interface.""" + return web_manager.get_state() + + action_fields = _extract_action_fields(action_cls) + is_chat_env = _is_chat_env(action_cls) + quick_start_md = get_quick_start_markdown(metadata, action_cls, observation_cls) + + default_blocks = build_gradio_app( + web_manager, + action_fields, + metadata, + is_chat_env, + title=metadata.name, + quick_start_md=quick_start_md, + ) + if gradio_builder is not None: + custom_blocks = gradio_builder( + web_manager, + action_fields, + metadata, + is_chat_env, + metadata.name, + quick_start_md, + ) + if not isinstance(custom_blocks, gr.Blocks): + raise TypeError( + f"gradio_builder must return a gr.Blocks instance, " + f"got {type(custom_blocks).__name__}" + ) + gradio_blocks = gr.TabbedInterface( + [default_blocks, custom_blocks], + tab_names=["Playground", "Visualization"], + title=get_gradio_display_title(metadata), + ) + else: + gradio_blocks = default_blocks + app = gr.mount_gradio_app( + app, + gradio_blocks, + path="/web", + theme=OPENENV_GRADIO_THEME, + css=OPENENV_GRADIO_CSS, + ) + + return app + + +def _is_chat_env(action_cls: Type[Action]) -> bool: + """Return True if the action class is a chat-style env (tokens field).""" + if hasattr(action_cls, "model_fields"): + for field_name, field_info in action_cls.model_fields.items(): + if ( + field_name == "tokens" + and hasattr(field_info.annotation, "__name__") + and "Tensor" in str(field_info.annotation) + ): + return True + return False + + +def _extract_action_fields(action_cls: Type[Action]) -> List[Dict[str, Any]]: + """Extract enhanced field metadata from Action class for form generation.""" + # Use Pydantic's JSON schema generation for robust metadata extraction + try: + schema = action_cls.model_json_schema() + except AttributeError: + # Fallback for non-Pydantic v2 models or if something goes wrong + return [] + + properties = schema.get("properties", {}) + required_fields = schema.get("required", []) + + action_fields = [] + + for field_name, field_info in properties.items(): + if field_name == "metadata": + continue + + # JSON schema "type" can be a string or list/undefined + # Determine our internal input type + input_type = _determine_input_type_from_schema(field_info, field_name) + + is_required = field_name in required_fields + + action_fields.append( + { + "name": field_name, + "type": input_type, + "required": is_required, + "description": field_info.get("description", ""), + "default_value": field_info.get("default"), + "choices": field_info.get("enum"), + "min_value": field_info.get("minimum"), + "max_value": field_info.get("maximum"), + "min_length": field_info.get("minLength"), + "max_length": field_info.get("maxLength"), + "pattern": field_info.get("pattern"), + "placeholder": _generate_placeholder(field_name, field_info), + "help_text": _generate_help_text(field_name, field_info), + } + ) + + return action_fields + + +def _determine_input_type_from_schema( + field_info: Dict[str, Any], field_name: str +) -> str: + """Determine input type from JSON schema for form generation (Gradio UI).""" + schema_type = field_info.get("type") + + # Check for specific tensor field convention + if "tokens" in field_name.lower(): + return "tensor" + + if "enum" in field_info: + return "select" + + if schema_type == "boolean": + return "checkbox" + + if schema_type == "integer" or schema_type == "number": + return "number" + + if schema_type == "string": + # Check if it should be a textarea + if ( + field_info.get("maxLength", 0) > 100 + or "message" in field_name.lower() + or "code" in field_name.lower() + ): + return "textarea" + return "text" + + # Default fallback + return "text" + + +def _generate_placeholder(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate placeholder text.""" + if "message" in field_name.lower(): + return f"Enter {field_name.replace('_', ' ')}..." + elif "code" in field_name.lower(): + return "Enter Python code here..." + elif "tokens" in field_name.lower(): + return "Enter comma-separated token IDs (e.g., 1,2,3,4,5)" + else: + return f"Enter {field_name.replace('_', ' ')}..." + + +def _generate_help_text(field_name: str, field_info: Dict[str, Any]) -> str: + """Generate help text.""" + description = field_info.get("description", "") + if description: + return description + + if "action_id" in field_name.lower(): + return "The action ID to execute in environment" + elif "game_name" in field_name.lower(): + return "Name of game or environment" + elif "tokens" in field_name.lower(): + return "Token IDs as a comma-separated list of integers" + elif "code" in field_name.lower(): + return "Python code to execute in environment" + elif "message" in field_name.lower(): + return "Text message to send" + + return "" diff --git a/src/openenv/core/evals/__init__.py b/src/openenv/core/evals/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52e564a09b5e4976f2cd5a8c1fe1c7848bb47ecb --- /dev/null +++ b/src/openenv/core/evals/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Evaluation harness support for OpenEnv.""" + +from openenv.core.evals.base import EvalHarness +from openenv.core.evals.inspect_harness import InspectAIHarness +from openenv.core.evals.types import EvalConfig, EvalResult + +__all__ = [ + "EvalHarness", + "EvalConfig", + "EvalResult", + "InspectAIHarness", +] diff --git a/src/openenv/core/evals/base.py b/src/openenv/core/evals/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e457d8adb740569ad79143cbf70bc58b05a8cef9 --- /dev/null +++ b/src/openenv/core/evals/base.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base class for evaluation harnesses.""" + +from abc import ABC, abstractmethod +from typing import Any, Dict + +from openenv.core.evals.types import EvalConfig, EvalResult + + +class EvalHarness(ABC): + """Abstract base class for evaluation harnesses. + + Subclasses implement run() to define evaluation logic. + """ + + @abstractmethod + def run( + self, + harness_version: str, + library_versions: Dict[str, str], + dataset: str, + eval_parameters: Dict[str, Any], + ) -> Dict[str, Any]: + """Run the evaluation and return scores. + + Args: + harness_version: Version of the evaluation harness. + library_versions: Versions of libraries used in the evaluation. + dataset: Name of the dataset to evaluate on. + eval_parameters: Parameters for the evaluation. + + Returns: + Dictionary of scores from the evaluation. + """ + raise NotImplementedError + + def run_from_config(self, config: EvalConfig) -> EvalResult: + """Run evaluation from an EvalConfig and return an EvalResult. + + Args: + config: Configuration for the evaluation. + + Returns: + EvalResult containing the config and scores. + """ + scores = self.run( + harness_version=config.harness_version, + library_versions=config.library_versions, + dataset=config.dataset, + eval_parameters=config.eval_parameters, + ) + return EvalResult(config=config, scores=scores) + + @property + def name(self) -> str: + """Return the name of the harness (class name).""" + return self.__class__.__name__ diff --git a/src/openenv/core/evals/inspect_harness.py b/src/openenv/core/evals/inspect_harness.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf91105db6cf325587623891905e5cbc71c124e --- /dev/null +++ b/src/openenv/core/evals/inspect_harness.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Inspect AI harness integration for OpenEnv. + +Requires the ``inspect-ai`` package: ``pip install 'inspect-ai>=0.3.0'`` +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from openenv.core.evals.base import EvalHarness + + +class InspectAIHarness(EvalHarness): + """Evaluation harness wrapping Inspect AI's ``eval()`` function. + + All ``inspect_ai`` imports are deferred to :meth:`run` so this class is + importable without inspect-ai installed. An ``ImportError`` with a clear + message is raised at call time if the dependency is missing. + + Args: + log_dir: Directory for evaluation log output. Defaults to None + (Inspect AI writes logs to its default location). + + ``eval_parameters`` keys accepted by :meth:`run`: + + +--------------------------+----------+-----------------+-----------------------------------+ + | Key | Type | Default | Purpose | + +==========================+==========+=================+===================================+ + | ``model`` | str | *required* | Model string, e.g. "openai/gpt-4o"| + | ``task`` | str|None | ``dataset`` arg | Task file path or task string | + | ``task_args`` | dict | ``{}`` | Arguments to pass to the task | + | ``max_samples`` | int|None | None | Limit samples per task | + | ``temperature`` | float|None| None | Model generation temperature | + | ``max_tokens`` | int|None | None | Max generation tokens | + | ``epochs`` | int|None | None | Number of evaluation epochs | + | ``solver`` | list|None| None | Solver pipeline override | + | ``scorer`` | list|None| None | Scorer override | + | ``model_args`` | dict | ``{}`` | Provider-specific model kwargs | + +--------------------------+----------+-----------------+-----------------------------------+ + """ + + def __init__( + self, + *, + log_dir: Optional[str] = None, + ): + self.log_dir = log_dir + + def run( + self, + harness_version: str, + library_versions: Dict[str, str], + dataset: str, + eval_parameters: Dict[str, Any], + ) -> Dict[str, Any]: + """Run an Inspect AI evaluation. + + Args: + harness_version: Version of inspect-ai being used. + library_versions: Versions of supporting libraries. + dataset: Default task string (used when ``task`` is not specified + in *eval_parameters*). + eval_parameters: See class docstring for accepted keys. + + Returns: + Dictionary mapping metric names to scores. + + Raises: + ImportError: If ``inspect-ai`` is not installed. + ValueError: If ``model`` is missing from *eval_parameters*. + RuntimeError: If the evaluation fails (log status is not "success"). + """ + try: + from inspect_ai import eval as inspect_eval + except ImportError: + raise ImportError( + "inspect-ai is required for InspectAIHarness. " + "Install it with: pip install 'inspect-ai>=0.3.0'" + ) + + # Extract required model parameter + model = eval_parameters.get("model") + if model is None: + raise ValueError( + "eval_parameters must include 'model' " + "(e.g. 'openai/gpt-4o', 'hf/meta-llama/...')." + ) + + # Task: explicit parameter or fall back to dataset + task = eval_parameters.get("task", dataset) + + # Build eval kwargs + eval_kwargs: Dict[str, Any] = {} + + task_args = eval_parameters.get("task_args", {}) + if task_args: + eval_kwargs["task_args"] = task_args + + model_args = eval_parameters.get("model_args", {}) + if model_args: + eval_kwargs["model_args"] = model_args + + for key in ("max_samples", "temperature", "max_tokens", "epochs"): + value = eval_parameters.get(key) + if value is not None: + eval_kwargs[key] = value + + if eval_parameters.get("solver") is not None: + eval_kwargs["solver"] = eval_parameters["solver"] + + if eval_parameters.get("scorer") is not None: + eval_kwargs["scorer"] = eval_parameters["scorer"] + + if self.log_dir is not None: + eval_kwargs["log_dir"] = self.log_dir + + # Run evaluation + logs = inspect_eval(task, model=model, **eval_kwargs) + + # Extract results from the first log + if not logs: + raise RuntimeError( + "Inspect AI evaluation returned no logs. " + "Check that the task and model arguments are valid." + ) + log = logs[0] + if log.status != "success": + raise RuntimeError( + f"Inspect AI evaluation failed with status: {log.status}" + ) + + return self._extract_scores(log) + + def _extract_scores(self, log: Any) -> Dict[str, Any]: + """Parse an EvalLog's results into a flat score dictionary. + + Iterates over ``log.results.scores`` (a list of ``EvalScore``), + flattening each scorer's ``metrics`` dict into a single output dict. + + Args: + log: An ``inspect_ai`` ``EvalLog`` object. + + Returns: + Dictionary mapping metric names to their values. + """ + scores: Dict[str, Any] = {} + if log.results is None: + return scores + + for eval_score in log.results.scores: + for metric_name, metric in eval_score.metrics.items(): + scores[metric_name] = metric.value + + return scores diff --git a/src/openenv/core/evals/types.py b/src/openenv/core/evals/types.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6b14f762624c607c345e5dff1bc77faa5b4b56 --- /dev/null +++ b/src/openenv/core/evals/types.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pydantic models for eval configuration and results.""" + +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict, Field + + +class EvalConfig(BaseModel): + """Configuration for running an evaluation.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + harness_name: str = Field(description="Name of the evaluation harness") + harness_version: str = Field(description="Version of the evaluation harness") + library_versions: Dict[str, str] = Field( + description="Versions of libraries used in the evaluation" + ) + dataset: str = Field(description="Name of the dataset to evaluate on") + eval_parameters: Dict[str, Any] = Field(description="Parameters for the evaluation") + + +class EvalResult(BaseModel): + """Result of running an evaluation.""" + + model_config = ConfigDict( + extra="forbid", + validate_assignment=True, + ) + + config: EvalConfig = Field(description="Configuration used for the evaluation") + scores: Dict[str, Any] = Field(description="Scores from the evaluation") diff --git a/src/openenv/core/generic_client.py b/src/openenv/core/generic_client.py new file mode 100644 index 0000000000000000000000000000000000000000..17576862293feeebf68b4a90d6a4a80de369dd34 --- /dev/null +++ b/src/openenv/core/generic_client.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Generic environment client that works with raw dictionaries. + +This module provides a GenericEnvClient that doesn't require installing +environment-specific packages. It's useful for connecting to remote servers +without running any untrusted code locally. +""" + +from typing import Any, Dict + +from .client_types import StepResult +from .env_client import EnvClient + + +class GenericEnvClient(EnvClient[Dict[str, Any], Dict[str, Any], Dict[str, Any]]): + """ + Environment client that works with raw dictionaries instead of typed classes. + + This client doesn't require installing environment-specific packages, making it + ideal for: + - Connecting to remote servers without installing their packages + - Quick prototyping and testing + - Environments where type safety isn't needed + - Security-conscious scenarios where you don't want to run remote code + + The trade-off is that you lose type safety and IDE autocomplete for actions + and observations. Instead of typed objects, you work with plain dictionaries. + + Example: + >>> # Direct connection to a running server (no installation needed) + >>> with GenericEnvClient(base_url="http://localhost:8000") as env: + ... result = env.reset() + ... result = env.step({"code": "print('hello')"}) + ... print(result.observation) # Dict[str, Any] + ... print(result.observation.get("output")) + + >>> # From local Docker image + >>> env = GenericEnvClient.from_docker_image("coding-env:latest") + >>> result = env.reset() + >>> result = env.step({"code": "x = 1 + 2"}) + >>> env.close() + + >>> # From HuggingFace Hub (pulls Docker image, no pip install) + >>> env = GenericEnvClient.from_env("user/my-env", use_docker=True) + >>> result = env.reset() + >>> env.close() + + Note: + GenericEnvClient inherits `from_docker_image()` and `from_env()` from + EnvClient, so you can use it with Docker containers and HuggingFace + Spaces without any package installation. + """ + + def _step_payload(self, action: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert action to payload for the server. + + For GenericEnvClient, this handles both raw dictionaries and + typed Action objects (Pydantic models). If a Pydantic model is + passed, it will be converted to a dictionary using model_dump(). + + Args: + action: Action as a dictionary or Pydantic BaseModel + + Returns: + The action as a dictionary for the server + """ + # If it's already a dict, return as-is + if isinstance(action, dict): + return action + + # If it's a Pydantic model (Action subclass), convert to dict + if hasattr(action, "model_dump"): + return action.model_dump() + + # Fallback for other objects with __dict__ + if hasattr(action, "__dict__"): + return vars(action) + + # Last resort: try to convert to dict + return dict(action) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Dict[str, Any]]: + """ + Parse server response into a StepResult. + + Extracts the observation, reward, and done fields from the + server response. + + Args: + payload: Response payload from the server + + Returns: + StepResult with observation as a dictionary + """ + return StepResult( + observation=payload.get("observation", {}), + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """ + Parse state response from the server. + + For GenericEnvClient, this returns the payload as-is since + we're working with dictionaries. + + Args: + payload: State payload from the server + + Returns: + The state as a dictionary + """ + return payload + + +class GenericAction(Dict[str, Any]): + """ + A dictionary subclass for creating actions when using GenericEnvClient. + + This provides a semantic wrapper around dictionaries to make code more + readable when working with GenericEnvClient. It behaves exactly like a + dict but signals intent that this is an action for an environment. + + Example: + >>> # Without GenericAction (works fine) + >>> env.step({"code": "print('hello')"}) + + >>> # With GenericAction (more explicit) + >>> action = GenericAction(code="print('hello')") + >>> env.step(action) + + >>> # With multiple fields + >>> action = GenericAction(code="x = 1", timeout=30, metadata={"tag": "test"}) + >>> env.step(action) + + Note: + GenericAction is just a dict with a constructor that accepts keyword + arguments. It's provided for symmetry with typed Action classes and + to make code more readable. + """ + + def __init__(self, **kwargs: Any) -> None: + """ + Create a GenericAction from keyword arguments. + + Args: + **kwargs: Action fields as keyword arguments + + Example: + >>> action = GenericAction(code="print(1)", timeout=30) + >>> action["code"] + 'print(1)' + """ + super().__init__(kwargs) + + def __repr__(self) -> str: + """Return a readable representation.""" + items = ", ".join(f"{k}={v!r}" for k, v in self.items()) + return f"GenericAction({items})" diff --git a/src/openenv/core/llm_client.py b/src/openenv/core/llm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..9df2ff27ae7c2054108ff159b9dec8e4c9dd238c --- /dev/null +++ b/src/openenv/core/llm_client.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""LLM client abstraction for calling LLM endpoints. + +Provides a generic RPC abstraction: point it at an endpoint/port, tell it the +protocol, and it works. OpenAI-compatible API is the first implementation, +covering OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, etc. +Anthropic's native API is supported via ``AnthropicClient``. + +Usage: + client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") + response = await client.complete("What is 2+2?") + + # Or use the factory for hosted APIs: + client = create_llm_client("openai", model="gpt-4", api_key="sk-...") + response = await client.complete_with_tools(messages, tools) +""" + +from __future__ import annotations + +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from openai import AsyncOpenAI + + +@dataclass +class ToolCall: + """A single tool/function call returned by the model.""" + + id: str + name: str + args: dict[str, Any] + + +@dataclass +class LLMResponse: + """Normalized response from an LLM, with optional tool calls.""" + + content: str + tool_calls: list[ToolCall] = field(default_factory=list) + + def to_message_dict(self) -> dict[str, Any]: + """Convert to an OpenAI-format assistant message dict.""" + msg: dict[str, Any] = {"role": "assistant", "content": self.content} + if self.tool_calls: + msg["tool_calls"] = [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": json.dumps(tc.args), + }, + } + for tc in self.tool_calls + ] + return msg + + +class LLMClient(ABC): + """Abstract base for LLM endpoint clients. + + Subclass and implement ``complete()`` for your protocol. + + Args: + endpoint: The base URL of the LLM service (e.g. "http://localhost"). + port: The port the service listens on. + """ + + def __init__(self, endpoint: str, port: int): + self.endpoint = endpoint + self.port = port + + @abstractmethod + async def complete(self, prompt: str, **kwargs) -> str: + """Send a prompt, return the text response. + + Args: + prompt: The user prompt to send. + **kwargs: Override default parameters (temperature, max_tokens, etc.). + + Returns: + The model's text response. + """ + ... + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + """Send messages with tool definitions, return a normalized response. + + Messages use OpenAI-format dicts (``{"role": "...", "content": "..."}``). + Tools use MCP tool definitions; they are converted internally. + + Args: + messages: Conversation history as OpenAI-format message dicts. + tools: MCP tool definitions. + **kwargs: Override default parameters (temperature, max_tokens, etc.). + + Returns: + An ``LLMResponse`` with the model's text and any tool calls. + """ + raise NotImplementedError( + f"{type(self).__name__} does not support tool calling" + ) + + @property + def base_url(self) -> str: + """Construct base URL from endpoint and port.""" + return f"{self.endpoint}:{self.port}" + + +class OpenAIClient(LLMClient): + """Client for OpenAI-compatible APIs. + + Works with: OpenAI, vLLM, TGI, Ollama, HuggingFace Inference API, + or any endpoint that speaks the OpenAI chat completions format. + + Args: + endpoint: The base URL (e.g. "http://localhost"). + port: The port number. + model: Model name to pass to the API. + api_key: API key. Defaults to "not-needed" for local endpoints. + system_prompt: Optional system message prepended to every request. + temperature: Default sampling temperature. + max_tokens: Default max tokens in the response. + """ + + def __init__( + self, + endpoint: str, + port: int, + model: str, + api_key: str | None = None, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 256, + ): + super().__init__(endpoint, port) + self.model = model + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + + self._client = AsyncOpenAI( + base_url=f"{self.base_url}/v1", + api_key=api_key if api_key is not None else "not-needed", + ) + + async def complete(self, prompt: str, **kwargs) -> str: + """Send a chat completion request. + + Args: + prompt: The user message. + **kwargs: Overrides for temperature, max_tokens. + + Returns: + The assistant's response text. + """ + messages = [] + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = await self._client.chat.completions.create( + model=self.model, + messages=messages, + temperature=kwargs.get("temperature", self.temperature), + max_tokens=kwargs.get("max_tokens", self.max_tokens), + ) + return response.choices[0].message.content or "" + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": messages, + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + openai_tools = _mcp_tools_to_openai(tools) + if openai_tools: + create_kwargs["tools"] = openai_tools + + response = await self._client.chat.completions.create(**create_kwargs) + msg = response.choices[0].message + + tool_calls = [] + if msg.tool_calls: + for tc in msg.tool_calls: + tool_calls.append( + ToolCall( + id=tc.id, + name=tc.function.name, + args=json.loads(tc.function.arguments), + ) + ) + + return LLMResponse(content=msg.content or "", tool_calls=tool_calls) + + +class AnthropicClient(LLMClient): + """Client for Anthropic's Messages API. + + Requires the ``anthropic`` package (lazy-imported at construction time). + + Args: + endpoint: The base URL (e.g. "https://api.anthropic.com"). + port: The port number. + model: Model name (e.g. "claude-sonnet-4-20250514"). + api_key: Anthropic API key. + system_prompt: Optional system message prepended to every request. + temperature: Default sampling temperature. + max_tokens: Default max tokens in the response. + """ + + def __init__( + self, + endpoint: str, + port: int, + model: str, + api_key: str | None = None, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 256, + ): + super().__init__(endpoint, port) + self.model = model + self.system_prompt = system_prompt + self.temperature = temperature + self.max_tokens = max_tokens + + try: + from anthropic import AsyncAnthropic + except ImportError as exc: + raise ImportError( + "AnthropicClient requires the 'anthropic' package. " + "Install it with: pip install anthropic" + ) from exc + + self._client = AsyncAnthropic( + base_url=self.base_url, + api_key=api_key if api_key is not None else "not-needed", + ) + + async def complete(self, prompt: str, **kwargs) -> str: + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + if self.system_prompt: + create_kwargs["system"] = self.system_prompt + + response = await self._client.messages.create(**create_kwargs) + return "".join(block.text for block in response.content if block.type == "text") + + async def complete_with_tools( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]], + **kwargs: Any, + ) -> LLMResponse: + system, anthropic_msgs = _openai_msgs_to_anthropic(messages) + + create_kwargs: dict[str, Any] = { + "model": self.model, + "messages": anthropic_msgs, + "temperature": kwargs.get("temperature", self.temperature), + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + } + system_text = system or self.system_prompt + if system_text: + create_kwargs["system"] = system_text + anthropic_tools = _mcp_tools_to_anthropic(tools) + if anthropic_tools: + create_kwargs["tools"] = anthropic_tools + + response = await self._client.messages.create(**create_kwargs) + + content = "" + tool_calls = [] + for block in response.content: + if block.type == "text": + content += block.text + elif block.type == "tool_use": + tool_calls.append( + ToolCall(id=block.id, name=block.name, args=block.input) + ) + + return LLMResponse(content=content, tool_calls=tool_calls) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +_HOSTED_PROVIDERS: dict[str, tuple[str, int, type[LLMClient]]] = { + "openai": ("https://api.openai.com", 443, OpenAIClient), + "anthropic": ("https://api.anthropic.com", 443, AnthropicClient), +} + + +def create_llm_client( + provider: str, + model: str, + api_key: str, + *, + system_prompt: str | None = None, + temperature: float = 0.0, + max_tokens: int = 4096, +) -> LLMClient: + """Create an LLM client for a hosted provider. + + Args: + provider: Provider name ("openai" or "anthropic"). + model: Model identifier. + api_key: API key for the provider. + system_prompt: Optional system message prepended to every request. + temperature: Sampling temperature. + max_tokens: Maximum tokens in the response. + + Returns: + A configured ``LLMClient`` instance. + """ + key = provider.lower() + if key not in _HOSTED_PROVIDERS: + raise ValueError( + f"Unsupported provider: {provider!r}. " + f"Supported: {sorted(_HOSTED_PROVIDERS)}" + ) + endpoint, port, cls = _HOSTED_PROVIDERS[key] + return cls( + endpoint, + port, + model, + api_key=api_key, + system_prompt=system_prompt, + temperature=temperature, + max_tokens=max_tokens, + ) + + +# --------------------------------------------------------------------------- +# MCP tool-schema helpers +# --------------------------------------------------------------------------- + + +def _clean_mcp_schema(schema: dict[str, Any]) -> dict[str, Any]: + """Normalize an MCP tool ``inputSchema`` for LLM function-calling APIs.""" + if not isinstance(schema, dict): + return {"type": "object", "properties": {}, "required": []} + + # Shallow copy to avoid mutating the caller's schema dict. + schema = dict(schema) + + if "oneOf" in schema: + for option in schema["oneOf"]: + if isinstance(option, dict) and option.get("type") == "object": + schema = option + break + else: + return {"type": "object", "properties": {}, "required": []} + + if "allOf" in schema: + merged: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + for sub in schema["allOf"]: + if isinstance(sub, dict): + if "properties" in sub: + merged["properties"].update(sub["properties"]) + if "required" in sub: + merged["required"].extend(sub["required"]) + schema = merged + + if "anyOf" in schema: + for option in schema["anyOf"]: + if isinstance(option, dict) and option.get("type") == "object": + schema = option + break + else: + return {"type": "object", "properties": {}, "required": []} + + schema.setdefault("type", "object") + if schema.get("type") == "object" and "properties" not in schema: + schema["properties"] = {} + return schema + + +def _mcp_tools_to_openai( + mcp_tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert MCP tool definitions to OpenAI function-calling format.""" + result = [] + for tool in mcp_tools: + input_schema = tool.get( + "inputSchema", {"type": "object", "properties": {}, "required": []} + ) + result.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": _clean_mcp_schema(input_schema), + }, + } + ) + return result + + +def _mcp_tools_to_anthropic( + mcp_tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert MCP tool definitions to Anthropic tool format.""" + result = [] + for tool in mcp_tools: + input_schema = tool.get( + "inputSchema", {"type": "object", "properties": {}, "required": []} + ) + result.append( + { + "name": tool["name"], + "description": tool.get("description", ""), + "input_schema": _clean_mcp_schema(input_schema), + } + ) + return result + + +def _openai_msgs_to_anthropic( + messages: list[dict[str, Any]], +) -> tuple[str, list[dict[str, Any]]]: + """Convert OpenAI-format messages to Anthropic format. + + Returns ``(system_text, anthropic_messages)``. System-role messages are + extracted and concatenated; tool-result messages are converted to + Anthropic's ``tool_result`` content blocks inside user turns. + """ + system_parts: list[str] = [] + anthropic_msgs: list[dict[str, Any]] = [] + + for msg in messages: + role = msg["role"] + + if role == "system": + system_parts.append(msg["content"]) + + elif role == "user": + anthropic_msgs.append({"role": "user", "content": msg["content"]}) + + elif role == "assistant": + if msg.get("tool_calls"): + content: list[dict[str, Any]] = [] + if msg.get("content"): + content.append({"type": "text", "text": msg["content"]}) + for tc in msg["tool_calls"]: + args = tc["function"]["arguments"] + if isinstance(args, str): + args = json.loads(args) + content.append( + { + "type": "tool_use", + "id": tc["id"], + "name": tc["function"]["name"], + "input": args, + } + ) + anthropic_msgs.append({"role": "assistant", "content": content}) + else: + anthropic_msgs.append( + {"role": "assistant", "content": msg.get("content", "")} + ) + + elif role == "tool": + tool_result = { + "type": "tool_result", + "tool_use_id": msg["tool_call_id"], + "content": msg["content"], + } + # Anthropic requires tool results in user turns; merge if possible. + if ( + anthropic_msgs + and anthropic_msgs[-1]["role"] == "user" + and isinstance(anthropic_msgs[-1]["content"], list) + ): + anthropic_msgs[-1]["content"].append(tool_result) + else: + anthropic_msgs.append({"role": "user", "content": [tool_result]}) + + system = "\n\n".join(system_parts) + return system, anthropic_msgs diff --git a/src/openenv/core/mcp_client.py b/src/openenv/core/mcp_client.py new file mode 100644 index 0000000000000000000000000000000000000000..edac3529d3a34e798781d86cf4d2495dc9611713 --- /dev/null +++ b/src/openenv/core/mcp_client.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +MCP Client classes for tool-calling environments. + +This module provides async client classes for interacting with MCP-enabled environments: +- MCPClientBase: Base class with shared tool discovery +- MCPToolClient: Client for tool-calling style (one tool per step) + +These clients abstract away the MCP protocol details, providing a clean interface +for listing and calling tools on remote environments. All clients are async by default. + +Architecture Overview:: + + ┌─────────────────────────────────────────────────────────┐ + │ HTTPEnvServer │ + ├─────────────────────────────────────────────────────────┤ + │ Simulation Mode (default): │ + │ /ws → OpenEnv protocol (reset/step/state) │ + │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ + │ /reset, /step, /state → HTTP endpoints │ + ├─────────────────────────────────────────────────────────┤ + │ Production Mode (use_production_mode=True): │ + │ /mcp → MCP JSON-RPC (tools/list, tools/call) │ + │ Bypasses step() for direct tool access │ + └─────────────────────────────────────────────────────────┘ + + Client Usage: + MCPToolClient (default) → /ws (step-based, with rewards) + MCPToolClient (production) → /mcp (direct tool access, no rewards) + +Example (async): + >>> from openenv.core.mcp_client import MCPToolClient + >>> + >>> async with MCPToolClient(base_url="http://localhost:8000") as env: + ... # Discover available tools + ... tools = await env.list_tools() + ... print([t.name for t in tools]) + ... + ... # Call a tool + ... result = await env.call_tool("echo_message", message="Hello!") + ... print(result) + +Example (sync wrapper): + >>> env = MCPToolClient(base_url="http://localhost:8000").sync() + >>> with env: + ... tools = env.list_tools() + ... result = env.call_tool("echo_message", message="Hello!") +""" + +from typing import Any, Dict, List, Optional + +from .client_types import StepResult +from .env_client import EnvClient +from .env_server.mcp_types import ( + CallToolAction, + CallToolObservation, + ListToolsAction, + ListToolsObservation, + Tool, + ToolError, +) +from .env_server.types import Observation, State + + +class MCPClientBase(EnvClient[Any, Observation, State]): + """ + Base class for MCP clients with tool discovery. + + This class provides the common `list_tools()` method for discovering + available tools from an MCP-enabled environment. Subclasses implement + specific interaction patterns (tool-calling or CodeAct). + + Attributes: + _tools_cache: Cached list of tools (populated on first `list_tools()` call) + """ + + def __init__( + self, + base_url: str, + connect_timeout_s: float = 10.0, + message_timeout_s: float = 60.0, + provider: Optional[Any] = None, + mode: Optional[str] = None, + ): + """ + Initialize MCP client. + + Args: + base_url: Base URL of the environment server (http:// or ws://). + connect_timeout_s: Timeout for establishing WebSocket connection. + message_timeout_s: Timeout for receiving responses to messages. + provider: Optional container/runtime provider for lifecycle management. + mode: Communication mode. Must be 'production' for MCP clients. Defaults to 'production'. + """ + # MCPClientBase defaults to production mode, but allow override for validation + if mode is None: + mode = "production" + + # Validate that mode is production + mode_lower = mode.lower() + if mode_lower != "production": + raise ValueError( + f"MCPToolClient only supports 'production' mode, got '{mode}'. " + f"Use GenericEnvClient for simulation mode." + ) + + super().__init__( + base_url=base_url, + connect_timeout_s=connect_timeout_s, + message_timeout_s=message_timeout_s, + provider=provider, + mode=mode, + ) + self._tools_cache: Optional[List[Tool]] = None + self.use_production_mode = False + + async def list_tools(self, use_cache: bool = True) -> List[Tool]: + """ + Discover available tools from the environment. + + Args: + use_cache: If True, return cached tools if available. + Set to False to force a fresh request. + + Returns: + List of Tool objects with name, description, and input_schema. + + Example: + >>> tools = await env.list_tools() + >>> for tool in tools: + ... print(f"{tool.name}: {tool.description}") + """ + if use_cache and self._tools_cache is not None: + return self._tools_cache + + # Use production mode HTTP endpoint if enabled + if self.use_production_mode: + import requests + + # Convert ws:// URL to http:// URL + url = self._ws_url.replace("ws://", "http://").replace("wss://", "https://") + # Remove /ws suffix if present and add /mcp + url = url.rstrip("/ws").rstrip("/") + "/mcp" + + try: + response = requests.post( + url, + json={ + "jsonrpc": "2.0", + "method": "tools/list", + "params": {}, + "id": 1, + }, + ) + data = response.json() + if "result" in data and "tools" in data["result"]: + tools = [ + Tool( + name=t.get("name", ""), + description=t.get("description", ""), + input_schema=t.get( + "input_schema", t.get("inputSchema", {}) + ), + ) + for t in data["result"]["tools"] + ] + self._tools_cache = tools + return tools + except Exception: + # If HTTP request fails, return empty list + pass + return [] + + result = await self.step(ListToolsAction()) + self._tools_cache = result.observation.tools + return self._tools_cache + + def _step_payload(self, action: Any) -> Dict[str, Any]: + """Convert an Action object to the JSON data expected by the env server.""" + if isinstance(action, ListToolsAction): + return {"type": "list_tools"} + elif isinstance(action, CallToolAction): + return { + "type": "call_tool", + "tool_name": action.tool_name, + "arguments": action.arguments, + } + else: + # For unknown actions, try to serialize as dict + if hasattr(action, "model_dump"): + return action.model_dump() + return {"action": str(action)} + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Observation]: + """Convert a JSON response from the env server to StepResult[Observation].""" + obs_data = payload.get("observation", {}) + + # Check if this is a ListToolsObservation + if "tools" in obs_data: + tools = [ + Tool( + name=t.get("name", ""), + description=t.get("description", ""), + input_schema=t.get("input_schema", t.get("inputSchema", {})), + ) + for t in obs_data.get("tools", []) + ] + observation = ListToolsObservation( + tools=tools, + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + # Check if this is a CallToolObservation + elif "tool_name" in obs_data: + error = None + if obs_data.get("error"): + error = ToolError(**obs_data["error"]) + + observation = CallToolObservation( + tool_name=obs_data.get("tool_name", ""), + result=obs_data.get("result"), + error=error, + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + else: + # Generic observation + observation = Observation( + done=payload.get("done", False), + reward=payload.get("reward"), + metadata=obs_data.get("metadata", {}), + ) + + return StepResult( + observation=observation, + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Dict[str, Any]) -> State: + """Convert a JSON response from the state endpoint to a State object.""" + return State( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + ) + + +class MCPToolClient(MCPClientBase): + """ + Async client for tool-calling style MCP interactions. + + Each step invokes a single tool. Use this for traditional function-calling + agent patterns where the agent decides which tool to call next. + + This client provides convenience methods for tool discovery and invocation: + - `list_tools()`: Get all available tools with their schemas + - `call_tool(name, **kwargs)`: Invoke a tool by name with arguments + + Example (async): + >>> async with MCPToolClient(base_url="http://localhost:8000") as env: + ... # Reset the environment + ... await env.reset() + ... + ... # Discover available tools + ... tools = await env.list_tools() + ... print([t.name for t in tools]) # ['echo_message', 'echo_with_length'] + ... + ... # Call a tool directly + ... result = await env.call_tool("echo_message", message="Hello!") + ... print(result) # "Hello!" + ... + ... # Or use the full action interface + ... from openenv.core.env_server.mcp_types import CallToolAction + ... step_result = await env.step(CallToolAction( + ... tool_name="echo_with_length", + ... arguments={"message": "Test"} + ... )) + ... print(step_result.observation.result) + + Example (sync wrapper): + >>> env = MCPToolClient(base_url="http://localhost:8000").sync() + >>> with env: + ... tools = env.list_tools() + ... result = env.call_tool("echo_message", message="Hello!") + """ + + async def call_tool(self, name: str, **kwargs: Any) -> Any: + """ + Call a tool by name. + + This is a convenience method that creates a CallToolAction, executes it, + and returns the result directly. For more control, use `step()` with + a CallToolAction directly. + + Args: + name: Name of the tool to invoke (must match a tool from `list_tools()`). + **kwargs: Arguments to pass to the tool. Must match the tool's input_schema. + + Returns: + The tool's result. The type depends on the tool being called. + + Raises: + RuntimeError: If the server returns an error response. + + Example: + >>> result = await env.call_tool("add", a=5, b=3) + >>> print(result) # 8 + >>> + >>> result = await env.call_tool("greet", name="Claude") + >>> print(result) # "Hello, Claude!" + """ + action = CallToolAction(tool_name=name, arguments=kwargs) + result = await self.step(action) + obs = result.observation + + # Check for transport/framework errors + if isinstance(obs, CallToolObservation) and obs.error is not None: + raise RuntimeError( + f"Tool '{name}' failed: {obs.error.message} " + f"(type: {obs.error.error_type.value})" + ) + + # Return the result + if isinstance(obs, CallToolObservation): + result = obs.result + # Handle FastMCP CallToolResult objects + # - As object: has .data attribute + # - As dict (from JSON): has "data" key + if hasattr(result, "data"): + return result.data + if isinstance(result, dict) and "data" in result: + return result["data"] + return result + + # Fallback for unexpected observation types + return obs + + async def get_tool(self, name: str) -> Optional[Tool]: + """ + Get a specific tool by name. + + Args: + name: Name of the tool to find. + + Returns: + The Tool object if found, None otherwise. + + Example: + >>> tool = await env.get_tool("echo_message") + >>> if tool: + ... print(tool.description) + ... print(tool.input_schema) + """ + tools = await self.list_tools() + for tool in tools: + if tool.name == name: + return tool + return None + + async def has_tool(self, name: str) -> bool: + """ + Check if a tool exists. + + Args: + name: Name of the tool to check. + + Returns: + True if the tool exists, False otherwise. + """ + return await self.get_tool(name) is not None diff --git a/src/openenv/core/rubrics/__init__.py b/src/openenv/core/rubrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe368494b70cfbabb04e86cb2277aa8c838bdf7 --- /dev/null +++ b/src/openenv/core/rubrics/__init__.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Rubrics for reward computation. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +from openenv.core.rubrics.base import Rubric +from openenv.core.rubrics.containers import ( + Gate, + RubricDict, + RubricList, + Sequential, + WeightedSum, +) +from openenv.core.rubrics.llm_judge import LLMJudge +from openenv.core.rubrics.trajectory import ( + ExponentialDiscountingTrajectoryRubric, + TrajectoryRubric, +) + +__all__ = [ + # Base + "Rubric", + # Containers + "Sequential", + "Gate", + "WeightedSum", + "RubricList", + "RubricDict", + # Trajectory + "TrajectoryRubric", + "ExponentialDiscountingTrajectoryRubric", + # LLM Judge + "LLMJudge", +] diff --git a/src/openenv/core/rubrics/base.py b/src/openenv/core/rubrics/base.py new file mode 100644 index 0000000000000000000000000000000000000000..38c7a381bc4f40a7bc1dac832902e9e6ac93a282 --- /dev/null +++ b/src/openenv/core/rubrics/base.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Base Rubric class for reward computation. + +Rubrics compute rewards from actions and observations. The API is modeled +after PyTorch's nn.Module: users implement forward(), and the framework +handles child registration and hooks. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import inspect +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + + +class Rubric(ABC): + """Abstract base class for reward computation. + + A Rubric computes a reward signal from an action and observation. + Subclasses implement forward() to define the reward logic. + + Usage: + class MyRubric(Rubric): + def forward(self, action, observation) -> float: + return 1.0 if action.valid else 0.0 + + rubric = MyRubric() + reward = rubric(action, observation) + + Child rubrics are auto-registered when assigned as attributes, + enabling hierarchical composition and introspection. + """ + + _rubric_children: Dict[str, "Rubric"] + _forward_hooks: List[Callable] + _forward_pre_hooks: List[Callable] + last_score: Optional[float] + + def __init__(self): + # Use object.__setattr__ to avoid triggering __setattr__ during init + object.__setattr__(self, "_rubric_children", {}) + object.__setattr__(self, "_forward_hooks", []) + object.__setattr__(self, "_forward_pre_hooks", []) + object.__setattr__(self, "last_score", None) + + def __setattr__(self, name: str, value: Any) -> None: + # Auto-register child rubrics when assigned as attributes + if isinstance(value, Rubric): + self._rubric_children[name] = value + object.__setattr__(self, name, value) + + def __call__(self, action: Any, observation: Any): + """Evaluate the rubric with hooks. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + # Check if forward method is async BEFORE calling it + if inspect.iscoroutinefunction(self.forward): + # Async path - pre-hooks will be called in _call_async + result = self.forward(action, observation) + return self._call_async(action, observation, result) + else: + # Sync path - call pre-hooks BEFORE forward() + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = self.forward(action, observation) + return self._call_sync(action, observation, result) + + def _call_sync(self, action: Any, observation: Any, result: float) -> float: + """Synchronous call path.""" + self.last_score = result + + # Post-forward hooks + for hook in self._forward_hooks: + hook(self, action, observation, result) + + return result + + async def _call_async(self, action: Any, observation: Any, result_coro) -> float: + """Asynchronous call path.""" + # Pre-forward hooks + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Await the forward result + result = await result_coro + self.last_score = result + + # Post-forward hooks + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + + return result + + @abstractmethod + def forward(self, action: Any, observation: Any) -> float: + """Compute the reward. Implement this in subclasses. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Reward value (typically 0.0 to 1.0). + """ + raise NotImplementedError + + def register_forward_hook( + self, hook: Callable[["Rubric", Any, Any, float], None] + ) -> None: + """Register a hook called after forward(). + + Args: + hook: Callable with signature (rubric, action, observation, result). + """ + self._forward_hooks.append(hook) + + def register_forward_pre_hook( + self, hook: Callable[["Rubric", Any, Any], None] + ) -> None: + """Register a hook called before forward(). + + Args: + hook: Callable with signature (rubric, action, observation). + """ + self._forward_pre_hooks.append(hook) + + def children(self) -> Iterator["Rubric"]: + """Iterate over immediate child rubrics.""" + yield from self._rubric_children.values() + + def named_children(self) -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over immediate child rubrics with names.""" + yield from self._rubric_children.items() + + def rubrics(self) -> Iterator["Rubric"]: + """Iterate over all descendant rubrics (depth-first).""" + for child in self._rubric_children.values(): + yield child + yield from child.rubrics() + + def named_rubrics(self, prefix: str = "") -> Iterator[Tuple[str, "Rubric"]]: + """Iterate over all descendant rubrics with dot-separated names.""" + for name, child in self._rubric_children.items(): + full_name = f"{prefix}.{name}" if prefix else name + yield full_name, child + yield from child.named_rubrics(full_name) + + def get_rubric(self, path: str) -> "Rubric": + """Access a nested rubric by dot-separated path. + + Args: + path: Dot-separated path (e.g., "code.syntax"). + + Returns: + The rubric at the specified path. + + Raises: + KeyError: If the path does not exist. + """ + parts = path.split(".") + current = self + for part in parts: + if part not in current._rubric_children: + raise KeyError(f"Rubric path not found: {path}") + current = current._rubric_children[part] + return current + + def reset(self) -> None: + """Reset any internal state. Override in subclasses if needed.""" + pass + + def state_dict(self) -> Dict[str, Any]: + """Serialize rubric configuration for checkpointing.""" + return {} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load rubric configuration from checkpoint.""" + pass diff --git a/src/openenv/core/rubrics/containers.py b/src/openenv/core/rubrics/containers.py new file mode 100644 index 0000000000000000000000000000000000000000..7a587ee7885efdf71b03d644b54524f1855474d9 --- /dev/null +++ b/src/openenv/core/rubrics/containers.py @@ -0,0 +1,574 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Container rubrics for composing reward computations. + +These containers provide common aggregation patterns for rubrics, +similar to how PyTorch provides nn.Sequential alongside nn.Module. + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import asyncio +import inspect +from typing import Any, Dict, Iterator, List, Mapping, Tuple, Union + +from openenv.core.rubrics.base import Rubric + + +def _in_async_context() -> bool: + """Check if we're currently in an async context.""" + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + +class Sequential(Rubric): + """Run rubrics in order, fail-fast on zero. + + Runs child rubrics in order. If any returns 0, stops immediately + and returns 0. This implements hierarchical gating patterns where + syntax checks run before execution checks. + + Usage: + rubric = Sequential( + Gate(Compiles()), + Gate(PassesTests(), threshold=0.5), + WeightedSum([PassesTests(), StyleRubric()], weights=[0.7, 0.3]) + ) + """ + + def __init__(self, *rubrics: Rubric): + """Initialize with rubrics to run in sequence. + + Args: + *rubrics: Rubrics to run in order. Stops and returns 0 if any + child returns 0. + """ + super().__init__() + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + + def forward(self, action: Any, observation: Any) -> float: + """Run rubrics in order, return 0 if any returns 0. Sync version.""" + result = 1.0 + for rubric in self._rubric_list: + score = rubric(action, observation) + if score == 0.0: + return 0.0 + result = score + return result + + def __call__(self, action: Any, observation: Any): + """Override to choose sync or async path based on children.""" + # Empty case - check if in async context + if not self._rubric_list: + if _in_async_context(): + return self._empty_async(action, observation) + else: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = 1.0 + self.last_score = result + for hook in self._forward_hooks: + hook(self, action, observation, result) + return result + + # Call first rubric to see if it's async + first_result = self._rubric_list[0](action, observation) + if inspect.iscoroutine(first_result): + # At least one child is async, use async path + return self._call_async_detected(action, observation, first_result) + else: + # Continue with sync path + if first_result == 0.0: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = 0.0 + for hook in self._forward_hooks: + hook(self, action, observation, 0.0) + return 0.0 + + final_result = first_result + for i, rubric in enumerate(self._rubric_list[1:], start=1): + score = rubric(action, observation) + if inspect.iscoroutine(score): + # Found async mid-way, switch to async + # We already called rubric at index i, so pass the coroutine and remaining rubrics + return self._call_async_mid( + action, + observation, + final_result, + score, + self._rubric_list[i + 1 :], + ) + if score == 0.0: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = 0.0 + for hook in self._forward_hooks: + hook(self, action, observation, 0.0) + return 0.0 + final_result = score + + # All sync - check if in async context + if _in_async_context(): + return self._wrap_sync_result(action, observation, final_result) + else: + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + self.last_score = final_result + for hook in self._forward_hooks: + hook(self, action, observation, final_result) + return final_result + + async def _empty_async(self, action, observation): + """Async path for empty sequential.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + result = 1.0 + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _wrap_sync_result(self, action, observation, result): + """Wrap sync result for async context.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _call_async_detected(self, action, observation, first_coro): + """Async path when first child is async.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + result = await first_coro + if result == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return 0.0 + + for rubric in self._rubric_list[1:]: + score = rubric(action, observation) + if inspect.iscoroutine(score): + score = await score + if score == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + result = score + + self.last_score = result + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + async def _call_async_mid( + self, action, observation, current_result, first_async_coro, remaining + ): + """Async path when async detected mid-execution.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Await the first async rubric (already called) + result = await first_async_coro + if result == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + + # Continue with remaining rubrics + for rubric in remaining: + score = rubric(action, observation) + if inspect.iscoroutine(score): + score = await score + if score == 0.0: + self.last_score = 0.0 + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, 0.0) + else: + hook(self, action, observation, 0.0) + return 0.0 + result = score + + self.last_score = result + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + def __len__(self) -> int: + return len(self._rubric_list) + + def __getitem__(self, index: int) -> Rubric: + return self._rubric_list[index] + + +class Gate(Rubric): + """Threshold wrapper - returns 0 if child score is below threshold. + + Useful for hard constraints like "must pass 50% of tests". + + Usage: + rubric = Gate(PassesTests(), threshold=0.5) + # Returns PassesTests() score if >= 0.5, else 0.0 + """ + + def __init__(self, rubric: Rubric, threshold: float = 1.0): + """Initialize with a rubric and threshold. + + Args: + rubric: The rubric to gate. + threshold: Minimum score required. If child returns less than + this, Gate returns 0. Default is 1.0 (must pass completely). + """ + super().__init__() + self.rubric = rubric + self.threshold = threshold + + def forward(self, action: Any, observation: Any) -> float: + """Return child score if >= threshold, else 0. Sync version.""" + score = self.rubric(action, observation) + if score < self.threshold: + return 0.0 + return score + + def __call__(self, action: Any, observation: Any): + """Override to handle async child.""" + # Call child + score = self.rubric(action, observation) + + if inspect.iscoroutine(score): + # Child is async + return self._call_async(action, observation, score) + else: + # Child is sync + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + result = 0.0 if score < self.threshold else score + self.last_score = result + for hook in self._forward_hooks: + hook(self, action, observation, result) + return result + + async def _call_async(self, action, observation, score_coro): + """Async path.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + score = await score_coro + result = 0.0 if score < self.threshold else score + self.last_score = result + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, result) + else: + hook(self, action, observation, result) + return result + + +class WeightedSum(Rubric): + """Weighted combination of child rubrics. + + Standard aggregation pattern for multi-criteria evaluation. + + Usage: + rubric = WeightedSum( + [PassesTests(), StyleRubric()], + weights=[0.7, 0.3] + ) + """ + + def __init__(self, rubrics: List[Rubric], weights: List[float]): + """Initialize with rubrics and weights. + + Args: + rubrics: List of rubrics to combine. + weights: Weight for each rubric. Must sum to 1.0. + + Raises: + ValueError: If lengths don't match or weights don't sum to 1.0. + """ + super().__init__() + if len(rubrics) != len(weights): + raise ValueError( + f"Number of rubrics ({len(rubrics)}) must match " + f"number of weights ({len(weights)})" + ) + if abs(sum(weights) - 1.0) > 1e-6: + raise ValueError(f"Weights must sum to 1.0, got {sum(weights)}") + + for i, rubric in enumerate(rubrics): + setattr(self, f"rubric_{i}", rubric) + self._rubric_list = list(rubrics) + self._weights = list(weights) + + def forward(self, action: Any, observation: Any) -> float: + """Return weighted sum of child scores. Sync version.""" + total = 0.0 + for rubric, weight in zip(self._rubric_list, self._weights): + score = rubric(action, observation) + total += score * weight + return total + + def __call__(self, action: Any, observation: Any): + """Override to handle async children with parallel execution.""" + # Call all rubrics + results = [rubric(action, observation) for rubric in self._rubric_list] + + # Check if any are async + has_async = any(inspect.iscoroutine(r) for r in results) + + if has_async: + # Use async path + return self._call_async(action, observation, results) + else: + # Sync path + # Pre-hooks + for hook in self._forward_pre_hooks: + hook(self, action, observation) + total = 0.0 + for score, weight in zip(results, self._weights): + total += score * weight + self.last_score = total + for hook in self._forward_hooks: + hook(self, action, observation, total) + return total + + async def _call_async(self, action, observation, results): + """Async path with parallel execution.""" + for hook in self._forward_pre_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation) + else: + hook(self, action, observation) + + # Separate sync and async results + async_tasks = [] + async_indices = [] + scores = [None] * len(results) + + for i, result in enumerate(results): + if inspect.iscoroutine(result): + async_tasks.append(result) + async_indices.append(i) + else: + scores[i] = result + + # Await all async tasks in parallel + if async_tasks: + async_scores = await asyncio.gather(*async_tasks) + for i, score in zip(async_indices, async_scores): + scores[i] = score + + # Compute weighted sum + total = 0.0 + for score, weight in zip(scores, self._weights): + total += score * weight + + self.last_score = total + + for hook in self._forward_hooks: + if inspect.iscoroutinefunction(hook): + await hook(self, action, observation, total) + else: + hook(self, action, observation, total) + return total + + @property + def weights(self) -> List[float]: + """Get the weights (read-only copy).""" + return list(self._weights) + + +class RubricList(Rubric): + """Container for dynamic lists of rubrics. + + Analogous to nn.ModuleList. Does not define aggregation - use within + a parent rubric that implements custom logic. + + Usage: + class MultiGameRubric(Rubric): + def __init__(self, games: List[str]): + super().__init__() + self.games = RubricList([GameRubric(g) for g in games]) + + def forward(self, action, obs) -> float: + return self.games[obs.game_index](action, obs) + """ + + def __init__(self, rubrics: List[Rubric] = None): + """Initialize with optional list of rubrics. + + Args: + rubrics: Optional list of rubrics to start with. + """ + super().__init__() + self._rubrics: List[Rubric] = [] + if rubrics is not None: + for i, rubric in enumerate(rubrics): + self.append(rubric) + + def forward(self, action: Any, observation: Any) -> float: + """RubricList does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricList.forward() is not implemented. " + "Use RubricList within a parent rubric that defines aggregation." + ) + + def append(self, rubric: Rubric) -> None: + """Add a rubric to the list.""" + index = len(self._rubrics) + setattr(self, f"rubric_{index}", rubric) + self._rubrics.append(rubric) + + def extend(self, rubrics: List[Rubric]) -> None: + """Add multiple rubrics to the list.""" + for rubric in rubrics: + self.append(rubric) + + def __len__(self) -> int: + return len(self._rubrics) + + def __getitem__(self, index: int) -> Rubric: + return self._rubrics[index] + + def __iter__(self) -> Iterator[Rubric]: + return iter(self._rubrics) + + +class RubricDict(Rubric): + """Container for named rubrics with keyed access. + + Analogous to nn.ModuleDict. Enables keyed access for multi-task + environments where different tasks require different rubrics. + + Usage: + class AtariRubric(Rubric): + def __init__(self): + super().__init__() + self.games = RubricDict({ + "pong": PongRubric(), + "breakout": BreakoutRubric(), + "space_invaders": SpaceInvadersRubric(), + }) + + def forward(self, action, obs) -> float: + return self.games[obs.game_id](action, obs) + + # Access: env.rubric.games["pong"] + """ + + def __init__(self, rubrics: Dict[str, Rubric] = None): + """Initialize with optional dictionary of rubrics. + + Args: + rubrics: Optional dictionary mapping names to rubrics. + """ + super().__init__() + self._rubric_dict: Dict[str, Rubric] = {} + if rubrics is not None: + for name, rubric in rubrics.items(): + self[name] = rubric + + def forward(self, action: Any, observation: Any) -> float: + """RubricDict does not define aggregation - override in parent.""" + raise NotImplementedError( + "RubricDict.forward() is not implemented. " + "Use RubricDict within a parent rubric that defines aggregation." + ) + + def __setitem__(self, key: str, rubric: Rubric) -> None: + """Add a rubric with the given key.""" + setattr(self, key, rubric) + self._rubric_dict[key] = rubric + + def __getitem__(self, key: str) -> Rubric: + """Get rubric by key.""" + return self._rubric_dict[key] + + def __contains__(self, key: str) -> bool: + """Check if key exists.""" + return key in self._rubric_dict + + def __len__(self) -> int: + return len(self._rubric_dict) + + def __iter__(self) -> Iterator[str]: + return iter(self._rubric_dict) + + def keys(self) -> Iterator[str]: + """Iterate over keys.""" + return iter(self._rubric_dict.keys()) + + def values(self) -> Iterator[Rubric]: + """Iterate over rubrics.""" + return iter(self._rubric_dict.values()) + + def items(self) -> Iterator[Tuple[str, Rubric]]: + """Iterate over (key, rubric) pairs.""" + return iter(self._rubric_dict.items()) + + def update(self, rubrics: Union[Dict[str, Rubric], Mapping[str, Rubric]]) -> None: + """Update with rubrics from a dictionary.""" + for name, rubric in rubrics.items(): + self[name] = rubric diff --git a/src/openenv/core/rubrics/llm_judge.py b/src/openenv/core/rubrics/llm_judge.py new file mode 100644 index 0000000000000000000000000000000000000000..4963956eb4a51270c03809f9f0e14f1c66b91958 --- /dev/null +++ b/src/openenv/core/rubrics/llm_judge.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""LLM-as-a-judge rubric for reward computation. + +Uses an LLM endpoint (via LLMClient) to evaluate agent actions/observations. + +Usage: + client = OpenAIClient("http://localhost", 8000, model="meta-llama/...") + judge = LLMJudge( + prompt_template="Rate this code solution:\\n{action}\\n\\nScore (0-1):", + client=client, + ) + score = await judge(action, observation) + +See RFC 004 for full design: rfcs/004-rubrics.md +""" + +import re +from typing import Any, Dict + +from openenv.core.llm_client import LLMClient +from openenv.core.rubrics.base import Rubric + + +class LLMJudge(Rubric): + """Rubric that uses an LLM to evaluate agent actions/observations. + + The prompt template is formatted with ``{action}`` and ``{observation}`` + placeholders. The LLM response is parsed for a numeric score. + + Args: + prompt_template: Template string with {action} and {observation} placeholders. + client: An LLMClient instance for making LLM calls. + score_pattern: Regex to extract the score from the LLM response. + Defaults to matching the first decimal number. + default_score: Score returned when parsing fails. + normalize: If True, clamp extracted score to [0, 1]. + """ + + def __init__( + self, + prompt_template: str, + client: LLMClient, + *, + score_pattern: str | None = None, + default_score: float = 0.0, + normalize: bool = True, + ): + super().__init__() + self.prompt_template = prompt_template + self._client = client + self._score_pattern = re.compile(score_pattern or r"(\d+\.?\d*)") + self.default_score = default_score + self.normalize = normalize + + async def forward(self, action: Any, observation: Any) -> float: + """Evaluate by sending a prompt to the LLM and parsing the score. + + Args: + action: The action taken by the agent. + observation: The resulting observation. + + Returns: + Parsed score from the LLM response. + """ + prompt = self._render_prompt(action, observation) + response = await self._client.complete(prompt) + return self._parse_score(response) + + def _render_prompt(self, action: Any, observation: Any) -> str: + """Format the prompt template with action and observation. + + Override in subclasses for custom prompt construction. + """ + return self.prompt_template.format(action=action, observation=observation) + + def _parse_score(self, response: str) -> float: + """Extract a numeric score from the LLM response. + + Uses the configured regex pattern to find the first match. + Returns default_score if no match is found. + """ + match = self._score_pattern.search(response) + if match is None: + return self.default_score + try: + # Use first capture group if present, otherwise full match + text = match.group(1) if match.lastindex else match.group(0) + score = float(text) + except (ValueError, IndexError): + return self.default_score + if self.normalize: + score = max(0.0, min(1.0, score)) + return score + + def state_dict(self) -> Dict[str, Any]: + """Serialize rubric configuration.""" + return { + "prompt_template": self.prompt_template, + "score_pattern": self._score_pattern.pattern, + "default_score": self.default_score, + "normalize": self.normalize, + } + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load rubric configuration from checkpoint.""" + if "prompt_template" in state: + self.prompt_template = state["prompt_template"] + if "score_pattern" in state: + self._score_pattern = re.compile(state["score_pattern"]) + if "default_score" in state: + self.default_score = state["default_score"] + if "normalize" in state: + self.normalize = state["normalize"] diff --git a/src/openenv/core/rubrics/trajectory.py b/src/openenv/core/rubrics/trajectory.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bb9aa9172047a24f89fae1fee6917abb861257 --- /dev/null +++ b/src/openenv/core/rubrics/trajectory.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Trajectory-based rubrics for delayed reward computation. + +These rubrics accumulate trajectory data and compute rewards based on +episode outcomes rather than individual steps. This supports scenarios +where reward signals depend on future events: + +- Terminal games (chess, Go): Win/loss known only at game end +- Plan execution: Plan quality depends on execution success +- Multi-agent games: One player's action quality depends on opponent response + +See RFC 004 "Delayed Rewards" section for design rationale. +""" + +from abc import abstractmethod +from typing import Any, Dict, List, Tuple + +from openenv.core.rubrics.base import Rubric + + +class TrajectoryRubric(Rubric): + """Abstract base for rubrics that score based on full trajectories. + + Subclasses implement: + - score_trajectory(): Compute final score from trajectory + - compute_step_rewards(): Define credit assignment strategy + + The __call__ method accumulates steps and returns rewards according + to the subclass's implementation. + + IMPORTANT: Trajectories are stored in CPU memory to avoid GPU pressure. + Environments with GPU tensors in observations must move them to CPU + before returning from step(). + + Known limitation: Very long episodes (thousands of steps) may consume + significant CPU memory. For such cases, consider streaming rubrics. + + Usage: + class WinLossRubric(TrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + return 1.0 if final_obs.metadata.get('won') else 0.0 + + def compute_step_rewards(self): + # Equal credit to all steps + score = self.score_trajectory(self._trajectory) + return [score] * len(self._trajectory) + + rubric = WinLossRubric() + for action, obs in episode: + reward = rubric(action, obs) # 0.0 until done + step_rewards = rubric.compute_step_rewards() # Credit assignment + """ + + _trajectory: List[Tuple[Any, Any]] + intermediate_reward: float + + def __init__(self, intermediate_reward: float = 0.0): + """Initialize trajectory rubric. + + Args: + intermediate_reward: Value to return for non-terminal steps. + Defaults to 0.0. + """ + super().__init__() + self.intermediate_reward = intermediate_reward + self._trajectory = [] + + def forward(self, action: Any, observation: Any) -> float: + """Accumulate step and return reward. + + Returns intermediate_reward until done, then computes trajectory score. + + Args: + action: The action taken. + observation: The resulting observation. Must have a 'done' attribute. + + Returns: + intermediate_reward if not done, else score_trajectory() result. + """ + self._trajectory.append((action, observation)) + + if getattr(observation, "done", False): + return self.score_trajectory(self._trajectory) + else: + return self.intermediate_reward + + @abstractmethod + def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float: + """Score the complete trajectory. Return 0.0-1.0. + + Called when observation.done=True. + + Args: + trajectory: List of (action, observation) tuples. + + Returns: + Final trajectory score (typically 0.0 to 1.0). + """ + raise NotImplementedError + + @abstractmethod + def compute_step_rewards(self) -> List[float]: + """Compute per-step rewards from the accumulated trajectory. + + Returns: + List of rewards, one per step. Length matches len(trajectory). + + Define your credit assignment strategy here (e.g., discounting, + assigning all credit to specific steps, etc.). + """ + raise NotImplementedError + + def reset(self) -> None: + """Clear accumulated trajectory. Call on env.reset().""" + self._trajectory = [] + + @property + def trajectory(self) -> List[Tuple[Any, Any]]: + """Current trajectory (read-only copy).""" + return list(self._trajectory) + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration (not trajectory data).""" + return {"intermediate_reward": self.intermediate_reward} + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + if "intermediate_reward" in state: + self.intermediate_reward = state["intermediate_reward"] + + +class ExponentialDiscountingTrajectoryRubric(TrajectoryRubric): + """TrajectoryRubric with exponential discounting for credit assignment. + + Per-step reward: r_t = gamma^(T-1-t) * R_final + + With gamma=0.99, later steps get higher reward (they're "closer" to the outcome). + With gamma=1.0, all steps get equal reward. + With gamma=0.0, only the final step gets reward. + + This is the standard temporal discounting used in reinforcement learning, + applied retroactively once the episode outcome is known. + + Usage: + class ChessRubric(ExponentialDiscountingTrajectoryRubric): + def score_trajectory(self, trajectory): + _, final_obs = trajectory[-1] + outcome = final_obs.metadata.get('winner') + if outcome == 'agent': return 1.0 + elif outcome == 'opponent': return 0.0 + else: return 0.5 # Draw + + rubric = ChessRubric(gamma=0.99) + reward = rubric(action, obs) # 0.0 until done, then final score + step_rewards = rubric.compute_step_rewards() # Discounted per-step rewards + """ + + gamma: float + + def __init__(self, gamma: float = 0.99, intermediate_reward: float = 0.0): + """Initialize with discount factor. + + Args: + gamma: Discount factor in [0, 1]. Higher values give more credit + to early moves. 0.99 is a common choice. + intermediate_reward: Value to return for non-terminal steps. + """ + super().__init__(intermediate_reward=intermediate_reward) + if not 0.0 <= gamma <= 1.0: + raise ValueError(f"gamma must be in [0, 1], got {gamma}") + self.gamma = gamma + + def compute_step_rewards(self) -> List[float]: + """Apply exponential discounting from final reward. + + Returns: + List of discounted rewards. step_rewards[t] = gamma^(T-1-t) * R_final + where T is the trajectory length and R_final is score_trajectory(). + """ + if not self._trajectory: + return [] + + final_score = self.score_trajectory(self._trajectory) + T = len(self._trajectory) + return [final_score * (self.gamma ** (T - 1 - t)) for t in range(T)] + + def state_dict(self) -> Dict[str, Any]: + """Serialize configuration.""" + state = super().state_dict() + state["gamma"] = self.gamma + return state + + def load_state_dict(self, state: Dict[str, Any]) -> None: + """Load configuration from checkpoint.""" + super().load_state_dict(state) + if "gamma" in state: + self.gamma = state["gamma"] diff --git a/src/openenv/core/sync_client.py b/src/openenv/core/sync_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5eb5da6151cea692ae447e1d4caba40a95fdaa --- /dev/null +++ b/src/openenv/core/sync_client.py @@ -0,0 +1,263 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Synchronous wrapper for async EnvClient. + +This module provides a SyncEnvClient that wraps an async EnvClient, +allowing synchronous usage while the underlying client uses async I/O. + +Example: + >>> from openenv.core import GenericEnvClient + >>> + >>> # Create async client and get sync wrapper + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous API + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"code": "print('hello')"}) +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import inspect +import threading +from typing import Any, Dict, Generic, TYPE_CHECKING, TypeVar + +from .client_types import StateT, StepResult + +if TYPE_CHECKING: + from .env_client import EnvClient + +ActT = TypeVar("ActT") +ObsT = TypeVar("ObsT") + + +class SyncEnvClient(Generic[ActT, ObsT, StateT]): + """ + Synchronous wrapper around an async EnvClient. + + This class provides a synchronous interface to an async EnvClient, + making it easier to use in synchronous code or to stop async from + "infecting" the entire call stack. + + The wrapper executes async operations on a dedicated background event loop + so connection state remains bound to a single loop. + + Cleanup note: + For guaranteed resource cleanup, use `with SyncEnvClient(...)` or call + `close()` explicitly. `__del__` is best-effort only and may not run + reliably (for example, during interpreter shutdown). + + Example: + >>> # From an async client + >>> async_client = GenericEnvClient(base_url="http://localhost:8000") + >>> sync_client = async_client.sync() + >>> + >>> # Use synchronous context manager + >>> with sync_client: + ... result = sync_client.reset() + ... result = sync_client.step({"action": "test"}) + + Attributes: + _async: The wrapped async EnvClient instance + """ + + def __init__(self, async_client: "EnvClient[ActT, ObsT, StateT]"): + """ + Initialize sync wrapper around an async client. + + Args: + async_client: The async EnvClient to wrap + """ + self._async = async_client + self._loop: asyncio.AbstractEventLoop | None = None + self._loop_thread: threading.Thread | None = None + self._loop_ready = threading.Event() + self._loop_init_lock = threading.Lock() + self._async_wrapper_cache: Dict[str, Any] = {} + + def _run_loop_forever(self) -> None: + """Run a dedicated event loop for this sync client.""" + loop = asyncio.new_event_loop() + self._loop = loop + asyncio.set_event_loop(loop) + self._loop_ready.set() + loop.run_forever() + loop.close() + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + """Start background loop thread on first use.""" + if ( + self._loop is not None + and self._loop_thread + and self._loop_thread.is_alive() + ): + return self._loop + + # Protect loop initialization when multiple threads race on first use. + with self._loop_init_lock: + if ( + self._loop is not None + and self._loop_thread + and self._loop_thread.is_alive() + ): + return self._loop + + self._loop_ready.clear() + self._loop_thread = threading.Thread( + target=self._run_loop_forever, + name="openenv-sync-client-loop", + daemon=True, + ) + self._loop_thread.start() + if not self._loop_ready.wait(timeout=5): + raise RuntimeError("Timed out starting sync client event loop") + assert self._loop is not None + return self._loop + + def _run(self, coro: Any) -> Any: + """Run coroutine on dedicated loop and block for result.""" + loop = self._ensure_loop() + future: concurrent.futures.Future[Any] = asyncio.run_coroutine_threadsafe( + coro, loop + ) + return future.result() + + def _stop_loop(self) -> None: + """Stop and join background loop thread.""" + loop = self._loop + thread = self._loop_thread + if loop is None: + return + + if loop.is_running(): + loop.call_soon_threadsafe(loop.stop) + if thread is not None: + thread.join(timeout=5) + + self._loop = None + self._loop_thread = None + + @property + def async_client(self) -> "EnvClient[ActT, ObsT, StateT]": + """Access the underlying async client.""" + return self._async + + def connect(self) -> "SyncEnvClient[ActT, ObsT, StateT]": + """ + Establish connection to the server. + + Returns: + self for method chaining + """ + self._run(self._async.connect()) + return self + + def disconnect(self) -> None: + """Close the connection.""" + self._run(self._async.disconnect()) + + def reset(self, **kwargs: Any) -> StepResult[ObsT]: + """ + Reset the environment. + + Args: + **kwargs: Optional parameters passed to the environment's reset method + + Returns: + StepResult containing initial observation + """ + return self._run(self._async.reset(**kwargs)) + + def step(self, action: ActT, **kwargs: Any) -> StepResult[ObsT]: + """ + Execute an action in the environment. + + Args: + action: The action to execute + **kwargs: Optional parameters + + Returns: + StepResult containing observation, reward, and done status + """ + return self._run(self._async.step(action, **kwargs)) + + def state(self) -> StateT: + """ + Get the current environment state. + + Returns: + State object with environment state information + """ + return self._run(self._async.state()) + + def close(self) -> None: + """Close the connection and clean up resources.""" + try: + self._run(self._async.close()) + finally: + self._stop_loop() + + def __enter__(self) -> "SyncEnvClient[ActT, ObsT, StateT]": + """Enter context manager, establishing connection.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit context manager, closing connection.""" + self.close() + + def __del__(self) -> None: + """ + Best-effort cleanup for background loop thread. + + Do not rely on this for deterministic cleanup; prefer context-manager + usage or an explicit `close()` call. + """ + try: + self._stop_loop() + except Exception: + pass + + def __getattr__(self, name: str) -> Any: + """ + Delegate unknown attributes to the async client. + + Async methods are wrapped to run on the sync client's dedicated loop. + """ + attr = getattr(self._async, name) + + if inspect.iscoroutinefunction(attr): + cached = self._async_wrapper_cache.get(name) + if cached is not None: + return cached + + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + method = getattr(self._async, name) + return self._run(method(*args, **kwargs)) + + self._async_wrapper_cache[name] = sync_wrapper + return sync_wrapper + + return attr + + # Delegate abstract method implementations to the wrapped client + def _step_payload(self, action: ActT) -> Dict[str, Any]: + """Delegate to async client's _step_payload.""" + return self._async._step_payload(action) + + def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ObsT]: + """Delegate to async client's _parse_result.""" + return self._async._parse_result(payload) + + def _parse_state(self, payload: Dict[str, Any]) -> StateT: + """Delegate to async client's _parse_state.""" + return self._async._parse_state(payload) diff --git a/src/openenv/core/tools/__init__.py b/src/openenv/core/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0193b2619fc14f14152e3276f54aa0d4aed8ca2c --- /dev/null +++ b/src/openenv/core/tools/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Core tools for code execution and other utilities.""" + +from .git_server_client import GitServerClient, RepoInfo + +try: + from .local_python_executor import PyExecutor +except ModuleNotFoundError: + # smolagents is optional for environments that only need Git tooling. + PyExecutor = None # type: ignore[assignment] + +__all__ = [ + "PyExecutor", + "GitServerClient", + "RepoInfo", +] diff --git a/src/openenv/core/tools/git_server_client.py b/src/openenv/core/tools/git_server_client.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc3379f6b675178cc7aa94914c31f66bc846aed --- /dev/null +++ b/src/openenv/core/tools/git_server_client.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 +""" +Git Server Client for connecting to external Gitea instance. + +This module provides a lightweight client for interacting with a shared +Gitea service, optimized for task-based isolation where multiple environment +instances share the same Gitea server but have isolated workspaces. +""" + +import json +import os +import shutil +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from urllib.parse import urlparse + + +@dataclass +class RepoInfo: + """Information about a repository.""" + + name: str + url: str + commit: str + clone_url: str + + +class GitServerClient: + """ + Client for connecting to an external Gitea server. + + This client is optimized for task-based isolation where: + - Multiple tasks share the same Gitea instance + - Each task has its own isolated workspace + - Fast reset() via git operations (no server restart) + - Repos are pre-migrated to Gitea once + + Args: + gitea_url: URL of the Gitea server (e.g., "http://gitea:3000") + username: Gitea username for authentication + password: Gitea password for authentication + workspace_dir: Local workspace directory for cloning repos + + Example: + >>> # Connect to shared Gitea (credentials from environment) + >>> import os + >>> client = GitServerClient( + ... gitea_url=os.getenv("GITEA_URL"), + ... username=os.getenv("GITEA_USERNAME"), + ... password=os.getenv("GITEA_PASSWORD") + ... ) + >>> client.wait_for_ready() + >>> # Clone repo to workspace + >>> path = client.clone_to_workspace("my-repo", commit="abc123") + >>> # Fast reset to base state + >>> client.reset_workspace("my-repo", commit="abc123") + """ + + def __init__( + self, + gitea_url: str, + username: str, + password: str, + workspace_dir: str = "/workspace", + ): + """Initialize Git Server Client.""" + self.gitea_url = gitea_url.rstrip("/") + self.username = username + self.password = password + self.workspace_dir = Path(workspace_dir) + self.is_ready = False + + # Parse Gitea URL + parsed = urlparse(self.gitea_url) + self.domain = parsed.hostname or "localhost" + self.port = parsed.port or 3000 + + # Ensure workspace exists + os.makedirs(self.workspace_dir, exist_ok=True) + + # Configure git credentials + self._configure_git() + + def _configure_git(self): + """Configure git credentials for automatic authentication.""" + home_dir = Path.home() + + # Git config + git_config = f"""[user] + name = {self.username} + email = {self.username}@local.env +[init] + defaultBranch = main +[credential] + helper = store +""" + gitconfig_path = home_dir / ".gitconfig" + gitconfig_path.write_text(git_config) + + # Git credentials + git_credentials = ( + f"http://{self.username}:{self.password}@{self.domain}:{self.port}\n" + ) + gitcreds_path = home_dir / ".git-credentials" + gitcreds_path.write_text(git_credentials) + gitcreds_path.chmod(0o600) + + def wait_for_ready(self, timeout: int = 30) -> bool: + """ + Wait for Gitea server to be ready. + + Args: + timeout: Maximum seconds to wait + + Returns: + True if server is ready, False otherwise + """ + start_time = time.time() + while time.time() - start_time < timeout: + try: + result = subprocess.run( + ["curl", "-sf", f"{self.gitea_url}/"], + capture_output=True, + timeout=5, + ) + if result.returncode == 0: + self.is_ready = True + return True + except subprocess.TimeoutExpired: + pass + except Exception: + pass + + time.sleep(1) + + return False + + def list_repositories(self) -> list[dict[str, str]]: + """ + List all repositories in Gitea. + + Returns: + List of repository information dictionaries + """ + if not self.is_ready: + raise RuntimeError("Gitea server is not ready") + + result = subprocess.run( + [ + "curl", + "-s", + f"{self.gitea_url}/api/v1/user/repos", + "-u", + f"{self.username}:{self.password}", + ], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + return [] + + try: + repos = json.loads(result.stdout) + return [ + { + "name": repo["name"], + "full_name": repo["full_name"], + "clone_url": repo["clone_url"], + "description": repo.get("description", ""), + } + for repo in repos + ] + except (json.JSONDecodeError, KeyError): + return [] + + def clone_to_workspace( + self, repo_name: str, target_dir: str | None = None, commit: str = "main" + ) -> str: + """ + Clone a repository to the workspace at a specific commit. + + This creates a fresh clone optimized for task isolation. + + Args: + repo_name: Name of repository to clone + target_dir: Target directory name (defaults to repo_name) + commit: Commit hash or branch to check out + + Returns: + Path to cloned repository + + Raises: + RuntimeError: If clone fails + """ + if not self.is_ready: + raise RuntimeError("Gitea server is not ready") + + target_dir = target_dir or repo_name + target_path = self.workspace_dir / target_dir + + # Remove existing directory if present + if target_path.exists(): + shutil.rmtree(target_path) + + clone_url = f"{self.gitea_url}/{self.username}/{repo_name}.git" + + # Clone repository + result = subprocess.run( + ["git", "clone", clone_url, str(target_path)], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Clone failed: {result.stderr}") + + # Checkout specific commit + if commit != "main": + result = subprocess.run( + ["git", "checkout", commit], + cwd=str(target_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Checkout failed: {result.stderr}") + + return str(target_path) + + def reset_workspace(self, repo_name: str, commit: str = "main") -> bool: + """ + Fast reset of workspace to base state (optimized for task resets). + + This is much faster than re-cloning. It: + 1. Checks out the target commit + 2. Resets to that commit (hard) + 3. Cleans untracked files + + Args: + repo_name: Name of repository (directory in workspace) + commit: Commit hash or branch to reset to + + Returns: + True if reset successful + + Raises: + RuntimeError: If reset fails + """ + repo_path = self.workspace_dir / repo_name + + if not repo_path.exists(): + raise RuntimeError(f"Repository not found in workspace: {repo_name}") + + # Fetch latest (in case commit is new) + subprocess.run( + ["git", "fetch", "--all"], + cwd=str(repo_path), + capture_output=True, + ) + + # Checkout and hard reset to commit + result = subprocess.run( + ["git", "checkout", commit], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Checkout failed: {result.stderr}") + + result = subprocess.run( + [ + "git", + "reset", + "--hard", + f"origin/{commit}" if commit != "main" else commit, + ], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + # Try without origin/ prefix + result = subprocess.run( + ["git", "reset", "--hard", commit], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"Reset failed: {result.stderr}") + + # Clean untracked files and directories + subprocess.run( + ["git", "clean", "-fdx"], + cwd=str(repo_path), + capture_output=True, + ) + + return True + + def execute_git_command( + self, command: str, working_dir: str = "" + ) -> tuple[int, str, str]: + """ + Execute a git command in the workspace. + + Args: + command: Git command to execute (without 'git' prefix) + working_dir: Working directory relative to workspace + + Returns: + Tuple of (exit_code, stdout, stderr) + """ + work_path = ( + self.workspace_dir / working_dir if working_dir else self.workspace_dir + ) + + if not work_path.exists(): + return (1, "", f"Working directory does not exist: {work_path}") + + # Split command safely + cmd_parts = ["git"] + command.split() + + result = subprocess.run( + cmd_parts, + cwd=str(work_path), + capture_output=True, + text=True, + ) + + return (result.returncode, result.stdout, result.stderr) + + def get_current_commit(self, repo_name: str) -> str: + """ + Get current commit hash of a workspace repository. + + Args: + repo_name: Name of repository in workspace + + Returns: + Commit hash + """ + repo_path = self.workspace_dir / repo_name + + if not repo_path.exists(): + raise RuntimeError(f"Repository not found: {repo_name}") + + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(repo_path), + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to get commit: {result.stderr}") + + return result.stdout.strip() + + def workspace_exists(self, repo_name: str) -> bool: + """Check if a repository exists in workspace.""" + return (self.workspace_dir / repo_name).exists() diff --git a/src/openenv/core/tools/local_python_executor.py b/src/openenv/core/tools/local_python_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..bb18052b309b3c214bcf0e5c2645416734575fa1 --- /dev/null +++ b/src/openenv/core/tools/local_python_executor.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Local Python Executor (enhanced). + +This module provides a safer wrapper around smolagents.LocalPythonExecutor +with improved exception handling and a few helpful tools registered with +the executor to make debugging executed code easier. + +Key improvements: +- Register a few helper utilities via send_tools so user code can use + them for reporting (e.g. `format_exc`). +- More robust extraction of stdout/stderr/exit codes from the executor + result object, tolerant to different versions of smolagents. +- Detailed stderr on unexpected exceptions including full traceback. +- Structured logging for operational visibility. +""" + +from __future__ import annotations + +import json +import logging +import traceback + +from openenv.core.env_server.types import CodeExecResult +from smolagents import LocalPythonExecutor + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class PyExecutor: + """Wrapper around smolagents LocalPythonExecutor. + + The wrapper registers a few non-privileged helper tools to the + LocalPythonExecutor that can be used by the executed code to + format exceptions and to safely stringify results for improved + error reporting. + """ + + def __init__(self, additional_imports: list[str] | None = None): + if additional_imports is None: + additional_imports = [] + + self._executor = LocalPythonExecutor( + additional_authorized_imports=additional_imports + ) + + # Register helpful utilities exposed to the execution environment. + # These are intentionally small, read-only helpers. + tools = { + # Provide a small helper to format the current exception in the + # executed context. This is a *string formatting* helper only. + "format_exc": traceback.format_exc, + # Safe JSON dumps with a fallback for non-serializable objects. + "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)), + } + + # `send_tools` is the public API on LocalPythonExecutor to make + # helper callables available to the sandboxed runtime. We don't + # provide any builtins that could change the environment. + try: + self._executor.send_tools(tools) + except Exception: + # If the LocalPythonExecutor implementation doesn't support + # send_tools or fails, log and continue — the executor is still usable. + logger.debug( + "LocalPythonExecutor.send_tools failed; continuing without extra tools", + exc_info=True, + ) + + def run(self, code: str) -> CodeExecResult: + """Execute Python code and return a CodeExecResult. + + This method is intentionally defensive: it attempts to extract + meaningful stdout/stderr/exit_code information from a variety of + possible return shapes that different versions of smolagents + may provide. + """ + try: + exec_result = self._executor(code) + + # Default values + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + exit_code = 0 + + # Extract logs/prints + try: + logs = getattr(exec_result, "logs", None) + if logs: + stdout_parts.append(str(logs)) + except Exception: + logger.debug("Failed to read exec_result.logs", exc_info=True) + + # Extract the result / output value + try: + if hasattr(exec_result, "output"): + out_val = exec_result.output + # If the output is not None, stringify it in a safe way + if out_val is not None: + # Prefer JSON if possible, otherwise repr + try: + stdout_parts.append(json.dumps(out_val)) + except Exception: + stdout_parts.append(repr(out_val)) + except Exception: + logger.debug("Failed to read exec_result.output", exc_info=True) + + # Some runtime implementations may put errors on `error` or `exception` + try: + err = getattr(exec_result, "error", None) + if err: + stderr_parts.append(str(err)) + except Exception: + logger.debug("Failed to read exec_result.error", exc_info=True) + + try: + ex = getattr(exec_result, "exception", None) + if ex: + stderr_parts.append(str(ex)) + except Exception: + logger.debug("Failed to read exec_result.exception", exc_info=True) + + # Determine exit code if provided + try: + if hasattr(exec_result, "exit_code"): + exit_code = ( + int(exec_result.exit_code) + if exec_result.exit_code is not None + else 0 + ) + elif hasattr(exec_result, "success"): + # Some versions use `success` boolean + exit_code = 0 if exec_result.success else 1 + else: + # Fallback: if there were any stderr parts, treat as non-zero + exit_code = 1 if stderr_parts else 0 + except Exception: + logger.debug("Failed to determine exec_result exit code", exc_info=True) + exit_code = 1 if stderr_parts else 0 + + # Compose the final stdout/stderr strings + stdout = "\n".join(part for part in stdout_parts if part is not None) + stderr = "\n".join(part for part in stderr_parts if part is not None) + + return CodeExecResult(stdout=stdout, stderr=stderr, exit_code=exit_code) + + except Exception: + # Any unexpected exception from the LocalPythonExecutor is + # returned with a full traceback to make debugging easier. + tb = traceback.format_exc() + logger.exception("LocalPythonExecutor raised an exception during run") + return CodeExecResult(stdout="", stderr=tb, exit_code=1) diff --git a/src/openenv/core/utils.py b/src/openenv/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e86b3ae9c3e6ec0a19cd6f4868e4e3cdfee66bbc --- /dev/null +++ b/src/openenv/core/utils.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Utility functions for OpenEnv core.""" + +import asyncio +import concurrent.futures + + +def run_async_safely(coro): + """ + Run an async coroutine safely from any context. + + This handles the case where we may already be inside an async event loop + (e.g., when called from an async framework). In that case, asyncio.run() + would fail, so we use a ThreadPoolExecutor to run in a separate thread. + + Args: + coro: The coroutine to run + + Returns: + The result of the coroutine + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in async context - run in a thread pool + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, coro) + return future.result() + else: + # No async context - use asyncio.run() directly + return asyncio.run(coro) + + +def convert_to_ws_url(url: str) -> str: + """ + Convert an HTTP/HTTPS URL to a WS/WSS URL. + + Args: + url: The URL to convert. + + Returns: + The converted WebSocket URL. + """ + ws_url = url.rstrip("/") + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[7:] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[8:] + elif not ws_url.startswith("ws://") and not ws_url.startswith("wss://"): + ws_url = "ws://" + ws_url + return ws_url diff --git a/src/openenv_core.egg-info/PKG-INFO b/src/openenv_core.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..654a6035265f03dfe71cddb57b22a9385d428012 --- /dev/null +++ b/src/openenv_core.egg-info/PKG-INFO @@ -0,0 +1,440 @@ +Metadata-Version: 2.4 +Name: openenv-core +Version: 0.2.2.dev0 +Summary: A unified framework for reinforcement learning environments +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: fastapi>=0.104.0 +Requires-Dist: pydantic>=2.0.0 +Requires-Dist: uvicorn>=0.24.0 +Requires-Dist: requests>=2.25.0 +Requires-Dist: typer>=0.9.0 +Requires-Dist: rich>=13.0.0 +Requires-Dist: pyyaml>=6.0 +Requires-Dist: huggingface_hub>=0.20.0 +Requires-Dist: openai>=2.7.2 +Requires-Dist: tomli>=2.3.0 +Requires-Dist: tomli-w>=1.2.0 +Requires-Dist: websockets>=15.0.1 +Requires-Dist: fastmcp>=3.0.0 +Requires-Dist: gradio>=4.0.0 +Provides-Extra: core +Requires-Dist: fastapi>=0.104.0; extra == "core" +Requires-Dist: pydantic>=2.0.0; extra == "core" +Requires-Dist: uvicorn>=0.24.0; extra == "core" +Requires-Dist: requests>=2.25.0; extra == "core" +Requires-Dist: websockets>=15.0.1; extra == "core" +Provides-Extra: cli +Requires-Dist: typer>=0.9.0; extra == "cli" +Requires-Dist: rich>=13.0.0; extra == "cli" +Requires-Dist: pyyaml>=6.0; extra == "cli" +Requires-Dist: huggingface_hub>=0.20.0; extra == "cli" +Requires-Dist: openai>=2.7.2; extra == "cli" +Requires-Dist: tomli>=2.3.0; extra == "cli" +Requires-Dist: tomli-w>=1.2.0; extra == "cli" +Provides-Extra: docs +Requires-Dist: sphinx==7.2.6; extra == "docs" +Requires-Dist: pytorch-sphinx-theme2; extra == "docs" +Requires-Dist: sphinxcontrib.katex==0.9.10; extra == "docs" +Requires-Dist: docutils<0.21,>=0.18.1; extra == "docs" +Requires-Dist: sphinx-design==0.6.1; extra == "docs" +Requires-Dist: sphinxcontrib-mermaid==1.0.0; extra == "docs" +Requires-Dist: myst-parser; extra == "docs" +Requires-Dist: sphinxext-opengraph; extra == "docs" +Requires-Dist: sphinx-sitemap==2.7.1; extra == "docs" +Requires-Dist: sphinx-gallery>=0.14.0; extra == "docs" +Requires-Dist: matplotlib; extra == "docs" +Requires-Dist: nest-asyncio; extra == "docs" +Requires-Dist: smolagents; extra == "docs" +Provides-Extra: all +Requires-Dist: openenv-core[core]; extra == "all" +Requires-Dist: openenv-core[cli]; extra == "all" +Provides-Extra: daytona +Requires-Dist: daytona>=0.136.0; extra == "daytona" +Requires-Dist: pyyaml>=6.0; extra == "daytona" +Provides-Extra: inspect +Requires-Dist: inspect-ai>=0.3.0; extra == "inspect" +Dynamic: license-file + +# image OpenEnv: Agentic Execution Environments + +An e2e framework for creating, deploying and using isolated execution environments for agentic RL training, built using Gymnasium style simple APIs. + +[![PyPI](https://img.shields.io/pypi/v/openenv?color=blue)](https://pypi.org/project/openenv/) +[![Discord](https://img.shields.io/badge/Discord-OpenEnv-7289da?style=flat&logo=discord&logoColor=white)](https://discord.gg/YsTYBh6PD9) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-pytorch/OpenEnv/blob/main/examples/OpenEnv_Tutorial.ipynb) +[![Docs](https://img.shields.io/badge/Docs-Explore-blue?logo=readthedocs&logoColor=white)](https://meta-pytorch.org/OpenEnv/) + +--- + +**🚀 Featured Example:** Train LLMs to play BlackJack using [torchforge](https://github.com/meta-pytorch/torchforge) (PyTorch's agentic RL framework): [`examples/grpo_blackjack/`](examples/grpo_blackjack/) + +**🔥 Zero to Hero Tutorial:** End to end tutorial from our [GPU Mode](tutorial/README.md) lecture and other hackathons. + +## Quick Start + +Install the OpenEnv core package: + +```bash +pip install openenv-core +``` + +Install an environment client (e.g., Echo): + +```bash +pip install git+https://huggingface.co/spaces/openenv/echo_env +``` + +Then use the environment: + +```python +import asyncio +from echo_env import EchoAction, EchoEnv + +async def main(): + # Connect to a running Space (async context manager) + async with EchoEnv(base_url="https://openenv-echo-env.hf.space") as client: + # Reset the environment + result = await client.reset() + print(result.observation.echoed_message) # "Echo environment ready!" + + # Send messages + result = await client.step(EchoAction(message="Hello, World!")) + print(result.observation.echoed_message) # "Hello, World!" + print(result.reward) # 1.3 (based on message length) + +asyncio.run(main()) +``` + +**Synchronous usage** is also supported via the `.sync()` wrapper: + +```python +from echo_env import EchoAction, EchoEnv + +# Use .sync() for synchronous context manager +with EchoEnv(base_url="https://openenv-echo-env.hf.space").sync() as client: + result = client.reset() + result = client.step(EchoAction(message="Hello, World!")) + print(result.observation.echoed_message) +``` + +For a detailed quick start, check out the [docs page](https://meta-pytorch.org/OpenEnv/quickstart/). + +## OpenEnv on partner platforms: + +- [Lightning AI Studio](https://lightning.ai/environments?section=featured) +- [TRL example](https://huggingface.co/docs/trl/openenv) +- [Unsloth Google Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb) +- [ART example](https://art.openpipe.ai/integrations/openenv-integration) +- [Oumi example](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20OpenEnv%20GRPO%20with%20trl.ipynb) + +## Overview + +OpenEnv provides a standard for interacting with agentic execution environments via simple Gymnasium style APIs - `step()`, `reset()`, `state()`. Users of agentic execution environments can interact with the environment during RL training loops using these simple APIs. + +In addition to making it easier for researchers and RL framework writers, we also provide tools for environment creators making it easier for them to create richer environments and make them available over familiar protocols like HTTP and packaged using canonical technologies like docker. Environment creators can use the OpenEnv framework to create environments that are isolated, secure, and easy to deploy and use. + +The OpenEnv CLI (`openenv`) provides commands to initialize new environments and deploy them to Hugging Face Spaces. + +> ⚠️ **Early Development Warning** OpenEnv is currently in an experimental +> stage. You should expect bugs, incomplete features, and APIs that may change +> in future versions. The project welcomes bugfixes, but to make sure things are +> well coordinated you should discuss any significant change before starting the +> work. It's recommended that you signal your intention to contribute in the +> issue tracker, either by filing a new issue or by claiming an existing one. + +### RFCs + +Below is a list of active and historical RFCs for OpenEnv. RFCs are proposals for major changes or features. Please review and contribute! + +- [RFC 001: Baseline API and Interface Specifications](https://github.com/meta-pytorch/OpenEnv/pull/26) +- [RFC 002: Discoverability of environment tools by agents](https://github.com/meta-pytorch/OpenEnv/pull/32) +- [RFC 003: Add MCP (Model Context Protocol) support](https://github.com/meta-pytorch/OpenEnv/pull/224) +- [RFC 004: Add delayed rewards support for trajectory-based scoring](https://github.com/meta-pytorch/OpenEnv/pull/337) +- [RFC 005: Agentic Harness Integration](https://github.com/meta-pytorch/OpenEnv/pull/387) + +## Architecture + +### Component Overview + +``` +┌─────────────────────────────────────────────────────────┐ +│ Client Application │ +│ ┌────────────────┐ ┌──────────────────┐ │ +│ │ EchoEnv │ │ CodingEnv │ │ +│ │ (EnvClient) │ │ (EnvClient) │ │ +│ └────────┬───────┘ └────────┬─────────┘ │ +└───────────┼───────────────────────────────┼─────────────┘ + │ WebSocket │ WebSocket + │ (reset, step, state) │ +┌───────────▼───────────────────────────────▼─────────────┐ +│ Docker Containers (Isolated) │ +│ ┌──────────────────────┐ ┌──────────────────────┐ │ +│ │ FastAPI Server │ │ FastAPI Server │ │ +│ │ EchoEnvironment │ │ PythonCodeActEnv │ │ +│ │ (Environment base) │ │ (Environment base) │ │ +│ └──────────────────────┘ └──────────────────────┘ │ +└─────────────────────────────────────────────────────────┘ +``` + +### Core Components + +#### 1. Web Interface + +OpenEnv includes a built-in web interface for interactive environment exploration and debugging. The web interface provides: + +- **Two-Pane Layout**: HumanAgent interaction on the left, state observation on the right +- **Real-time Updates**: WebSocket-based live updates without page refresh +- **Dynamic Forms**: Automatically generated action forms based on environment Action types +- **Action History**: Complete log of all actions taken and their results + +The web interface is **conditionally enabled** based on environment variables: + +- **Local Development**: Disabled by default for lightweight development +- **Manual Override**: Enable with `ENABLE_WEB_INTERFACE=true` + +To use the web interface: + +```python +from openenv.core.env_server import create_web_interface_app +from your_env.models import YourAction, YourObservation +from your_env.server.your_environment import YourEnvironment + +env = YourEnvironment() +app = create_web_interface_app(env, YourAction, YourObservation) +``` + +When enabled, open `http://localhost:8000/web` in your browser to interact with the environment. + +#### 2. Environment (Server-Side) +Base class for implementing environment logic: +- **`reset()`**: Initialize a new episode, returns initial `Observation` +- **`step(action)`**: Execute an `Action`, returns resulting `Observation` +- **`state()`**: Access episode metadata (`State` with episode_id, step_count, etc.) + +#### 3. EnvClient (Client-Side) +Base class for environment communication: +- **Async by default**: Use `async with` and `await` for all operations +- **Sync wrapper**: Call `.sync()` to get a `SyncEnvClient` for synchronous usage +- Handles WebSocket connections to environment server +- Contains a utility to spin up a docker container locally for the corresponding environment +- Type-safe action/observation parsing + +#### 4. Container Providers +Manage container deployment: +- `LocalDockerProvider`: Run containers on local Docker daemon +- `KubernetesProvider`: Deploy to K8s clusters (future) + +#### 5. Models +Type-safe data structures: +- `Action`: Base class for environment actions +- `Observation`: Base class for environment observations +- `State`: Episode state tracking +- `StepResult`: Combines observation, reward, done flag + +## Project Structure + +### For Environment Creators + +Use the CLI to quickly scaffold a new environment: + +```bash +openenv init my_env +``` + +This creates the following structure: + +``` +my_env/ +├── .dockerignore # Docker build exclusions +├── __init__.py # Export YourAction, YourObservation, YourEnv +├── models.py # Define Action, Observation, State dataclasses +├── client.py # Implement YourEnv(EnvClient) +├── README.md # Document your environment +├── openenv.yaml # Environment manifest +├── pyproject.toml # Dependencies and package configuration +├── outputs/ # Runtime outputs (logs, evals) - gitignored +│ ├── logs/ +│ └── evals/ +└── server/ + ├── your_environment.py # Implement YourEnvironment(Environment) + ├── app.py # Create FastAPI app + ├── requirements.txt # Dependencies for Docker (can be generated) + └── Dockerfile # Define container image +``` + +#### Dependency Management + +OpenEnv uses `pyproject.toml` as the primary dependency specification: + +- **Environment-level `pyproject.toml`**: Each environment defines its own dependencies +- **Root-level `pyproject.toml`**: Contains shared core dependencies (fastapi, pydantic, uvicorn) +- **Server `requirements.txt`**: Can be auto-generated from `pyproject.toml` for Docker builds + +**Development Workflow:** + +```bash +# Install environment in editable mode +cd my_env +pip install -e . + +# Or using uv (faster) +uv pip install -e . + +# Run server locally without Docker +uv run server --host 0.0.0.0 --port 8000 +``` + +**Benefits:** +- ✅ **Client-side extensions**: Modify client classes locally without repo changes +- ✅ **Better dependency management**: Clear separation between environments +- ✅ **Flexible workflows**: Use pip, uv, or Docker for different scenarios +- ✅ **CI/CD ready**: Automated dependency generation and validation + +See [`envs/README.md`](envs/README.md) for a complete guide on building environments. + +### For Environment Users + +To use an environment: +1. Install the client: `pip install git+https://huggingface.co/spaces/openenv/echo-env` +2. Import: `from echo_env import EchoAction, EchoEnv` +3. Use async (recommended) or sync API: + +**Async (recommended):** +```python +async with EchoEnv(base_url="...") as client: + result = await client.reset() + result = await client.step(action) +``` + +**Sync (via `.sync()` wrapper):** +```python +with EchoEnv(base_url="...").sync() as client: + result = client.reset() + result = client.step(action) +``` + +See example scripts in `examples/` directory. + +## CLI Commands + +The OpenEnv CLI provides commands to manage environments: + +- **`openenv init `** - Initialize a new environment from template +- **`openenv push [--repo-id ] [--private]`** - Deploy environment to Hugging Face Spaces + +### Quick Start + +```bash +# Create a new environment +openenv init my_game_env + +# Deploy to Hugging Face (will prompt for login if needed) +cd my_game_env +openenv push +``` + +For detailed options: `openenv init --help` and `openenv push --help`. + +## Design Principles + +1. **Separation of Concerns**: Clear client-server boundaries +2. **Type Safety**: Strongly-typed actions, observations, and state +3. **Container Isolation**: Each environment runs in its own container +4. **Simple APIs**: Minimal, intuitive interfaces + +## Development + +### Installation + +```bash +# Clone the repository +git clone https://github.com/meta-pytorch/OpenEnv.git +cd OpenEnv + +# Install core package in editable mode +pip install -e . +# Or using uv (faster) +uv pip install -e . +``` + +### Running Tests + +OpenEnv uses a modular dependency structure: the core package is minimal, and each environment has its own dependencies. This means some tests require environment-specific packages. + +```bash +# Install pytest (required for running tests) +uv pip install pytest + +# Run all tests (skips tests requiring uninstalled dependencies) +PYTHONPATH=src:envs uv run pytest tests/ -v --tb=short + +# Run a specific test file +PYTHONPATH=src:envs uv run pytest tests/envs/test_echo_environment.py -v +``` + +**To run environment-specific tests**, install that environment's dependencies: + +```bash +# Example: Install coding_env with dev dependencies (includes smolagents + pytest) +uv pip install -e "envs/coding_env[dev]" + +# Then run coding_env tests +PYTHONPATH=src:envs uv run pytest tests/envs/test_python_codeact_rewards.py -v +``` + +Tests will be automatically skipped if their required dependencies aren't installed. + +## Requirements + +- Python 3.10+ +- Docker Desktop or Docker Engine +- FastAPI >= 0.104.0 +- Uvicorn >= 0.24.0 +- Requests >= 2.25.0 +- Environment-specific dependencies (e.g., smolagents for coding_env) + +## Supported RL Tools +The goal of this project is to support a broad set of open and closed tools to help standardize the agentic RL community. If you have a project that supports OpenEnv environments, please put up a PR to add your tool name along with a link to your documentation. + +### torchforge +See GRPO BlackJack training example: [`examples/grpo_blackjack/`](examples/grpo_blackjack/) + +### TRL +See the [TRL example](https://huggingface.co/docs/trl/openenv) on how to integrate OpenEnv environments with GRPO training. + +### Unsloth +See the 2048 game example based on gpt-oss: [Colab notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game.ipynb) + +### SkyRL +See the [SkyRL example](https://skyrl.readthedocs.io/en/latest/examples/openenv.html) on how to train on OpenEnv environments with SkyRL. + +### ART +See the [ART example](https://art.openpipe.ai/integrations/openenv-integration) on how OpenEnv environments can be used to train models with ART. + +### Oumi +See the [Oumi example](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20OpenEnv%20GRPO%20with%20trl.ipynb) on how OpenEnv environments can be used to train models with Oumi. + +## Example Environments + +| Environment | Description | +|---|---| +| [Echo Environment](envs/echo_env/README.md) | Echoes back messages with metadata. Ideal for testing HTTP server infrastructure, learning framework basics, and verifying container deployment. | +| [Coding Environment](envs/coding_env/README.md) | Sandboxed Python code execution via smolagents. Captures stdout/stderr/exit codes, supports persistent episode context, and provides detailed error handling. | +| [Chess Environment](envs/chess_env/README.md) | Chess RL environment with configurable opponents and full rules support. | +| [Atari Environment](envs/atari_env/README.md) | Classic Arcade Learning Environment tasks for RL benchmarking. | +| [FinRL Environment](envs/finrl_env/README.md) | Financial market simulations for algorithmic trading experiments. | + +> Browse the full catalog of community environments at [meta-pytorch.org/OpenEnv/environments](https://meta-pytorch.org/OpenEnv/environments/). + +## Community Support & Acknowledgments +This is an open and community-centric project. If you would like to add your name here, please put up a pull request and tag @jspisak for review. Ty!! + +Supporters include: Meta-PyTorch, Hugging Face, [Scaler AI Labs](https://scalerailabs.com), [Patronus AI](https://patronus.ai), [Surge AI](https://surgehq.ai), [LastMile AI](https://www.lastmileai.dev), Unsloth AI, Reflection AI, vLLM, SkyRL (UC-Berkeley), LightningAI, Axolotl AI, Stanford Scaling Intelligence Lab, Mithril, [OpenMined](https://openmined.org/), [Fleet AI](https://fleetai.com), [Halluminate](https://halluminate.ai/), [Turing](https://www.turing.com/), [Scale AI](https://scale.com/) .. + +And we'd also like to acknowledge the team at Farama Foundation as the OpenEnv API was heavily inspired by the work you all have done on Gymnasium. Cheers! + +## License + +BSD 3-Clause License (see [LICENSE](./LICENSE) file) diff --git a/src/openenv_core.egg-info/SOURCES.txt b/src/openenv_core.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..f06cb83904e0d986b1d0010ab0eb4c8ca541e85e --- /dev/null +++ b/src/openenv_core.egg-info/SOURCES.txt @@ -0,0 +1,81 @@ +LICENSE +README.md +pyproject.toml +src/openenv/__init__.py +src/openenv/auto/__init__.py +src/openenv/auto/_discovery.py +src/openenv/auto/auto_action.py +src/openenv/auto/auto_env.py +src/openenv/cli/__init__.py +src/openenv/cli/__main__.py +src/openenv/cli/_cli_utils.py +src/openenv/cli/_validation.py +src/openenv/cli/commands/__init__.py +src/openenv/cli/commands/build.py +src/openenv/cli/commands/fork.py +src/openenv/cli/commands/init.py +src/openenv/cli/commands/push.py +src/openenv/cli/commands/serve.py +src/openenv/cli/commands/skills.py +src/openenv/cli/commands/validate.py +src/openenv/cli/templates/__init__.py +src/openenv/cli/templates/__pycache__/__init__.cpython-311.pyc +src/openenv/cli/templates/__pycache__/__init__.cpython-313.pyc +src/openenv/cli/templates/openenv_env/README.md +src/openenv/cli/templates/openenv_env/__init__.py +src/openenv/cli/templates/openenv_env/client.py +src/openenv/cli/templates/openenv_env/models.py +src/openenv/cli/templates/openenv_env/openenv.yaml +src/openenv/cli/templates/openenv_env/pyproject.toml +src/openenv/cli/templates/openenv_env/server/Dockerfile +src/openenv/cli/templates/openenv_env/server/__ENV_NAME___environment.py +src/openenv/cli/templates/openenv_env/server/__init__.py +src/openenv/cli/templates/openenv_env/server/app.py +src/openenv/cli/templates/openenv_env/server/requirements.txt +src/openenv/core/__init__.py +src/openenv/core/client_types.py +src/openenv/core/env_client.py +src/openenv/core/generic_client.py +src/openenv/core/llm_client.py +src/openenv/core/mcp_client.py +src/openenv/core/sync_client.py +src/openenv/core/utils.py +src/openenv/core/containers/__init__.py +src/openenv/core/containers/test_local_docker_provider.py +src/openenv/core/containers/runtime/__init__.py +src/openenv/core/containers/runtime/daytona_provider.py +src/openenv/core/containers/runtime/providers.py +src/openenv/core/containers/runtime/uv_provider.py +src/openenv/core/env_server/__init__.py +src/openenv/core/env_server/base_transforms.py +src/openenv/core/env_server/exceptions.py +src/openenv/core/env_server/gradio_theme.py +src/openenv/core/env_server/gradio_ui.py +src/openenv/core/env_server/http_server.py +src/openenv/core/env_server/interfaces.py +src/openenv/core/env_server/mcp_environment.py +src/openenv/core/env_server/mcp_types.py +src/openenv/core/env_server/route_config.py +src/openenv/core/env_server/serialization.py +src/openenv/core/env_server/types.py +src/openenv/core/env_server/web_interface.py +src/openenv/core/evals/__init__.py +src/openenv/core/evals/base.py +src/openenv/core/evals/inspect_harness.py +src/openenv/core/evals/types.py +src/openenv/core/rubrics/__init__.py +src/openenv/core/rubrics/base.py +src/openenv/core/rubrics/containers.py +src/openenv/core/rubrics/llm_judge.py +src/openenv/core/rubrics/trajectory.py +src/openenv/core/tools/__init__.py +src/openenv/core/tools/git_server_client.py +src/openenv/core/tools/local_python_executor.py +src/openenv_core/__init__.py +src/openenv_core.egg-info/PKG-INFO +src/openenv_core.egg-info/SOURCES.txt +src/openenv_core.egg-info/dependency_links.txt +src/openenv_core.egg-info/entry_points.txt +src/openenv_core.egg-info/requires.txt +src/openenv_core.egg-info/top_level.txt +tests/test_line_endings.py \ No newline at end of file diff --git a/src/openenv_core.egg-info/dependency_links.txt b/src/openenv_core.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/openenv_core.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/openenv_core.egg-info/entry_points.txt b/src/openenv_core.egg-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..a771c213d08379ccc0be28741bcfaccc3f64193a --- /dev/null +++ b/src/openenv_core.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +openenv = openenv.cli.__main__:main diff --git a/src/openenv_core.egg-info/requires.txt b/src/openenv_core.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..caf0c5f6727bc58526e2a96a3f14cc28713c32c6 --- /dev/null +++ b/src/openenv_core.egg-info/requires.txt @@ -0,0 +1,56 @@ +fastapi>=0.104.0 +pydantic>=2.0.0 +uvicorn>=0.24.0 +requests>=2.25.0 +typer>=0.9.0 +rich>=13.0.0 +pyyaml>=6.0 +huggingface_hub>=0.20.0 +openai>=2.7.2 +tomli>=2.3.0 +tomli-w>=1.2.0 +websockets>=15.0.1 +fastmcp>=3.0.0 +gradio>=4.0.0 + +[all] +openenv-core[core] +openenv-core[cli] + +[cli] +typer>=0.9.0 +rich>=13.0.0 +pyyaml>=6.0 +huggingface_hub>=0.20.0 +openai>=2.7.2 +tomli>=2.3.0 +tomli-w>=1.2.0 + +[core] +fastapi>=0.104.0 +pydantic>=2.0.0 +uvicorn>=0.24.0 +requests>=2.25.0 +websockets>=15.0.1 + +[daytona] +daytona>=0.136.0 +pyyaml>=6.0 + +[docs] +sphinx==7.2.6 +pytorch-sphinx-theme2 +sphinxcontrib.katex==0.9.10 +docutils<0.21,>=0.18.1 +sphinx-design==0.6.1 +sphinxcontrib-mermaid==1.0.0 +myst-parser +sphinxext-opengraph +sphinx-sitemap==2.7.1 +sphinx-gallery>=0.14.0 +matplotlib +nest-asyncio +smolagents + +[inspect] +inspect-ai>=0.3.0 diff --git a/src/openenv_core.egg-info/top_level.txt b/src/openenv_core.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..e540ace7b2bb39dfb2445c526e67c391e8249487 --- /dev/null +++ b/src/openenv_core.egg-info/top_level.txt @@ -0,0 +1,2 @@ +openenv +openenv_core diff --git a/src/openenv_core/__init__.py b/src/openenv_core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cde96644033202a9feaf604065600822d08663d --- /dev/null +++ b/src/openenv_core/__init__.py @@ -0,0 +1,54 @@ +""" +Compatibility shim for the historical ``openenv_core`` package. + +The core runtime now lives under ``openenv.core``. Importing from the old +package path will continue to work but emits a ``DeprecationWarning`` so +downstream users can migrate at their own pace. +""" + +from __future__ import annotations + +import importlib +import sys +import warnings +from types import ModuleType +from typing import Dict + +_TARGET_PREFIX = "openenv.core" +_TARGET_MODULE = importlib.import_module(_TARGET_PREFIX) + +warnings.warn( + "openenv_core is deprecated; import from openenv.core instead.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = getattr(_TARGET_MODULE, "__all__", []) + + +def __getattr__(name: str): + return getattr(_TARGET_MODULE, name) + + +def __dir__(): + return sorted(set(dir(_TARGET_MODULE))) + + +def _alias(name: str) -> None: + target = f"{_TARGET_PREFIX}.{name}" + sys.modules[f"{__name__}.{name}"] = importlib.import_module(target) + + +for _child in ( + "client_types", + "containers", + "env_client", + "env_server", + "rubrics", + "tools", + "utils", +): + try: + _alias(_child) + except ModuleNotFoundError: # pragma: no cover - defensive + continue