from typing import List, Union from ..agent import SIGNAL_END_OF_CONVERSATION, Moderator from ..config import AgentConfig, EnvironmentConfig from ..message import Message, MessagePool from .base import Environment, TimeStep class Conversation(Environment): """ Turn-based fully observable conversation environment. Next speaker order is either parallel or round-robin. """ type_name = "conversation" def __init__(self, player_names: List[str], parallel: bool = False, **kwargs): super().__init__(player_names=player_names, parallel=parallel, **kwargs) self.parallel = parallel # The "state" of the environment is maintained by the message pool self.message_pool = MessagePool() self._current_turn = 0 self._next_player_index = 0 def reset(self): self._current_turn = 0 self._next_player_index = 0 self.message_pool.reset() init_timestep = TimeStep( observation=[], reward=self.get_zero_rewards(), terminal=False ) return init_timestep @property def phase_index(self): return self._phase_index @phase_index.setter def phase_index(self, value): self._phase_index = value def to_config(self) -> EnvironmentConfig: return EnvironmentConfig( env_type=self.type_name, player_names=self.player_names, parallel=self.parallel, ) def print(self): self.message_pool.print() def get_next_player(self) -> str: """Get the next player.""" return self.player_names[self._next_player_index] def get_observation(self, player_name=None) -> List[Message]: """Get observation for the player.""" if player_name is None: return self.message_pool.get_all_messages() else: return self.message_pool.get_visible_messages( player_name, turn=self._current_turn ) def is_terminal(self) -> bool: """Check if the conversation is over.""" # If the last message is the signal, then the conversation is over if self.message_pool.last_message.content.startswith( SIGNAL_END_OF_CONVERSATION ): return True def step(self, player_name: str, action: str) -> TimeStep: """ Step function that is called by the arena. Args: player_name: the name of the player that takes the action action: the action that the agents wants to take """ message = Message( agent_name=player_name, content=action, turn=self._current_turn ) self.message_pool.append_message(message) # Update the counters if not self.parallel or self._next_player_index == 0: self._current_turn += 1 self._next_player_index = (self._next_player_index + 1) % self.num_players timestep = TimeStep( observation=self.get_observation(), reward=self.get_zero_rewards(), terminal=self.is_terminal(), ) # Return all the messages return timestep class ModeratedConversation(Conversation): """ Turn-based fully observable conversation environment. Next speaker order is either parallel or round-robin. Moderator is a special agent that can see all messages and can decide whether the conversation is over. """ type_name = "moderated_conversation" def __init__( self, player_names: List[str], moderator: Union[Moderator, AgentConfig], parallel: bool = False, moderator_visibility="all", moderator_period=None, **kwargs, ): super().__init__(player_names=player_names, parallel=parallel, **kwargs) if isinstance(moderator, AgentConfig): moderator_config = moderator moderator = Moderator.from_config(moderator_config) elif not isinstance(moderator, Moderator): raise ValueError( "moderator must be either an AgentConfig or a Moderator instance." ) self.moderator = moderator self.moderator_visibility = moderator_visibility if moderator_period is None: if parallel: self.moderator_period = "round" else: self.moderator_period = "turn" else: self.moderator_period = moderator_period def to_config(self) -> EnvironmentConfig: # This environment contains some special config arguments that needs to be handle specially return EnvironmentConfig( env_type=self.type_name, player_names=self.player_names, parallel=self.parallel, moderator=self.moderator.to_config(), moderator_visibility=self.moderator_visibility, moderator_period=self.moderator_period, ) def step(self, player_name: str, action: str) -> TimeStep: """ Step function that is called by the arena. Args: player_name: the name of the player that takes the action action: the action that the agents wants to take """ message = Message( agent_name=player_name, content=action, turn=self._current_turn ) self.message_pool.append_message(message) # Round-robin order for the next player self._next_player_index = (self._next_player_index + 1) % self.num_players if self.moderator_period == "turn" or ( self.moderator_period == "round" and self._next_player_index == 0 ): # Moderator's turn moderator_history = self.message_pool.get_all_messages() moderator_response = self.moderator(moderator_history) moderator_message = Message( agent_name=self.moderator.name, content=moderator_response, turn=self._current_turn, visible_to=self.moderator_visibility, ) self.message_pool.append_message(moderator_message) terminal = ( self.moderator.is_terminal(moderator_history) or self.is_terminal() ) else: terminal = self.is_terminal() # Update the counters if not self.parallel or self._next_player_index == 0: self._current_turn += 1 timestep = TimeStep( observation=self.get_observation(), reward=self.get_zero_rewards(), terminal=terminal, ) # Return all the messages return timestep