Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Code Fixer Environment Client.""" | |
| import asyncio | |
| import inspect | |
| import logging | |
| from typing import Dict | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| from rl_code_fix_env.models import CodeFixerAction, CodeFixerObservation | |
| log = logging.getLogger(__name__) | |
| class CodeFixerEnv( | |
| EnvClient[CodeFixerAction, CodeFixerObservation, State] | |
| ): | |
| """ | |
| Client for the Code Fixer Environment. | |
| This client maintains a persistent WebSocket connection to the environment server, | |
| enabling efficient multi-step interactions with lower latency. | |
| Each client instance has its own dedicated environment session on the server. | |
| Example: | |
| >>> # Connect to a running server | |
| >>> with CodeFixerEnv(base_url="http://localhost:8000") as client: | |
| ... result = client.reset() | |
| ... print(result.observation.code) | |
| ... | |
| ... result = client.step(CodeFixerAction(type="run_tests")) | |
| ... print(result.observation.test_passed) | |
| Example with Docker: | |
| >>> # Automatically start container and connect | |
| >>> client = CodeFixerEnv.from_docker_image("code_fixer-env:latest") | |
| >>> try: | |
| ... result = client.reset() | |
| ... result = client.step(CodeFixerAction(type="run_tests")) | |
| ... finally: | |
| ... client.close() | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self._loop = asyncio.new_event_loop() | |
| # Store init args for reconnection | |
| self._init_args = args | |
| self._init_kwargs = kwargs | |
| def _run_sync(self, result): | |
| """Run coroutine results on this client's dedicated event loop.""" | |
| if inspect.iscoroutine(result): | |
| return self._loop.run_until_complete(result) | |
| return result | |
| def _reconnect(self) -> None: | |
| """ | |
| Tear down the dead event loop and WebSocket connection, then | |
| re-initialise so the next call works cleanly. | |
| Called automatically by reset() and step() when a 1011 / timeout | |
| error is detected after an idle period. | |
| """ | |
| log.warning("[CodeFixerEnv] WebSocket timed out reconnecting...") | |
| # Close the old loop gracefully | |
| try: | |
| self._run_sync(super().close()) | |
| except Exception: | |
| pass | |
| if not self._loop.is_closed(): | |
| self._loop.close() | |
| # Re-initialise: fresh loop + fresh base-class state | |
| self._loop = asyncio.new_event_loop() | |
| super().__init__(*self._init_args, **self._init_kwargs) | |
| log.warning("[CodeFixerEnv] Reconnected successfully.") | |
| def _is_reconnectable_ws_error(exc: Exception) -> bool: | |
| err = str(exc).lower() | |
| reconnect_markers = ( | |
| "1011", | |
| "1006", | |
| "keepalive", | |
| "timed out", | |
| "closed", | |
| "close frame", | |
| "connection closed", | |
| "connectionclosed", | |
| "websocket", | |
| ) | |
| return any(marker in err for marker in reconnect_markers) | |
| def reset(self): | |
| """Reset the environment auto-reconnects if the WebSocket died.""" | |
| try: | |
| return self._run_sync(super().reset()) | |
| except Exception as exc: | |
| if self._is_reconnectable_ws_error(exc): | |
| self._reconnect() | |
| return self._run_sync(super().reset()) # one retry | |
| raise | |
| def step(self, action: CodeFixerAction): | |
| """Execute a step auto-reconnects if the WebSocket died.""" | |
| try: | |
| return self._run_sync(super().step(action)) | |
| except Exception as exc: | |
| if self._is_reconnectable_ws_error(exc): | |
| self._reconnect() | |
| return self._run_sync(super().step(action)) # one retry | |
| raise | |
| def close(self): | |
| """Close client resources and the dedicated event loop safely.""" | |
| try: | |
| self._run_sync(super().close()) | |
| finally: | |
| if not self._loop.is_closed(): | |
| self._loop.close() | |
| def _step_payload(self, action: CodeFixerAction) -> Dict: | |
| """ | |
| Convert CodeFixerAction to JSON payload for step message. | |
| Args: | |
| action: CodeFixerAction instance | |
| Returns: | |
| Dictionary representation suitable for JSON encoding | |
| """ | |
| return { | |
| "type": action.type, | |
| "payload": action.payload, | |
| } | |
| def _parse_result(self, payload: Dict) -> StepResult[CodeFixerObservation]: | |
| """ | |
| Parse server response into StepResult[CodeFixerObservation]. | |
| Args: | |
| payload: JSON response data from server | |
| Returns: | |
| StepResult with CodeFixerObservation | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| observation = CodeFixerObservation( | |
| code=obs_data.get("code", ""), | |
| logs=obs_data.get("logs"), | |
| test_score=float(obs_data.get("test_score", 0.0)), | |
| total_tests=obs_data.get("total_tests", 1), | |
| steps=obs_data.get("steps", 0), | |
| done=obs_data.get("done", payload.get("done", False)), | |
| reward=obs_data.get("reward", payload.get("reward")), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> State: | |
| """ | |
| Parse server response into State object. | |
| Args: | |
| payload: JSON response from state request | |
| Returns: | |
| State object with episode_id and step_count | |
| """ | |
| return State( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| ) | |