hzxwonder
update
2d75a44
raw
history blame
8.56 kB
import openai
from .misc import history_to_str
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.prompts.chat import (
PromptTemplate,
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain import LLMChain
from langchain.callbacks import FileCallbackHandler
from langchain.callbacks import get_openai_callback
from .act import NaiveAct
from memory.env_history import EnvironmentHistory
import tiktoken
from .utils import run_chain
from loguru import logger
class EXE(NaiveAct):
def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
self.pre_memory = []
self.post_memory = []
self.is_first = True
self.num_trails = args.num_trails
self.game_description = args.game_description
self.goal_description = args.goal_description
self.action_description = args.action_description
self.action_desc_dict = args.action_desc_dict
self.mem_num = args.short_mem_num
self.fixed_suggestion = fixed_suggestion
self.fixed_insight = fixed_insight
self._update_mem(None)
self.insight = ""
def num_tokens_from_string(self,string: str) -> int:
"""Returns the number of tokens in a text string."""
num_tokens = len(self.encoding.encode(string))
return num_tokens
def update_mem(self,):
traj = self.game_description
traj += self.goal_description
traj += self.action_description
traj += str(self.env_history)
self._update_mem(traj)
def clear_mem(self):
self.update_mem()
self.pre_memory = []
self.post_memory = []
self.is_first = True
self.env_history.reset()
# self._update_mem(None)
def _update_mem(self, traj):
if self.memory:
self.post_memory = self.memory
self.insight = self.distiller.generate_insight(self.post_memory)
else:
if not self.is_first:
summary = self.distiller.generate_summary(traj, self.post_memory)
self.post_memory.append(summary)
self.insight = self.distiller.generate_insight(self.post_memory)
else:
self.is_first = False
self.insight = ""
suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.insight, self.num_trails)
if self.fixed_suggestion:
suggestion = self.fixed_suggestion
if self.fixed_insight:
self.insight = self.fixed_insight
self.pre_memory.append(suggestion)
self.env_history.reset()
def _read_mem(self, ):
insight_str = ""
if self.insight:
insight_str += "The insights of the game are listed below: "
insight_str += f"{self.insight}\n"
suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
return insight_str + suggestion_str
def act(
self,
state_description,
action_description,
env_info,
game_description,
goal_description,
logfile=None,
):
self.game_description = game_description
self.goal_description = goal_description
self.env_history.add("observation", state_description)
if self.args.api_type == "azure":
chat = AzureChatOpenAI(
openai_api_type=openai.api_type,
openai_api_version=openai.api_version,
openai_api_base=openai.api_base,
openai_api_key=openai.api_key,
deployment_name=self.args.gpt_version,
temperature=self.temperature,
max_tokens=self.max_tokens
)
elif self.args.api_type == "openai":
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key, model=self.args.gpt_version)
# print(self.logger)
reply_format_description = \
"Your response should choose an optimal action from valid action list, and terminated with following format: "
# only task relevant examplesA
template = "Now you are completing a task."
template += "You need to carefully understand the description of the game. "
# TODO: few shot example handle
if self.irr_few_shot_examples:
template += "Here are some examples of how you should completing a task."
for examples in self.irr_few_shot_examples:
template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
template += "\n\nNow you are in the task.\n"
template += " {game_description}\n{action_description}\n{goal_description}"
template += "You are observing something and " \
"you need to choose the optimal action acoordingly."
template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
template += self._read_mem()
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
short_memory_template = HumanMessagePromptTemplate.from_template("{history}\nNext is the observation that the agent gets:\n{state_description}Please select an optimal action to gain higher rewards based on the current state and history. The action description is below: {action_description}. Please think step by step.")
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, short_memory_template])
if self.logger:
pass
else:
if logfile:
# logger.remove()
if self.first_call:
self.logger = logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
self.first_call = False
handler = FileCallbackHandler(logfile)
total_tokens, total_cost = 0, 0
max_think_times = 1
for i_think in range(max_think_times):
# chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=True)
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
with get_openai_callback() as cb:
response = run_chain(
chain,
game_description=game_description,
goal_description=goal_description,
action_description=action_description,
state_description = self.env_history.get_last_history(),
history=self.env_history.get_histories(self.mem_num),
format_instructions=self.parser.get_format_instructions(),
reply_format_description=reply_format_description,
max_token=self.max_tokens
)
total_tokens += cb.total_tokens
total_cost += cb.total_cost
action = self.parser.parse(response).action
self._add_history_after_action(action)
self.logger.info(f'The GPT response is: {response}.')
self.logger.info(f'The optimal action is: {action}.')
if self.pre_memory:
self.logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
if self.post_memory:
self.logger.info(f'The summary is: {self.post_memory[-1]}.')
if env_info.get('history'):
self.logger.info(f'History: {history_to_str(env_info["history"])}')
text_prompt = chat_prompt.format_messages(
game_description=game_description,
goal_description=goal_description,
action_description=action_description,
state_description = self.env_history.get_last_history(),
history=self.env_history.get_histories(self.mem_num),
format_instructions=self.parser.get_format_instructions(),
reply_format_description=reply_format_description,
)
text_prompt = f'{text_prompt[0].content}\n{text_prompt[1].content}'
return action, text_prompt, response, total_tokens, total_cost