Text-Gym-Agents / distillers /traj_prompt_summarizer.py
hzxwonder
update
c640769
import random
from deciders.utils import get_completion
import json
from loguru import logger
class TrajPromptSummarizer():
def __init__(self,args=None,logfile=None):
self.args = args
with open("./distillers/traj_summary_few_shot_examples.txt", 'r') as f:
self.FEW_SHOT_EXAMPLES = f.read()
if logfile:
# logger.remove()
logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' in x['message'])
def generate_from_file(self, file_path,max_step_num=200):
mem = []
with open(file_path, 'r') as infile:
data = json.load(infile)
for traj in data:
traj_text = traj[0]['game_description']+'\n'
traj_text += traj[0]['goal_description']+'\n'
for transition in traj[-max_step_num:]:
traj_text += transition['observation']+'\n'
if type(eval(str(transition['action']))) == type([]):
action = float(eval(str(transition['action']))[0])-1
else:
action = transition['action']
traj_text += f"Action: {action}\n"
traj_text += f"Reward: {transition['reward']}\n"
traj_text += f"Your performance is: {transition['cum_reward']}\n"
reflection = self.generate(traj_text, mem, max_len_mem=5)
mem.append(reflection)
return mem
def _generate_summary_query(self, traj, memory):
"""Allows the Agent to reflect upon a past experience."""
query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. Summarize your trajectory and reasoning the relation between your policy and the obtained result. Here are two examples:
{self.FEW_SHOT_EXAMPLES}
{traj}"""
if len(memory) > 0:
query += '\n\nPlans from past attempts:\n'
for i, m in enumerate(memory):
query += f'Trial #{i}: {m}\n'
query += '\n\nSummary:'
return query
def generate(self, traj, memory, max_len_mem=5):
if len(memory)> max_len_mem:
reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
else:
reflection_query = self._generate_summary_query(traj, memory)
reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
return reflection