Spaces:
Running
Running
File size: 6,802 Bytes
bdafe83 53709ed bdafe83 a06e98d bdafe83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import csv
import json
import logging
from typing import Union
from agentreview.arena import Arena, TooManyInvalidActions
from agentreview.role_descriptions import get_reviewer_description
from agentreview.utility.utils import format_metareviews
from .agent import Player
from .config import ArenaConfig
from .environments import TimeStep, load_environment
from .paper_review_player import PaperExtractorPlayer, AreaChair, Reviewer
logger = logging.getLogger(__name__)
class PaperReviewArena(Arena):
"""Arena for the paper review environment.
"""
# PaperReviewArena.from_config
@classmethod
def from_config(cls, config: Union[str, ArenaConfig]):
"""Create an arena from a config."""
# If config is a path, load the config
if isinstance(config, str):
config = ArenaConfig.load(config)
global_prompt = config.get("global_prompt", None)
# Create the players
players = []
for player_config in config.players:
# Add public_prompt to the player config
if global_prompt is not None:
player_config["global_prompt"] = global_prompt
if player_config['name'].startswith("Paper Extractor"):
player = PaperExtractorPlayer.from_config(player_config)
elif player_config['name'].startswith("AC"):
player = AreaChair.from_config(player_config)
elif player_config['name'].startswith("Reviewer"):
player = Reviewer.from_config(player_config)
else:
player = Player.from_config(player_config)
players.append(player)
# Check that the player names are unique
player_names = [player.name for player in players]
assert len(player_names) == len(
set(player_names)
), f"Player names must be unique, current players: {[','.join(player_names)]}"
# Create the environment
config.environment[
"player_names"
] = player_names # add the player names to the environment config
env = load_environment(config.environment)
return cls(players, env, global_prompt=global_prompt)
# PaperReviewArena.step()
def step(self) -> TimeStep:
"""Take a step in the game: one player takes an action and the environment updates."""
# if self.environment.phase_index > 4 and self.args.task == "paper_review":
# logger.info("Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for "
# "Phase V. (AC makes decisions).")
# return
#
# elif self.environment.phase_index > 5 and self.args.task == "paper_decision":
# logger.info("Finishing the simulation for Phase V. (AC makes decisions).")
# return
player_name = self.environment.get_next_player()
player = self.name_to_player[player_name] # get the player object
observation = self.environment.get_observation(
player_name
) # get the observation for the player
timestep = None
# try to take an action for a few times
for i in range(self.invalid_actions_retry):
# Update reviewer description for rebuttal
if self.environment.phase_index == 3 and player.name.startswith("Reviewer"):
logging.info("Update reviewers' role_desc for Phase 3 (reviewer_ac_discussion)")
reviewer_index = int(player.name.split("Reviewer ")[1])
# reviewer_index starts from 1, so we need to subtract 1 to get the index of the reviewer in the list
player.role_desc = get_reviewer_description(phase="reviewer_ac_discussion",
**self.environment.experiment_setting["players"][
'Reviewer'][reviewer_index - 1])
elif self.environment.phase_index == 5: # Phase 5 AC Makes Decisions
player.role_desc += format_metareviews(self.environment.metareviews, self.environment.paper_ids)
action = player(observation) # take an action
if self.environment.check_action(action, player_name): # action is valid
timestep = self.environment.step(
player_name, action
) # update the environment
break
else: # action is invalid
logging.warning(f"{player_name} made an invalid action {action}")
continue
if (
timestep is None
): # if the player made invalid actions for too many times, terminate the game
warning_msg = f"{player_name} has made invalid actions for {self.invalid_actions_retry} times. Terminating the game."
logging.warning(warning_msg)
raise TooManyInvalidActions(warning_msg)
return timestep
def save_history(self, path: str):
"""
Save the history of the game to a file.
Supports csv and json formats.
"""
messages = self.environment.get_observation()
message_rows = []
if path.endswith(".csv"):
header = [
"agent_name",
"content",
"turn",
"timestamp",
"visible_to",
"msg_type",
]
for message in messages:
message_row = [
message.agent_name,
message.content,
message.turn,
str(message.timestamp),
message.visible_to,
message.msg_type,
]
message_rows.append(message_row)
with open(path, "w") as f:
writer = csv.writer(f)
writer.writerow(header)
writer.writerows(message_rows)
elif path.endswith(".json"):
for message in messages:
message_row = {
"agent_name": message.agent_name,
"content": message.content,
"turn": message.turn,
"timestamp": str(message.timestamp),
"visible_to": message.visible_to,
"msg_type": message.msg_type,
}
message_rows.append(message_row)
with open(path, "w") as f:
json.dump({
"experiment_setting": self.environment.experiment_setting,
"messages": message_rows,
}, f, indent=2)
else:
raise ValueError("Invalid file format")
|