Red-Button / red_button /client.py
Arun-Sanjay's picture
Phase 7: OpenEnv server (ShutdownGymEnvironment, create_app wiring, EnvClient), custom-action mode, turn-budget timer per PROJECT.md Sections 4, 8, 10, 11, 13
453233a
"""Typed OpenEnv client for Red Button.
Mirrors :mod:`envs.coding_env.client.CodingEnv` from the upstream OpenEnv
reference repo (PROJECT.md Section 4). The three abstract hooks expected by
``EnvClient`` are:
* ``_step_payload(action)`` β€” JSON payload for the ``/step`` endpoint. We use
``action.model_dump()`` (rather than coding_env's hand-written shape) because
``ShutdownAction`` carries a nested ``arguments: dict``.
* ``_parse_result(payload)`` β€” convert the server response into a
``StepResult[ShutdownObservation]``. The ``observation`` blob carries
``done``/``reward`` natively (PROJECT.md Section 11), but the framework
also surfaces them at the top level of the WebSocket envelope; we honour
whichever is present.
* ``_parse_state(payload)`` β€” full Pydantic ``ShutdownState`` reconstruction
via ``model_validate`` (Section 11.6 + 11.3).
"""
from __future__ import annotations
from typing import Any
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from red_button.models import ShutdownAction, ShutdownObservation, ShutdownState
class ShutdownGymClient(EnvClient[ShutdownAction, ShutdownObservation, ShutdownState]):
"""Typed client wrapping the Red Button OpenEnv HTTP/WebSocket server."""
def _step_payload(self, action: ShutdownAction) -> dict[str, Any]:
# ``model_dump`` gives a fully-validated round-trippable dict β€”
# safer than hand-rolling for our nested ``arguments`` shape.
return action.model_dump()
def _parse_result(self, payload: dict[str, Any]) -> StepResult[ShutdownObservation]:
obs_data = payload.get("observation", payload)
obs = ShutdownObservation.model_validate(obs_data)
# Top-level "reward"/"done" win when present (the framework's WebSocket
# envelope sets them); fall back to whatever the observation carries.
reward = payload.get("reward", obs.reward)
done = bool(payload.get("done", obs.done))
return StepResult(observation=obs, reward=reward, done=done)
def _parse_state(self, payload: dict[str, Any]) -> ShutdownState:
return ShutdownState.model_validate(payload)