Spaces:
Running
Running
import csv | |
import json | |
import logging | |
from typing import Union | |
from agentreview.arena import Arena, TooManyInvalidActions | |
from agentreview.role_descriptions import get_reviewer_description | |
from agentreview.utility.utils import format_metareviews | |
from .agent import Player | |
from .config import ArenaConfig | |
from .environments import TimeStep, load_environment | |
from .paper_review_player import PaperExtractorPlayer, AreaChair, Reviewer | |
logger = logging.getLogger(__name__) | |
class PaperReviewArena(Arena): | |
"""Arena for the paper review environment. | |
""" | |
# PaperReviewArena.from_config | |
def from_config(cls, config: Union[str, ArenaConfig]): | |
"""Create an arena from a config.""" | |
# If config is a path, load the config | |
if isinstance(config, str): | |
config = ArenaConfig.load(config) | |
global_prompt = config.get("global_prompt", None) | |
# Create the players | |
players = [] | |
for player_config in config.players: | |
# Add public_prompt to the player config | |
if global_prompt is not None: | |
player_config["global_prompt"] = global_prompt | |
if player_config['name'].startswith("Paper Extractor"): | |
player = PaperExtractorPlayer.from_config(player_config) | |
elif player_config['name'].startswith("AC"): | |
player = AreaChair.from_config(player_config) | |
elif player_config['name'].startswith("Reviewer"): | |
player = Reviewer.from_config(player_config) | |
else: | |
player = Player.from_config(player_config) | |
players.append(player) | |
# Check that the player names are unique | |
player_names = [player.name for player in players] | |
assert len(player_names) == len( | |
set(player_names) | |
), f"Player names must be unique, current players: {[','.join(player_names)]}" | |
# Create the environment | |
config.environment[ | |
"player_names" | |
] = player_names # add the player names to the environment config | |
env = load_environment(config.environment) | |
return cls(players, env, global_prompt=global_prompt) | |
# PaperReviewArena.step() | |
def step(self) -> TimeStep: | |
"""Take a step in the game: one player takes an action and the environment updates.""" | |
# if self.environment.phase_index > 4 and self.args.task == "paper_review": | |
# logger.info("Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for " | |
# "Phase V. (AC makes decisions).") | |
# return | |
# | |
# elif self.environment.phase_index > 5 and self.args.task == "paper_decision": | |
# logger.info("Finishing the simulation for Phase V. (AC makes decisions).") | |
# return | |
player_name = self.environment.get_next_player() | |
player = self.name_to_player[player_name] # get the player object | |
observation = self.environment.get_observation( | |
player_name | |
) # get the observation for the player | |
timestep = None | |
# try to take an action for a few times | |
for i in range(self.invalid_actions_retry): | |
# Update reviewer description for rebuttal | |
if self.environment.phase_index == 3 and player.name.startswith("Reviewer"): | |
logging.info("Update reviewers' role_desc for Phase 3 (reviewer_ac_discussion)") | |
reviewer_index = int(player.name.split("Reviewer ")[1]) | |
# reviewer_index starts from 1, so we need to subtract 1 to get the index of the reviewer in the list | |
player.role_desc = get_reviewer_description(phase="reviewer_ac_discussion", | |
**self.environment.experiment_setting["players"][ | |
'Reviewer'][reviewer_index - 1]) | |
elif self.environment.phase_index == 5: # Phase 5 AC Makes Decisions | |
player.role_desc += format_metareviews(self.environment.metareviews, self.environment.paper_ids) | |
action = player(observation) # take an action | |
if self.environment.check_action(action, player_name): # action is valid | |
timestep = self.environment.step( | |
player_name, action | |
) # update the environment | |
break | |
else: # action is invalid | |
logging.warning(f"{player_name} made an invalid action {action}") | |
continue | |
if ( | |
timestep is None | |
): # if the player made invalid actions for too many times, terminate the game | |
warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game." | |
logging.warning(warning_msg) | |
raise TooManyInvalidActions(warning_msg) | |
return timestep | |
def save_history(self, path: str): | |
""" | |
Save the history of the game to a file. | |
Supports csv and json formats. | |
""" | |
messages = self.environment.get_observation() | |
message_rows = [] | |
if path.endswith(".csv"): | |
header = [ | |
"agent_name", | |
"content", | |
"turn", | |
"timestamp", | |
"visible_to", | |
"msg_type", | |
] | |
for message in messages: | |
message_row = [ | |
message.agent_name, | |
message.content, | |
message.turn, | |
str(message.timestamp), | |
message.visible_to, | |
message.msg_type, | |
] | |
message_rows.append(message_row) | |
with open(path, "w") as f: | |
writer = csv.writer(f) | |
writer.writerow(header) | |
writer.writerows(message_rows) | |
elif path.endswith(".json"): | |
for message in messages: | |
message_row = { | |
"agent_name": message.agent_name, | |
"content": message.content, | |
"turn": message.turn, | |
"timestamp": str(message.timestamp), | |
"visible_to": message.visible_to, | |
"msg_type": message.msg_type, | |
} | |
message_rows.append(message_row) | |
with open(path, "w") as f: | |
json.dump({ | |
"experiment_setting": self.environment.experiment_setting, | |
"messages": message_rows, | |
}, f, indent=2) | |
else: | |
raise ValueError("Invalid file format") | |