import random from deciders.utils import get_completion import json from loguru import logger class TrajPromptSummarizer(): def __init__(self,args=None,logfile=None): self.args = args with open("./distillers/traj_summary_few_shot_examples.txt", 'r') as f: self.FEW_SHOT_EXAMPLES = f.read() if logfile: # logger.remove() logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' in x['message']) def generate_from_file(self, file_path,max_step_num=200): mem = [] with open(file_path, 'r') as infile: data = json.load(infile) for traj in data: traj_text = traj[0]['game_description']+'\n' traj_text += traj[0]['goal_description']+'\n' for transition in traj[-max_step_num:]: traj_text += transition['observation']+'\n' if type(eval(str(transition['action']))) == type([]): action = float(eval(str(transition['action']))[0])-1 else: action = transition['action'] traj_text += f"Action: {action}\n" traj_text += f"Reward: {transition['reward']}\n" traj_text += f"Your performance is: {transition['cum_reward']}\n" reflection = self.generate(traj_text, mem, max_len_mem=5) mem.append(reflection) return mem def _generate_summary_query(self, traj, memory): """Allows the Agent to reflect upon a past experience.""" query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. Summarize your trajectory and reasoning the relation between your policy and the obtained result. Here are two examples: {self.FEW_SHOT_EXAMPLES} {traj}""" if len(memory) > 0: query += '\n\nPlans from past attempts:\n' for i, m in enumerate(memory): query += f'Trial #{i}: {m}\n' query += '\n\nSummary:' return query def generate(self, traj, memory, max_len_mem=5): if len(memory)> max_len_mem: reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:]) else: reflection_query = self._generate_summary_query(traj, memory) reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version) logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.') logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.') return reflection