Spaces:
Running
Running
import csv | |
import json | |
import logging | |
import uuid | |
from typing import Dict, List, Union | |
from .agent import Player | |
from .backends import Human | |
from .config import ArenaConfig | |
from .environments import Environment, TimeStep, load_environment | |
class TooManyInvalidActions(Exception): | |
pass | |
class Arena: | |
"""Utility class that manages the game environment and players.""" | |
def __init__( | |
self, players: List[Player], environment: Environment, args, global_prompt: str = None | |
): | |
# Create a container for the players and environment and reset the game | |
self.players = players | |
self.environment = environment | |
self.global_prompt = global_prompt | |
self.current_timestep = environment.reset() | |
self.uuid = uuid.uuid4() # Generate a unique id for the game | |
self.invalid_actions_retry = 5 | |
self.args = args | |
def num_players(self): | |
return self.environment.num_players | |
def name_to_player(self) -> Dict[str, Player]: | |
return {player.name: player for player in self.players} | |
def reset(self) -> TimeStep: | |
# Reset the environment | |
self.current_timestep = self.environment.reset() | |
# Reset the players | |
for player in self.players: | |
player.reset() | |
# Reset the uuid | |
self.uuid = uuid.uuid4() | |
return self.current_timestep | |
def step(self) -> TimeStep: | |
"""Take a step in the game: one player takes an action and the environment updates.""" | |
player_name = self.environment.get_next_player() | |
player = self.name_to_player[player_name] # get the player object | |
observation = self.environment.get_observation( | |
player_name | |
) # get the observation for the player | |
timestep = None | |
for i in range( | |
self.invalid_actions_retry | |
): # try to take an action for a few times | |
action = player(observation) # take an action | |
if self.environment.check_action(action, player_name): # action is valid | |
timestep = self.environment.step( | |
player_name, action | |
) # update the environment | |
break | |
else: # action is invalid | |
logging.warning(f"{player_name} made an invalid action {action}") | |
continue | |
if ( | |
timestep is None | |
): # if the player made invalid actions for too many times, terminate the game | |
warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game." | |
logging.warning(warning_msg) | |
raise TooManyInvalidActions(warning_msg) | |
return timestep | |
def next_is_human(self): | |
"""Check if the next player is human.""" | |
player_name = self.environment.get_next_player() | |
player = self.name_to_player[player_name] | |
return isinstance(player.backend, Human) | |
def run(self, num_steps: int = 1): | |
"""Run the game for num_turns.""" | |
for i in range(num_steps): | |
timestep = self.step() | |
if timestep.terminal: | |
break | |
def from_config(cls, config: Union[str, ArenaConfig]): | |
"""Create an arena from a config.""" | |
# If config is a path, load the config | |
if isinstance(config, str): | |
config = ArenaConfig.load(config) | |
global_prompt = config.get("global_prompt", None) | |
# Create the players | |
players = [] | |
for player_config in config.players: | |
# Add public_prompt to the player config | |
if global_prompt is not None: | |
player_config["global_prompt"] = global_prompt | |
player = Player.from_config(player_config) | |
players.append(player) | |
# Check that the player names are unique | |
player_names = [player.name for player in players] | |
assert len(player_names) == len( | |
set(player_names) | |
), "Player names must be unique" | |
# Create the environment | |
config.environment[ | |
"player_names" | |
] = player_names # add the player names to the environment config | |
env = load_environment(config.environment) | |
return cls(players, env, global_prompt=global_prompt) | |
def to_config(self) -> ArenaConfig: | |
"""Convert the arena to a config.""" | |
# return { | |
# "players": [player.to_config() for player in self.players], | |
# "environment": self.environment.to_config(), | |
# "global_prompt": self.global_prompt | |
# } | |
return ArenaConfig( | |
players=[player.to_config() for player in self.players], | |
environment=self.environment.to_config(), | |
global_prompt=self.global_prompt, | |
) | |
def launch_cli(self, max_steps: int = None, interactive: bool = True): | |
"""Launch the command line interface.""" | |
from agentreview.ui.cli import ArenaCLI | |
cli = ArenaCLI(self) | |
cli.launch(max_steps=max_steps, interactive=interactive) | |
def save_config(self, path: str): | |
"""Save the config to a file.""" | |
config = self.to_config() | |
config.save(path) | |
def save_history(self, path: str): | |
""" | |
Save the history of the game to a file. | |
Supports csv and json formats. | |
""" | |
messages = self.environment.get_observation() | |
message_rows = [] | |
if path.endswith(".csv"): | |
header = [ | |
"agent_name", | |
"content", | |
"turn", | |
"timestamp", | |
"visible_to", | |
"msg_type", | |
] | |
for message in messages: | |
message_row = [ | |
message.agent_name, | |
message.content, | |
message.turn, | |
str(message.timestamp), | |
message.visible_to, | |
message.msg_type, | |
] | |
message_rows.append(message_row) | |
with open(path, "w") as f: | |
writer = csv.writer(f) | |
writer.writerow(header) | |
writer.writerows(message_rows) | |
elif path.endswith(".json"): | |
for message in messages: | |
message_row = { | |
"agent_name": message.agent_name, | |
"content": message.content, | |
"turn": message.turn, | |
"timestamp": str(message.timestamp), | |
"visible_to": message.visible_to, | |
"msg_type": message.msg_type, | |
} | |
message_rows.append(message_row) | |
with open(path, "w") as f: | |
json.dump(message_rows, f, indent=2) | |
else: | |
raise ValueError("Invalid file format") | |