MiniGridEnv / client.py
yashu2000's picture
Upload folder using huggingface_hub
a03a89b verified
"""Typed OpenEnv client for MiniGridEnv."""
from __future__ import annotations
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
try:
from .env.models import MiniGridAction, MiniGridObservation, MiniGridState
except ImportError:
from env.models import MiniGridAction, MiniGridObservation, MiniGridState
class MiniGridEnvClient(EnvClient[MiniGridAction, MiniGridObservation, MiniGridState]):
"""WebSocket client for interacting with a MiniGridEnv server."""
def _step_payload(self, action: MiniGridAction) -> Dict[str, Any]:
payload: Dict[str, Any] = {"command": action.command}
if action.thought:
payload["thought"] = action.thought
return payload
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[MiniGridObservation]:
obs_data = payload.get("observation")
if not isinstance(obs_data, dict):
obs_data = payload if isinstance(payload, dict) else {}
done = bool(payload.get("done", obs_data.get("done", False)))
reward = payload.get("reward", obs_data.get("reward"))
observation = MiniGridObservation(
text=obs_data.get("text", ""),
mission=obs_data.get("mission", ""),
step_idx=obs_data.get("step_idx", 0),
steps_remaining=obs_data.get("steps_remaining", 0),
max_steps=obs_data.get("max_steps", 1),
history=obs_data.get("history", []),
level_name=obs_data.get("level_name", ""),
last_action=obs_data.get("last_action"),
action_success=obs_data.get("action_success"),
done=done,
reward=reward,
metadata=obs_data.get("metadata", {}),
)
return StepResult(observation=observation, reward=reward, done=done)
def _parse_state(self, payload: Dict[str, Any]) -> MiniGridState:
state_data = payload.get("state")
if not isinstance(state_data, dict):
state_data = payload if isinstance(payload, dict) else {}
return MiniGridState(
episode_id=state_data.get("episode_id"),
step_count=state_data.get("step_count", 0),
level_name=state_data.get("level_name", ""),
level_difficulty=state_data.get("level_difficulty", 0),
completed=state_data.get("completed", False),
truncated=state_data.get("truncated", False),
total_reward=state_data.get("total_reward", 0.0),
steps_taken=state_data.get("steps_taken", 0),
optimal_steps=state_data.get("optimal_steps"),
efficiency_ratio=state_data.get("efficiency_ratio"),
valid_actions=state_data.get("valid_actions", 0),
invalid_actions=state_data.get("invalid_actions", 0),
action_distribution=state_data.get("action_distribution", {}),
)
MiniGridEnv = MiniGridEnvClient