Yiqiao Jin
Initial Commit
bdafe83
raw
history blame
6.92 kB
import csv
import json
import logging
import uuid
from typing import Dict, List, Union
from .agent import Player
from .backends import Human
from .config import ArenaConfig
from .environments import Environment, TimeStep, load_environment
class TooManyInvalidActions(Exception):
pass
class Arena:
"""Utility class that manages the game environment and players."""
def __init__(
self, players: List[Player], environment: Environment, args, global_prompt: str = None
):
# Create a container for the players and environment and reset the game
self.players = players
self.environment = environment
self.global_prompt = global_prompt
self.current_timestep = environment.reset()
self.uuid = uuid.uuid4() # Generate a unique id for the game
self.invalid_actions_retry = 5
self.args = args
@property
def num_players(self):
return self.environment.num_players
@property
def name_to_player(self) -> Dict[str, Player]:
return {player.name: player for player in self.players}
def reset(self) -> TimeStep:
# Reset the environment
self.current_timestep = self.environment.reset()
# Reset the players
for player in self.players:
player.reset()
# Reset the uuid
self.uuid = uuid.uuid4()
return self.current_timestep
def step(self) -> TimeStep:
"""Take a step in the game: one player takes an action and the environment updates."""
player_name = self.environment.get_next_player()
player = self.name_to_player[player_name] # get the player object
observation = self.environment.get_observation(
player_name
) # get the observation for the player
timestep = None
for i in range(
self.invalid_actions_retry
): # try to take an action for a few times
action = player(observation) # take an action
if self.environment.check_action(action, player_name): # action is valid
timestep = self.environment.step(
player_name, action
) # update the environment
break
else: # action is invalid
logging.warning(f"{player_name} made an invalid action {action}")
continue
if (
timestep is None
): # if the player made invalid actions for too many times, terminate the game
warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game."
logging.warning(warning_msg)
raise TooManyInvalidActions(warning_msg)
return timestep
def next_is_human(self):
"""Check if the next player is human."""
player_name = self.environment.get_next_player()
player = self.name_to_player[player_name]
return isinstance(player.backend, Human)
def run(self, num_steps: int = 1):
"""Run the game for num_turns."""
for i in range(num_steps):
timestep = self.step()
if timestep.terminal:
break
@classmethod
def from_config(cls, config: Union[str, ArenaConfig]):
"""Create an arena from a config."""
# If config is a path, load the config
if isinstance(config, str):
config = ArenaConfig.load(config)
global_prompt = config.get("global_prompt", None)
# Create the players
players = []
for player_config in config.players:
# Add public_prompt to the player config
if global_prompt is not None:
player_config["global_prompt"] = global_prompt
player = Player.from_config(player_config)
players.append(player)
# Check that the player names are unique
player_names = [player.name for player in players]
assert len(player_names) == len(
set(player_names)
), "Player names must be unique"
# Create the environment
config.environment[
"player_names"
] = player_names # add the player names to the environment config
env = load_environment(config.environment)
return cls(players, env, global_prompt=global_prompt)
def to_config(self) -> ArenaConfig:
"""Convert the arena to a config."""
# return {
# "players": [player.to_config() for player in self.players],
# "environment": self.environment.to_config(),
# "global_prompt": self.global_prompt
# }
return ArenaConfig(
players=[player.to_config() for player in self.players],
environment=self.environment.to_config(),
global_prompt=self.global_prompt,
)
def launch_cli(self, max_steps: int = None, interactive: bool = True):
"""Launch the command line interface."""
from agentreview.ui.cli import ArenaCLI
cli = ArenaCLI(self)
cli.launch(max_steps=max_steps, interactive=interactive)
def save_config(self, path: str):
"""Save the config to a file."""
config = self.to_config()
config.save(path)
def save_history(self, path: str):
"""
Save the history of the game to a file.
Supports csv and json formats.
"""
messages = self.environment.get_observation()
message_rows = []
if path.endswith(".csv"):
header = [
"agent_name",
"content",
"turn",
"timestamp",
"visible_to",
"msg_type",
]
for message in messages:
message_row = [
message.agent_name,
message.content,
message.turn,
str(message.timestamp),
message.visible_to,
message.msg_type,
]
message_rows.append(message_row)
with open(path, "w") as f:
writer = csv.writer(f)
writer.writerow(header)
writer.writerows(message_rows)
elif path.endswith(".json"):
for message in messages:
message_row = {
"agent_name": message.agent_name,
"content": message.content,
"turn": message.turn,
"timestamp": str(message.timestamp),
"visible_to": message.visible_to,
"msg_type": message.msg_type,
}
message_rows.append(message_row)
with open(path, "w") as f:
json.dump(message_rows, f, indent=2)
else:
raise ValueError("Invalid file format")