ewanlee's picture
Synced repo using 'sync_with_huggingface' Github Action
a2afd48
raw
history blame
11.4 kB
# This file contains functions for interacting with the ChatGPT model
import openai
from .gpt import gpt
from loguru import logger
from .parser import DISPARSERS, CONPARSERS
from langchain.output_parsers import PydanticOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from memory.env_history import EnvironmentHistory
import tiktoken
import json
import re
from .utils import run_chain, get_completion, get_chat
from gym.spaces import Discrete
class RandomAct():
def __init__(self, action_space):
self.action_space = action_space
def act(self, state_description, action_description, env_info, game_description=None, goal_description=None):
if isinstance(self.action_space, Discrete):
action = self.action_space.sample()+1
else:
action = self.action_space.sample()
return action, '', '', '', 0, 0
class NaiveAct(gpt):
def __init__(self, openai_key, action_space, args, prompts, distiller, temperature=0.0, max_tokens=2048, logger=None):
self.action_space = action_space
self.temperature = temperature
self.action_desc_dict = args.action_desc_dict
self.args = args
self.prompts = prompts
self.max_tokens = max_tokens
self.prompt_level = args.prompt_level
if args.gpt_version == "gpt-35-turbo":
model = "gpt-3.5-turbo"
else:
model = args.gpt_version
self.encoding = tiktoken.encoding_for_model(model)
super().__init__(args, openai_key)
self.distiller = distiller
self.fewshot_example_initialization(args.prompt_level, args.prompt_path, distiller = self.distiller)
if isinstance(self.action_space, Discrete):
self.default_action = 1
else:
self.default_action = [0 for ind in range(self.action_space.shape[0])]
self.parser = self._parser_initialization()
self.irr_game_description = ''
self.memory = []
self.env_history = EnvironmentHistory()
self.first_call = True
self.logger = logger
if self.prompt_level in [2, 4]:
self.memory = self.summarized_fewshot_example
if args.use_short_mem == 1:
self.use_short_mem = True
self.mem_num = self.args.short_mem_num
else:
self.use_short_mem = False
self.mem_num = 0
def num_tokens_from_string(self,string: str) -> int:
"""Returns the number of tokens in a text string."""
num_tokens = len(self.encoding.encode(string))
return num_tokens
def update_mem(self,):
traj = "Firstly, the description and the goal of the task will be provided. Please pay close attention to comprehend the information presented below.\n"
traj += "Task Description: " + self.game_description + '\n'
traj += "Goal Description: " + self.goal_description + '\n'
traj += self.action_description
traj += "Below is the historical data for this round of the game, which includes the state and corresponding action for each step.\n"
traj += str(self.env_history)
# print(traj)
self._update_mem(traj)
def _update_mem(self, traj):
my_reflection = self.distiller.generate(traj, self.memory)
self.memory.append(my_reflection)
self.env_history.reset()
def clear_mem(self):
self.update_mem()
self.pre_memory = []
self.post_memory = []
self.is_first = True
self.env_history.reset()
def _parser_initialization(self):
if isinstance(self.action_space, Discrete):
PARSERS = DISPARSERS
num_action = self.action_space.n
else:
PARSERS = CONPARSERS
num_action = self.action_space.shape[0]
if self.args.api_type == "azure":
autofixing_chat = AzureChatOpenAI(
openai_api_type=openai.api_type,
openai_api_version=openai.api_version,
openai_api_base=openai.api_base,
openai_api_key=openai.api_key,
deployment_name=self.args.gpt_version,
temperature=self.temperature,
max_tokens=self.max_tokens
)
elif self.args.api_type == "openai":
autofixing_chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key,model=self.args.gpt_version)
parser = PydanticOutputParser(pydantic_object=PARSERS[num_action])
autofixing_parser = OutputFixingParser.from_llm(
llm=autofixing_chat, parser=parser)
return autofixing_parser
def fewshot_example_initialization(self, level, path=None, distiller=None):
self.fewshot_example = []
self.irr_few_shot_examples = []
self.prompt_level = level
self.expert_knowledge = None
if level in [1,3]:
self.irr_few_shot_examples = self.prompts.TASK_IRRELEVANT_PROMPTS
elif level == 5:
if hasattr(self.prompts, "expert_prompt"):
self.expert_knowledge = self.prompts.expert_prompt
self.fewshot_example = self.prompts.PERCEPTRON_BASIC_FS_EXAMPLES
else:
self.irr_few_shot_examples = self.prompts.TASK_IRRELEVANT_PROMPTS
json_file = f'{path}_l{level}.json'
with open(json_file, 'r') as infile:
data = json.load(infile)
max_step_num = 0
for traj in data:
traj_text = traj[0]['game_description']
traj_text += traj[0]['goal_description']
for i, transition in enumerate(traj):
traj_text += transition['observation']
traj_text += f"> {transition['action']}"
traj_text += f"{transition.get('reward','')}\n"
one_traj_token = self.num_tokens_from_string(traj_text)
if one_traj_token > self.args.max_query_tokens:
max_step_num = i+1
break
traj_text += f"Your performance is: {transition['cum_reward']}"
if not max_step_num:
max_step_num = self.args.max_episode_len
self.summarized_fewshot_example = self.distiller.generate_from_file(json_file,max_step_num=max_step_num)
def response(self, state_description, action_description, env_info, game_description=None, goal_description=None, fewshot_examples=None):
if env_info['future_summary']:
prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\n{state_description}\n{env_info['future_summary']}\n{action_description} "
else:
prompt = f"{game_description}\n{goal_description}\n{fewshot_examples}\nCurrent {state_description}\n{action_description} "
prompt += "Please select an action based on the current game state and the information you get. You must select the appropriate action from the given action descriptions and cannot refrain from taking action or performing any prohibited actions. Your Action is: "
print(f"prompt is {prompt}")
# res = get_chat(prompt, self.args.api_type, self.args.gpt_version, self.temperature, self.max_tokens)
res = get_chat(prompt, api_type=self.args.api_type, model=self.args.gpt_version, engine=self.args.gpt_version, temperature=self.temperature, max_tokens=self.max_tokens)
# openai.ChatCompletion.create(
# engine=self.args.gpt_version,
# # model=self.args.gpt_version,
# prompt=prompt,
# temperature=self.temperature,
# max_tokens=self.max_tokens,
# )
return prompt, res
def _add_history_before_action(self, game_description, goal_description, state_description):
self.game_description = game_description
self.goal_description = goal_description
self.env_history.add("observation", state_description)
# limit the token used, or it may exceed the max token
if len(self.env_history):
one_history_token = self.num_tokens_from_string(self.env_history.get_one_history())
self.env_history.set_history(self.args.max_query_tokens // one_history_token)
def act(self, state_description, action_description, env_info, game_description=None, goal_description=None, logfile=None):
self._add_history_before_action(game_description, goal_description, state_description)
asking_round = 0
res = None
action = None
prompt = None
if not self.logger:
logger.remove()
self.logger = logger.add(logfile, colorize=True, enqueue=True)
if self.args.prompt_level == 5:
my_mem = ""
if self.fewshot_example:
my_mem += "Here are some examples of how you should complete a task."
for examples in self.fewshot_example:
my_mem += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
my_mem += '\nNow you are in the task.\n'
elif self.args.prompt_level in [2,3,4]:
my_mem = ""
if self.prompt_level == 2:
my_mem += 'I have collected a few trajectories from a random policy, and the summaries are listed below.'
elif self.prompt_level == 3:
my_mem += 'I have collected a few trajectories before, and the summaries are listed below.'
elif self.prompt_level == 4:
my_mem += 'I have collected a few trajectories from an expert policy, and the summaries are listed below.'
my_mem += self._read_mem()
else:
my_mem = ""
if self.use_short_mem:
if len(self.env_history) > 1:
my_mem += '\nSubsequently, I will offer pertinent guidance or information about the task. Please utilize this instruction to accomplish the given task effectively.'
my_mem += f"\nBelow are the latest {min(self.mem_num, len(self.env_history))} historical data entries:\n"
my_mem += f"{self.env_history.get_histories(self.mem_num)}"
prompt, response = self.response(state_description, action_description, env_info, game_description, goal_description, my_mem)
action_str = response
print(f'my anwser is {action_str}')
action = self.parser.parse(response).action
self._add_history_after_action(action)
self.logger.info(f'The GPT response is: {response}.')
self.logger.info(f'The optimal action is: {action}.')
if env_info.get('history'):
self.logger.info(f'History: {history_to_str(env_info["history"])}')
return action, prompt, response, 0, 0
def _read_mem(self, ):
memory = self.memory
mem_str = ""
if len(memory) > 5:
memory = memory[-5:]
if len(memory) > 0:
mem_str += '\nYour memory for the task below:'
for i, m in enumerate(memory):
mem_str += f'\nTrial {i}:\n{m.strip()}'
return mem_str
def _add_history_after_action(self, action):
self.env_history.add('action', action)