| | |
| | |
| | |
| | |
| | |
| |
|
| | """Server implementation for the generic TextArena environment.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import sys |
| | from typing import Any, Dict, Iterable, List, Optional |
| | from uuid import uuid4 |
| |
|
| | import nltk |
| |
|
| | from openenv.core.env_server.interfaces import Environment |
| |
|
| | try: |
| | |
| | from textarena_env.models import ( |
| | TextArenaAction, |
| | TextArenaMessage, |
| | TextArenaObservation, |
| | TextArenaState, |
| | ) |
| | from textarena_env.rewards import RewardProvider, build_reward_providers |
| | except ImportError: |
| | |
| | from models import ( |
| | TextArenaAction, |
| | TextArenaMessage, |
| | TextArenaObservation, |
| | TextArenaState, |
| | ) |
| | from rewards import RewardProvider, build_reward_providers |
| |
|
| |
|
| | _TEXTARENA_MODULE: Any | None = None |
| | _TEXTARENA_IMPORT_ERROR: Exception | None = None |
| |
|
| |
|
| | def _import_textarena() -> Any: |
| | """Import ``textarena`` lazily and cache the module reference.""" |
| |
|
| | global _TEXTARENA_MODULE, _TEXTARENA_IMPORT_ERROR |
| |
|
| | if _TEXTARENA_MODULE is not None: |
| | return _TEXTARENA_MODULE |
| |
|
| | if _TEXTARENA_IMPORT_ERROR is not None: |
| | raise _TEXTARENA_IMPORT_ERROR |
| |
|
| | if sys.version_info < (3, 10): |
| | _TEXTARENA_IMPORT_ERROR = RuntimeError( |
| | "TextArena environments require Python 3.10 or newer; " |
| | f"current interpreter is {sys.version_info.major}.{sys.version_info.minor}" |
| | ) |
| | raise _TEXTARENA_IMPORT_ERROR |
| |
|
| | try: |
| | import textarena as ta |
| | except Exception as exc: |
| | _TEXTARENA_IMPORT_ERROR = exc |
| | raise |
| |
|
| | _TEXTARENA_MODULE = ta |
| | return ta |
| |
|
| |
|
| | class TextArenaEnvironment(Environment): |
| | """Wrap any TextArena game behind the OpenEnv ``Environment`` API.""" |
| |
|
| | def __init__( |
| | self, |
| | env_id: str = "Wordle-v0", |
| | *, |
| | num_players: int = 1, |
| | max_turns: Optional[int] = None, |
| | download_nltk: bool = True, |
| | env_kwargs: Optional[Dict[str, Any]] = None, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | ta = _import_textarena() |
| |
|
| | if download_nltk: |
| | nltk.download("words", quiet=True) |
| | nltk.download("averaged_perceptron_tagger_eng", quiet=True) |
| |
|
| | self.env_id = env_id |
| | self.num_players = num_players |
| | self.max_turns = max_turns |
| | self._env_kwargs = env_kwargs or {} |
| |
|
| | self._ta_env = ta.make(env_id=env_id, **self._env_kwargs) |
| |
|
| | self._state = TextArenaState( |
| | env_id=env_id, |
| | num_players=num_players, |
| | max_turns=max_turns, |
| | ) |
| |
|
| | self._reward_providers: List[RewardProvider] = build_reward_providers(env_id) |
| | self._last_reward_signals: Dict[str, float] = {} |
| |
|
| | |
| | |
| | |
| | def reset(self) -> TextArenaObservation: |
| | |
| | |
| | |
| | env = self._ta_env |
| | while hasattr(env, "env"): |
| | if hasattr(env, "full_observations"): |
| | env.full_observations = {} |
| | env = env.env |
| | |
| | if hasattr(env, "full_observations"): |
| | env.full_observations = {} |
| |
|
| | self._ta_env.reset(num_players=self.num_players) |
| |
|
| | for provider in self._reward_providers: |
| | provider.reset() |
| |
|
| | self._state.episode_id = str(uuid4()) |
| | self._state.step_count = 0 |
| | self._state.turn = 0 |
| | self._state.last_reward = 0.0 |
| | self._state.last_info = {} |
| | self._state.raw_state = self._snapshot_state() |
| | self._last_reward_signals = {} |
| |
|
| | observation = self._build_observation() |
| | observation.reward = 0.0 |
| | observation.done = False |
| |
|
| | return observation |
| |
|
| | def step(self, action: TextArenaAction) -> TextArenaObservation: |
| | if not isinstance(action, TextArenaAction): |
| | raise TypeError(f"Expected TextArenaAction, received {type(action)!r}") |
| |
|
| | done, info = self._ta_env.step(action.message) |
| |
|
| | self._state.step_count += 1 |
| | self._state.turn = getattr(self._ta_env.state, "turn", self._state.turn + 1) |
| | self._state.last_info = info or {} |
| |
|
| | observation = self._build_observation() |
| | observation.done = done |
| |
|
| | reward = self._extract_reward() |
| | observation.reward = reward |
| | self._state.last_reward = reward |
| |
|
| | reward_signals = self._compute_reward_signals(action=action, observation=observation) |
| | if reward_signals: |
| | observation.info.setdefault("reward_signals", {}).update(reward_signals) |
| | observation.metadata.setdefault("reward_signals", {}).update(reward_signals) |
| | self._last_reward_signals = reward_signals |
| | if reward_signals: |
| | self._state.last_info = { |
| | **(self._state.last_info or {}), |
| | "reward_signals": reward_signals, |
| | } |
| | self._state.raw_state = self._snapshot_state() |
| |
|
| | return observation |
| |
|
| | @property |
| | def state(self) -> TextArenaState: |
| | return self._state |
| |
|
| | |
| | |
| | |
| | def _build_observation(self) -> TextArenaObservation: |
| | player_id, messages = self._ta_env.get_observation() |
| |
|
| | ta_messages = self._convert_messages(messages) |
| |
|
| | |
| | |
| | |
| | |
| | prompt_lines = [msg.content for msg in ta_messages if msg.category == "PROMPT"] |
| |
|
| | if not prompt_lines: |
| | |
| | |
| | current_turn = getattr(self._ta_env.state, "turn", 0) |
| | if current_turn == 0 and ta_messages: |
| | prompt_lines = [ta_messages[0].content] |
| | else: |
| | |
| | prompt_lines = [self.env_id] |
| |
|
| | prompt = "\n".join(prompt_lines).strip() |
| |
|
| | info: Dict[str, Any] = {} |
| | info.update(getattr(self._ta_env.state, "step_info", {})) |
| |
|
| | observation = TextArenaObservation( |
| | prompt=prompt, |
| | messages=ta_messages, |
| | current_player_id=player_id, |
| | legal_players=self._legal_players(), |
| | info=info, |
| | metadata={ |
| | "env_id": self.env_id, |
| | "turn": getattr(self._ta_env.state, "turn", 0), |
| | "raw_messages": [ |
| | { |
| | "sender_id": msg.sender_id, |
| | "content": msg.content, |
| | "category": msg.category, |
| | } |
| | for msg in ta_messages |
| | ], |
| | }, |
| | ) |
| |
|
| | return observation |
| |
|
| | def _legal_players(self) -> List[int]: |
| | role_mapping = getattr(self._ta_env.state, "role_mapping", {}) or {} |
| | players = [pid for pid in role_mapping.keys() if isinstance(pid, int) and pid >= 0] |
| | return sorted(players) |
| |
|
| | def _convert_messages(self, messages: Iterable[Any]) -> List[TextArenaMessage]: |
| | converted: List[TextArenaMessage] = [] |
| | buffered_sender: int | None = None |
| | buffered_category: str | None = None |
| | buffered_content: List[str] = [] |
| |
|
| | def flush_buffer() -> None: |
| | nonlocal buffered_content, buffered_sender, buffered_category |
| | if not buffered_content: |
| | return |
| | converted.append( |
| | TextArenaMessage( |
| | sender_id=buffered_sender if buffered_sender is not None else -1, |
| | content="".join(buffered_content), |
| | category=buffered_category or "MESSAGE", |
| | ) |
| | ) |
| | buffered_content = [] |
| | buffered_category = None |
| | buffered_sender = None |
| |
|
| | for entry in messages: |
| | if isinstance(entry, tuple) and len(entry) == 3: |
| | sender, content, category = entry |
| | elif isinstance(entry, tuple) and len(entry) == 2: |
| | sender, content = entry |
| | category = "MESSAGE" |
| | else: |
| | sender, content, category = -1, str(entry), "MESSAGE" |
| |
|
| | category_name = getattr(category, "name", str(category)) |
| | sender_id = int(sender) if isinstance(sender, (int, float)) else -1 |
| | text = str(content) |
| |
|
| | if buffered_content and buffered_category == category_name and buffered_sender == sender_id: |
| | buffered_content.append(text) |
| | else: |
| | flush_buffer() |
| | buffered_sender = sender_id |
| | buffered_category = category_name |
| | buffered_content = [text] |
| |
|
| | flush_buffer() |
| |
|
| | return converted |
| |
|
| | def _extract_reward(self) -> float: |
| | rewards = getattr(self._ta_env.state, "rewards", None) |
| | if isinstance(rewards, dict): |
| | |
| | player_id = getattr(self._ta_env.state, "current_player_id", 0) |
| | if player_id in rewards: |
| | return float(rewards[player_id]) |
| | if 0 in rewards: |
| | return float(rewards[0]) |
| | return 0.0 |
| |
|
| | def _snapshot_state(self) -> Dict[str, Any]: |
| | state = self._ta_env.state |
| | snapshot: Dict[str, Any] = { |
| | "turn": getattr(state, "turn", 0), |
| | "game_state": getattr(state, "game_state", {}), |
| | "logs": list(getattr(state, "logs", [])), |
| | "rewards": getattr(state, "rewards", None), |
| | "done": getattr(state, "done", False), |
| | "role_mapping": getattr(state, "role_mapping", {}), |
| | "game_info": getattr(state, "game_info", {}), |
| | "step_info": getattr(state, "step_info", {}), |
| | } |
| | if self._last_reward_signals: |
| | snapshot["reward_signals"] = dict(self._last_reward_signals) |
| | return snapshot |
| |
|
| | def _compute_reward_signals( |
| | self, *, action: TextArenaAction, observation: TextArenaObservation |
| | ) -> Dict[str, float]: |
| | if not self._reward_providers: |
| | return {} |
| |
|
| | aggregated: Dict[str, float] = {} |
| | for provider in self._reward_providers: |
| | try: |
| | result = provider.compute(action=action, observation=observation) |
| | except Exception: |
| | continue |
| | for key, value in result.items(): |
| | aggregated[key] = float(value) |
| | return aggregated |
| |
|