| |
| |
| |
| |
| |
|
|
| """Slipstream Governance Environment Client.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Dict |
|
|
| try: |
| from openenv.core.client_types import StepResult |
| from openenv.core.env_client import EnvClient |
| from .models import SlipstreamAction, SlipstreamObservation, SlipstreamState |
| except ImportError: |
| from openenv.core.client_types import StepResult |
| from openenv.core.env_client import EnvClient |
| from models import SlipstreamAction, SlipstreamObservation, SlipstreamState |
|
|
|
|
| class SlipstreamGovEnv(EnvClient[SlipstreamAction, SlipstreamObservation, SlipstreamState]): |
| """Client for SlipstreamGov OpenEnv environment.""" |
|
|
| def _step_payload(self, action: SlipstreamAction) -> Dict: |
| return {"message": action.message} |
|
|
| def _parse_result(self, payload: Dict) -> StepResult[SlipstreamObservation]: |
| obs_data = payload.get("observation", {}) or {} |
|
|
| observation = SlipstreamObservation( |
| task_prompt=obs_data.get("task_prompt"), |
| parsed_slip=obs_data.get("parsed_slip"), |
| expected_anchor=obs_data.get("expected_anchor"), |
| predicted_anchor=obs_data.get("predicted_anchor"), |
| arg_overlap=obs_data.get("arg_overlap", 0.0), |
| violations=obs_data.get("violations", []) or [], |
| metrics=obs_data.get("metrics", {}) or {}, |
| done=payload.get("done", False), |
| reward=payload.get("reward"), |
| metadata=obs_data.get("metadata", {}) or {}, |
| ) |
|
|
| return StepResult( |
| observation=observation, |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| ) |
|
|
| def _parse_state(self, payload: Dict) -> SlipstreamState: |
| return SlipstreamState( |
| episode_id=payload.get("episode_id"), |
| step_count=payload.get("step_count", 0), |
| scenario_id=payload.get("scenario_id"), |
| attack=payload.get("attack", False), |
| ) |
|
|