Yiqiao Jin
Initial Commit
bdafe83
from typing import List, Union
from ..agent import SIGNAL_END_OF_CONVERSATION, Moderator
from ..config import AgentConfig, EnvironmentConfig
from ..message import Message, MessagePool
from .base import Environment, TimeStep
class Conversation(Environment):
"""
Turn-based fully observable conversation environment.
Next speaker order is either parallel or round-robin.
"""
type_name = "conversation"
def __init__(self, player_names: List[str], parallel: bool = False, **kwargs):
super().__init__(player_names=player_names, parallel=parallel, **kwargs)
self.parallel = parallel
# The "state" of the environment is maintained by the message pool
self.message_pool = MessagePool()
self._current_turn = 0
self._next_player_index = 0
def reset(self):
self._current_turn = 0
self._next_player_index = 0
self.message_pool.reset()
init_timestep = TimeStep(
observation=[], reward=self.get_zero_rewards(), terminal=False
)
return init_timestep
@property
def phase_index(self):
return self._phase_index
@phase_index.setter
def phase_index(self, value):
self._phase_index = value
def to_config(self) -> EnvironmentConfig:
return EnvironmentConfig(
env_type=self.type_name,
player_names=self.player_names,
parallel=self.parallel,
)
def print(self):
self.message_pool.print()
def get_next_player(self) -> str:
"""Get the next player."""
return self.player_names[self._next_player_index]
def get_observation(self, player_name=None) -> List[Message]:
"""Get observation for the player."""
if player_name is None:
return self.message_pool.get_all_messages()
else:
return self.message_pool.get_visible_messages(
player_name, turn=self._current_turn
)
def is_terminal(self) -> bool:
"""Check if the conversation is over."""
# If the last message is the signal, then the conversation is over
if self.message_pool.last_message.content.startswith(
SIGNAL_END_OF_CONVERSATION
):
return True
def step(self, player_name: str, action: str) -> TimeStep:
"""
Step function that is called by the arena.
Args:
player_name: the name of the player that takes the action
action: the action that the agents wants to take
"""
message = Message(
agent_name=player_name, content=action, turn=self._current_turn
)
self.message_pool.append_message(message)
# Update the counters
if not self.parallel or self._next_player_index == 0:
self._current_turn += 1
self._next_player_index = (self._next_player_index + 1) % self.num_players
timestep = TimeStep(
observation=self.get_observation(),
reward=self.get_zero_rewards(),
terminal=self.is_terminal(),
) # Return all the messages
return timestep
class ModeratedConversation(Conversation):
"""
Turn-based fully observable conversation environment.
Next speaker order is either parallel or round-robin.
Moderator is a special agent that can see all messages and can decide whether the conversation is over.
"""
type_name = "moderated_conversation"
def __init__(
self,
player_names: List[str],
moderator: Union[Moderator, AgentConfig],
parallel: bool = False,
moderator_visibility="all",
moderator_period=None,
**kwargs,
):
super().__init__(player_names=player_names, parallel=parallel, **kwargs)
if isinstance(moderator, AgentConfig):
moderator_config = moderator
moderator = Moderator.from_config(moderator_config)
elif not isinstance(moderator, Moderator):
raise ValueError(
"moderator must be either an AgentConfig or a Moderator instance."
)
self.moderator = moderator
self.moderator_visibility = moderator_visibility
if moderator_period is None:
if parallel:
self.moderator_period = "round"
else:
self.moderator_period = "turn"
else:
self.moderator_period = moderator_period
def to_config(self) -> EnvironmentConfig:
# This environment contains some special config arguments that needs to be handle specially
return EnvironmentConfig(
env_type=self.type_name,
player_names=self.player_names,
parallel=self.parallel,
moderator=self.moderator.to_config(),
moderator_visibility=self.moderator_visibility,
moderator_period=self.moderator_period,
)
def step(self, player_name: str, action: str) -> TimeStep:
"""
Step function that is called by the arena.
Args:
player_name: the name of the player that takes the action
action: the action that the agents wants to take
"""
message = Message(
agent_name=player_name, content=action, turn=self._current_turn
)
self.message_pool.append_message(message)
# Round-robin order for the next player
self._next_player_index = (self._next_player_index + 1) % self.num_players
if self.moderator_period == "turn" or (
self.moderator_period == "round" and self._next_player_index == 0
):
# Moderator's turn
moderator_history = self.message_pool.get_all_messages()
moderator_response = self.moderator(moderator_history)
moderator_message = Message(
agent_name=self.moderator.name,
content=moderator_response,
turn=self._current_turn,
visible_to=self.moderator_visibility,
)
self.message_pool.append_message(moderator_message)
terminal = (
self.moderator.is_terminal(moderator_history) or self.is_terminal()
)
else:
terminal = self.is_terminal()
# Update the counters
if not self.parallel or self._next_player_index == 0:
self._current_turn += 1
timestep = TimeStep(
observation=self.get_observation(),
reward=self.get_zero_rewards(),
terminal=terminal,
) # Return all the messages
return timestep