import json import logging import os.path as osp from typing import List from agentreview.environments import Conversation from agentreview.utility.utils import get_rebuttal_dir from .base import TimeStep from ..message import Message from ..paper_review_message import PaperReviewMessagePool logger = logging.getLogger(__name__) class PaperReview(Conversation): """ Discussion between reviewers and area chairs. There are several phases in the reviewing process: reviewer_write_reviews: reviewers write their reviews based on the paper content. author_reviewer_discussion: An author respond to comments from the reviewers. reviewer_ac_discussion: reviewers and an area chair discuss the paper. ac_discussion: an area chair makes the final decision. """ type_name = "paper_review" def __init__(self, player_names: List[str], paper_id: int, paper_decision: str, experiment_setting: dict, args, parallel: bool = False, **kwargs): """ Args: paper_id (int): the id of the paper, such as 917 paper_decision (str): the decision of the paper, such as "Accept: notable-top-25%" """ # Inherit from the parent class of `class Conversation` super(Conversation, self).__init__(player_names=player_names, parallel=parallel, **kwargs) self.args = args self.paper_id = paper_id self.paper_decision = paper_decision self.parallel = parallel self.experiment_setting = experiment_setting self.player_to_test = experiment_setting.get('player_to_test', None) self.task = kwargs.get("task") self.experiment_name = args.experiment_name # The "state" of the environment is maintained by the message pool self.message_pool = PaperReviewMessagePool(experiment_setting) self.phase_index = 0 self._phases = None @property def phases(self): if self._phases is not None: return self._phases reviewer_names = [name for name in self.player_names if name.startswith("Reviewer")] num_reviewers = len(reviewer_names) reviewer_names = [f"Reviewer {i}" for i in range(1, num_reviewers + 1)] self._phases = { # In phase 0, no LLM-based agents are called. 0: { "name": "paper_extraction", 'speaking_order': ["Paper Extractor"], }, 1: { "name": 'reviewer_write_reviews', 'speaking_order': reviewer_names }, # The author responds to each reviewer's review 2: { 'name': 'author_reviewer_discussion', 'speaking_order': ["Author" for _ in reviewer_names], }, 3: { 'name': 'reviewer_ac_discussion', 'speaking_order': ["AC"] + reviewer_names, }, 4: { 'name': 'ac_write_metareviews', 'speaking_order': ["AC"] }, 5: { 'name': 'ac_makes_decisions', 'speaking_order': ["AC"] }, } return self.phases @phases.setter def phases(self, value): self._phases = value def reset(self): self._current_phase = "review" self.phase_index = 0 return super().reset() def load_message_history_from_cache(self): if self._phase_index == 0: print("Loading message history from BASELINE experiment") full_paper_discussion_path = get_rebuttal_dir(paper_id=self.paper_id, experiment_name="BASELINE", model_name=self.args.model_name, conference=self.args.conference) messages = json.load(open(osp.join(full_paper_discussion_path, f"{self.paper_id}.json"), 'r', encoding='utf-8'))['messages'] num_messages_from_AC = 0 for msg in messages: # We have already extracted contents from the paper. if msg['agent_name'] == "Paper Extractor": continue # Encountering the 2nd message from the AC. Stop loading messages. if msg['agent_name'] == "AC" and num_messages_from_AC == 1: break if msg['agent_name'] == "AC": num_messages_from_AC += 1 message = Message(**msg) self.message_pool.append_message(message) num_unique_reviewers = len( set([msg['agent_name'] for msg in messages if msg['agent_name'].startswith("Reviewer")])) assert num_unique_reviewers == self.args.num_reviewers_per_paper self._phase_index = 4 def step(self, player_name: str, action: str) -> TimeStep: """ Step function that is called by the arena. Args: player_name: the name of the player that takes the action action: the action that the agents wants to take """ message = Message( agent_name=player_name, content=action, turn=self._current_turn ) self.message_pool.append_message(message) speaking_order = self.phases[self.phase_index]["speaking_order"] # Reached the end of the speaking order. Move to the next phase. logging.info(f"Phase {self.phase_index}: {self.phases[self._phase_index]['name']} " f"| Player {self._next_player_index}: {speaking_order[self._next_player_index]}") terminal = self.is_terminal() if self._next_player_index == len(speaking_order) - 1: self._next_player_index = 0 if self.phase_index == 4: terminal = True logger.info( "Finishing the simulation for Phase I - IV. Please run `python run_paper_decision_cli.py ` for " "Phase V. (AC makes decisions).") else: print(f"Phase {self.phase_index}: end of the speaking order. Move to Phase ({self.phase_index + 1}).") self.phase_index += 1 self._current_turn += 1 else: self._next_player_index += 1 timestep = TimeStep( observation=self.get_observation(), reward=self.get_zero_rewards(), terminal=terminal, ) # Return all the messages return timestep def get_next_player(self) -> str: """Get the next player in the current phase.""" speaking_order = self.phases[self.phase_index]["speaking_order"] next_player = speaking_order[self._next_player_index] return next_player def get_observation(self, player_name=None) -> List[Message]: """Get observation for the player.""" if player_name is None: return self.message_pool.get_all_messages() else: return self.message_pool.get_visible_messages_for_paper_review( player_name, phase_index=self.phase_index, next_player_idx=self._next_player_index, player_names=self.player_names ) def get_messages_from_player(self, player_name: str) -> List[str]: """Get the list of actions that the player can take.""" return self.message_pool.get_messages_from_player(player_name)