Spaces:
Running
Running
from typing import List, Union | |
from ..agent import SIGNAL_END_OF_CONVERSATION, Moderator | |
from ..config import AgentConfig, EnvironmentConfig | |
from ..message import Message, MessagePool | |
from .base import Environment, TimeStep | |
class Conversation(Environment): | |
""" | |
Turn-based fully observable conversation environment. | |
Next speaker order is either parallel or round-robin. | |
""" | |
type_name = "conversation" | |
def __init__(self, player_names: List[str], parallel: bool = False, **kwargs): | |
super().__init__(player_names=player_names, parallel=parallel, **kwargs) | |
self.parallel = parallel | |
# The "state" of the environment is maintained by the message pool | |
self.message_pool = MessagePool() | |
self._current_turn = 0 | |
self._next_player_index = 0 | |
def reset(self): | |
self._current_turn = 0 | |
self._next_player_index = 0 | |
self.message_pool.reset() | |
init_timestep = TimeStep( | |
observation=[], reward=self.get_zero_rewards(), terminal=False | |
) | |
return init_timestep | |
def phase_index(self): | |
return self._phase_index | |
def phase_index(self, value): | |
self._phase_index = value | |
def to_config(self) -> EnvironmentConfig: | |
return EnvironmentConfig( | |
env_type=self.type_name, | |
player_names=self.player_names, | |
parallel=self.parallel, | |
) | |
def print(self): | |
self.message_pool.print() | |
def get_next_player(self) -> str: | |
"""Get the next player.""" | |
return self.player_names[self._next_player_index] | |
def get_observation(self, player_name=None) -> List[Message]: | |
"""Get observation for the player.""" | |
if player_name is None: | |
return self.message_pool.get_all_messages() | |
else: | |
return self.message_pool.get_visible_messages( | |
player_name, turn=self._current_turn | |
) | |
def is_terminal(self) -> bool: | |
"""Check if the conversation is over.""" | |
# If the last message is the signal, then the conversation is over | |
if self.message_pool.last_message.content.startswith( | |
SIGNAL_END_OF_CONVERSATION | |
): | |
return True | |
def step(self, player_name: str, action: str) -> TimeStep: | |
""" | |
Step function that is called by the arena. | |
Args: | |
player_name: the name of the player that takes the action | |
action: the action that the agents wants to take | |
""" | |
message = Message( | |
agent_name=player_name, content=action, turn=self._current_turn | |
) | |
self.message_pool.append_message(message) | |
# Update the counters | |
if not self.parallel or self._next_player_index == 0: | |
self._current_turn += 1 | |
self._next_player_index = (self._next_player_index + 1) % self.num_players | |
timestep = TimeStep( | |
observation=self.get_observation(), | |
reward=self.get_zero_rewards(), | |
terminal=self.is_terminal(), | |
) # Return all the messages | |
return timestep | |
class ModeratedConversation(Conversation): | |
""" | |
Turn-based fully observable conversation environment. | |
Next speaker order is either parallel or round-robin. | |
Moderator is a special agent that can see all messages and can decide whether the conversation is over. | |
""" | |
type_name = "moderated_conversation" | |
def __init__( | |
self, | |
player_names: List[str], | |
moderator: Union[Moderator, AgentConfig], | |
parallel: bool = False, | |
moderator_visibility="all", | |
moderator_period=None, | |
**kwargs, | |
): | |
super().__init__(player_names=player_names, parallel=parallel, **kwargs) | |
if isinstance(moderator, AgentConfig): | |
moderator_config = moderator | |
moderator = Moderator.from_config(moderator_config) | |
elif not isinstance(moderator, Moderator): | |
raise ValueError( | |
"moderator must be either an AgentConfig or a Moderator instance." | |
) | |
self.moderator = moderator | |
self.moderator_visibility = moderator_visibility | |
if moderator_period is None: | |
if parallel: | |
self.moderator_period = "round" | |
else: | |
self.moderator_period = "turn" | |
else: | |
self.moderator_period = moderator_period | |
def to_config(self) -> EnvironmentConfig: | |
# This environment contains some special config arguments that needs to be handle specially | |
return EnvironmentConfig( | |
env_type=self.type_name, | |
player_names=self.player_names, | |
parallel=self.parallel, | |
moderator=self.moderator.to_config(), | |
moderator_visibility=self.moderator_visibility, | |
moderator_period=self.moderator_period, | |
) | |
def step(self, player_name: str, action: str) -> TimeStep: | |
""" | |
Step function that is called by the arena. | |
Args: | |
player_name: the name of the player that takes the action | |
action: the action that the agents wants to take | |
""" | |
message = Message( | |
agent_name=player_name, content=action, turn=self._current_turn | |
) | |
self.message_pool.append_message(message) | |
# Round-robin order for the next player | |
self._next_player_index = (self._next_player_index + 1) % self.num_players | |
if self.moderator_period == "turn" or ( | |
self.moderator_period == "round" and self._next_player_index == 0 | |
): | |
# Moderator's turn | |
moderator_history = self.message_pool.get_all_messages() | |
moderator_response = self.moderator(moderator_history) | |
moderator_message = Message( | |
agent_name=self.moderator.name, | |
content=moderator_response, | |
turn=self._current_turn, | |
visible_to=self.moderator_visibility, | |
) | |
self.message_pool.append_message(moderator_message) | |
terminal = ( | |
self.moderator.is_terminal(moderator_history) or self.is_terminal() | |
) | |
else: | |
terminal = self.is_terminal() | |
# Update the counters | |
if not self.parallel or self._next_player_index == 0: | |
self._current_turn += 1 | |
timestep = TimeStep( | |
observation=self.get_observation(), | |
reward=self.get_zero_rewards(), | |
terminal=terminal, | |
) # Return all the messages | |
return timestep | |