import csv import json import logging from typing import Union from agentreview.arena import Arena, TooManyInvalidActions from agentreview.role_descriptions import get_reviewer_description from agentreview.utility.utils import format_metareviews from .agent import Player from .config import ArenaConfig from .environments import TimeStep, load_environment from .paper_review_player import PaperExtractorPlayer, AreaChair, Reviewer logger = logging.getLogger(__name__) class PaperReviewArena(Arena): """Arena for the paper review environment. """ # PaperReviewArena.from_config @classmethod def from_config(cls, config: Union[str, ArenaConfig]): """Create an arena from a config.""" # If config is a path, load the config if isinstance(config, str): config = ArenaConfig.load(config) global_prompt = config.get("global_prompt", None) # Create the players players = [] for player_config in config.players: # Add public_prompt to the player config if global_prompt is not None: player_config["global_prompt"] = global_prompt if player_config['name'].startswith("Paper Extractor"): player = PaperExtractorPlayer.from_config(player_config) elif player_config['name'].startswith("AC"): player = AreaChair.from_config(player_config) elif player_config['name'].startswith("Reviewer"): player = Reviewer.from_config(player_config) else: player = Player.from_config(player_config) players.append(player) # Check that the player names are unique player_names = [player.name for player in players] assert len(player_names) == len( set(player_names) ), f"Player names must be unique, current players: {[','.join(player_names)]}" # Create the environment config.environment[ "player_names" ] = player_names # add the player names to the environment config env = load_environment(config.environment) return cls(players, env, global_prompt=global_prompt) # PaperReviewArena.step() def step(self) -> TimeStep: """Take a step in the game: one player takes an action and the environment updates.""" # if self.environment.phase_index > 4 and self.args.task == "paper_review": # logger.info("Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for " # "Phase V. (AC makes decisions).") # return # # elif self.environment.phase_index > 5 and self.args.task == "paper_decision": # logger.info("Finishing the simulation for Phase V. (AC makes decisions).") # return player_name = self.environment.get_next_player() player = self.name_to_player[player_name] # get the player object observation = self.environment.get_observation( player_name ) # get the observation for the player timestep = None # try to take an action for a few times for i in range(self.invalid_actions_retry): # Update reviewer description for rebuttal if self.environment.phase_index == 3 and player.name.startswith("Reviewer"): logging.info("Update reviewers' role_desc for Phase 3 (reviewer_ac_discussion)") reviewer_index = int(player.name.split("Reviewer ")[1]) # reviewer_index starts from 1, so we need to subtract 1 to get the index of the reviewer in the list player.role_desc = get_reviewer_description(phase="reviewer_ac_discussion", **self.environment.experiment_setting["players"][ 'Reviewer'][reviewer_index - 1]) elif self.environment.phase_index == 5: # Phase 5 AC Makes Decisions player.role_desc += format_metareviews(self.environment.metareviews, self.environment.paper_ids) action = player(observation) # take an action if self.environment.check_action(action, player_name): # action is valid timestep = self.environment.step( player_name, action ) # update the environment break else: # action is invalid logging.warning(f"{player_name} made an invalid action {action}") continue if ( timestep is None ): # if the player made invalid actions for too many times, terminate the game warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game." logging.warning(warning_msg) raise TooManyInvalidActions(warning_msg) return timestep def save_history(self, path: str): """ Save the history of the game to a file. Supports csv and json formats. """ messages = self.environment.get_observation() message_rows = [] if path.endswith(".csv"): header = [ "agent_name", "content", "turn", "timestamp", "visible_to", "msg_type", ] for message in messages: message_row = [ message.agent_name, message.content, message.turn, str(message.timestamp), message.visible_to, message.msg_type, ] message_rows.append(message_row) with open(path, "w") as f: writer = csv.writer(f) writer.writerow(header) writer.writerows(message_rows) elif path.endswith(".json"): for message in messages: message_row = { "agent_name": message.agent_name, "content": message.content, "turn": message.turn, "timestamp": str(message.timestamp), "visible_to": message.visible_to, "msg_type": message.msg_type, } message_rows.append(message_row) with open(path, "w") as f: json.dump({ "experiment_setting": self.environment.experiment_setting, "messages": message_rows, }, f, indent=2) else: raise ValueError("Invalid file format")