SocialAISchool / scripts /LLM_test.py
grg's picture
Moving LLM obs to text in textworld utils, bugfixes.
11bd154
raw
history blame
34.1 kB
import argparse
import json
import requests
import time
import warnings
from n_tokens import estimate_price
import pickle
import numpy as np
import torch
from pathlib import Path
# from utils.babyai_utils.baby_agent import load_agent
from utils import *
from textworld_utils.utils import generate_text_obs
from models import *
import subprocess
import os
from matplotlib import pyplot as plt
from gym_minigrid.wrappers import *
from gym_minigrid.window import Window
from datetime import datetime
from imageio import mimsave
def new_episode_marker():
return "New episode.\n"
def success_marker():
return "Success!\n"
def failure_marker():
return "Failure!\n"
def action_query():
return "Act :"
def get_parsed_action(text_action):
"""
Parses the text generated by a model and extracts the action
"""
if "move forward" in text_action:
return "move forward"
elif "done" in text_action:
return "done"
elif "turn left" in text_action:
return "turn left"
elif "turn right" in text_action:
return "turn right"
elif "toggle" in text_action:
return "toggle"
elif "no_op" in text_action:
return "no_op"
else:
warnings.warn(f"Undefined action {text_action}")
return "no_op"
def action_to_prompt_action_text(action):
if np.allclose(action, [int(env.actions.forward), np.nan, np.nan], equal_nan=True):
# 2
text_action = "move forward"
elif np.allclose(action, [int(env.actions.left), np.nan, np.nan], equal_nan=True):
# 0
text_action = "turn left"
elif np.allclose(action, [int(env.actions.right), np.nan, np.nan], equal_nan=True):
# 1
text_action = "turn right"
elif np.allclose(action, [int(env.actions.toggle), np.nan, np.nan], equal_nan=True):
# 3
text_action = "toggle"
elif np.allclose(action, [int(env.actions.done), np.nan, np.nan], equal_nan=True):
# 4
text_action = "done"
elif np.allclose(action, [np.nan, np.nan, np.nan], equal_nan=True):
text_action = "no_op"
else:
warnings.warn(f"Undefined action {action}")
return "no_op"
return f"{action_query()} {text_action}\n"
def text_action_to_action(text_action):
# text_action = get_parsed_action(text_action)
if "move forward" == text_action:
action = [int(env.actions.forward), np.nan, np.nan]
elif "turn left" == text_action:
action = [int(env.actions.left), np.nan, np.nan]
elif "turn right" == text_action:
action = [int(env.actions.right), np.nan, np.nan]
elif "toggle" == text_action:
action = [int(env.actions.toggle), np.nan, np.nan]
elif "done" == text_action:
action = [int(env.actions.done), np.nan, np.nan]
elif "no_op" == text_action:
action = [np.nan, np.nan, np.nan]
return action
def prompt_preprocessor(llm_prompt):
# remove peer observations
lines = llm_prompt.split("\n")
new_lines = []
for line in lines:
if line.startswith("#"):
continue
elif line.startswith("Conversation"):
continue
elif "peer" in line:
caretaker = True
if caretaker:
# show only the location of the caretaker
# this is very ugly, todo: refactor this
assert "there is a" in line
start_index = line.index('there is a') + 11
new_line = line[:start_index] + 'caretaker'
new_lines.append(new_line)
else:
# no caretaker at all
if line.startswith("Obs :") and "peer" in line:
# remove only the peer descriptions
line = "Obs :"
new_lines.append(line)
else:
assert "peer" in line
elif "Caretaker:" in line:
line = line.replace("Caretaker:", "Caretaker says: ")
new_lines.append(line)
else:
new_lines.append(line)
return "\n".join(new_lines)
# def generate_text_obs(obs, info):
#
# text_observation = obs_to_text(info)
#
# llm_prompt = "Obs : "
# llm_prompt += "".join(text_observation)
#
# # add utterances
# if obs["utterance_history"] != "Conversation: \n":
# utt_hist = obs['utterance_history']
# utt_hist = utt_hist.replace("Conversation: \n","")
# llm_prompt += utt_hist
#
# return llm_prompt
# def obs_to_text(info):
# image, vis_mask = info["image"], info["vis_mask"]
# carrying = info["carrying"]
# agent_pos_vx, agent_pos_vy = info["agent_pos_vx"], info["agent_pos_vy"]
# npc_actions_dict = info["npc_actions_dict"]
#
# # (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
# # State, 0: open, 1: closed, 2: locked
# IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
# IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
#
# list_textual_descriptions = []
#
# if carrying is not None:
# list_textual_descriptions.append("You carry a {} {}".format(carrying.color, carrying.type))
#
# # agent_pos_vx, agent_pos_vy = self.get_view_coords(self.agent_pos[0], self.agent_pos[1])
#
# view_field_dictionary = dict()
#
# for i in range(image.shape[0]):
# for j in range(image.shape[1]):
# if image[i][j][0] != 0 and image[i][j][0] != 1 and image[i][j][0] != 2:
# if i not in view_field_dictionary.keys():
# view_field_dictionary[i] = dict()
# view_field_dictionary[i][j] = image[i][j]
# else:
# view_field_dictionary[i][j] = image[i][j]
#
# # Find the wall if any
# # We describe a wall only if there is no objects between the agent and the wall in straight line
#
# # Find wall in front
# add_wall_descr = False
# if add_wall_descr:
# j = agent_pos_vy - 1
# object_seen = False
# while j >= 0 and not object_seen:
# if image[agent_pos_vx][j][0] != 0 and image[agent_pos_vx][j][0] != 1:
# if image[agent_pos_vx][j][0] == 2:
# list_textual_descriptions.append(
# f"A wall is {agent_pos_vy - j} steps in front of you. \n") # forward
# object_seen = True
# else:
# object_seen = True
# j -= 1
# # Find wall left
# i = agent_pos_vx - 1
# object_seen = False
# while i >= 0 and not object_seen:
# if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1:
# if image[i][agent_pos_vy][0] == 2:
# list_textual_descriptions.append(
# f"A wall is {agent_pos_vx - i} steps to the left. \n") # left
# object_seen = True
# else:
# object_seen = True
# i -= 1
# # Find wall right
# i = agent_pos_vx + 1
# object_seen = False
# while i < image.shape[0] and not object_seen:
# if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1:
# if image[i][agent_pos_vy][0] == 2:
# list_textual_descriptions.append(
# f"A wall is {i - agent_pos_vx} steps to the right. \n") # right
# object_seen = True
# else:
# object_seen = True
# i += 1
#
# # list_textual_descriptions.append("You see the following objects: ")
# # returns the position of seen objects relative to you
# for i in view_field_dictionary.keys():
# for j in view_field_dictionary[i].keys():
# if i != agent_pos_vx or j != agent_pos_vy:
# object = view_field_dictionary[i][j]
#
# # # don't show npc
# # if IDX_TO_OBJECT[object[0]] == "npc":
# # continue
#
# front_dist = agent_pos_vy - j
# left_right_dist = i - agent_pos_vx
#
# loc_descr = ""
# if front_dist == 1 and left_right_dist == 0:
# loc_descr += "Right in front of you "
#
# elif left_right_dist == 1 and front_dist == 0:
# loc_descr += "Just to the right of you"
#
# elif left_right_dist == -1 and front_dist == 0:
# loc_descr += "Just to the left of you"
#
# else:
# front_str = str(front_dist) + " steps in front of you " if front_dist > 0 else ""
#
# loc_descr += front_str
#
# suff = "s" if abs(left_right_dist) > 0 else ""
# and_ = "and" if loc_descr != "" else ""
#
# if left_right_dist < 0:
# left_right_str = f"{and_} {-left_right_dist} step{suff} to the left"
# loc_descr += left_right_str
#
# elif left_right_dist > 0:
# left_right_str = f"{and_} {left_right_dist} step{suff} to the right"
# loc_descr += left_right_str
#
# else:
# left_right_str = ""
# loc_descr += left_right_str
#
# loc_descr += f" there is a "
#
# obj_type = IDX_TO_OBJECT[object[0]]
# if obj_type == "npc":
# IDX_TO_STATE = {0: 'friendly', 1: 'antagonistic'}
#
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} peer. "
#
# # gaze
# gaze_dir = {
# 0: "towards you",
# 1: "to the left of you",
# 2: "in the same direction as you",
# 3: "to the right of you",
# }
# description += f"It is looking {gaze_dir[object[3]]}. "
#
# # point
# point_dir = {
# 0: "towards you",
# 1: "to the left of you",
# 2: "in the same direction as you",
# 3: "to the right of you",
# }
#
# if object[4] != 255:
# description += f"It is pointing {point_dir[object[4]]}. "
#
# # last action
# last_action = {v: k for k, v in npc_actions_dict.items()}[object[5]]
#
# last_action = {
# "go_forward": "foward",
# "rotate_left": "turn left",
# "rotate_right": "turn right",
# "toggle_action": "toggle",
# "point_stop_point": "stop pointing",
# "point_E": "",
# "point_S": "",
# "point_W": "",
# "point_N": "",
# "stop_point": "stop pointing",
# "no_op": ""
# }[last_action]
#
# if last_action not in ["no_op", ""]:
# description += f"It's last action is {last_action}. "
#
# elif obj_type in ["switch", "apple", "generatorplatform", "marble", "marbletee", "fence"]:
# # todo: this assumes that Switch.no_light == True
# description = f"{IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
# assert object[2:].mean() == 0
#
# elif obj_type == "lockablebox":
# IDX_TO_STATE = {0: 'open', 1: 'closed', 2: 'locked'}
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
# assert object[3:].mean() == 0
#
# elif obj_type == "applegenerator":
# IDX_TO_STATE = {1: 'square', 2: 'round'}
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
# assert object[3:].mean() == 0
#
# elif obj_type == "remotedoor":
# IDX_TO_STATE = {0: 'open', 1: 'closed'}
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
# assert object[3:].mean() == 0
#
# elif obj_type == "door":
# IDX_TO_STATE = {0: 'open', 1: 'closed', 2: 'locked'}
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} "
# assert object[3:].mean() == 0
#
# elif obj_type == "lever":
# IDX_TO_STATE = {1: 'activated', 0: 'unactivated'}
# if object[3] == 255:
# countdown_txt = ""
# else:
# countdown_txt = f"with {object[3]} timesteps left. "
#
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} {countdown_txt}"
#
# assert object[4:].mean() == 0
# else:
# raise ValueError(f"Undefined object type {obj_type}")
#
# full_destr = loc_descr + description + "\n"
#
# list_textual_descriptions.append(full_destr)
#
# if len(list_textual_descriptions) == 0:
# list_textual_descriptions.append("\n")
#
# return list_textual_descriptions
def plt_2_rgb(env):
# data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8)
# data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,))
width, height = env.window.fig.get_size_inches() * env.window.fig.get_dpi()
data = np.fromstring(env.window.fig.canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3)
return data
def reset(env):
env.reset()
# a dirty trick just to get obs and info
return env.step([np.nan, np.nan, np.nan])
# return step("no_op")
def generate(text_input, model):
# return "(a) move forward"
if model == "dummy":
print("dummy action forward")
return "move forward"
elif model == "interactive":
return input("Enter action:")
elif model == "random":
print("random agent")
print("PROMPT:")
print(text_input)
return random.choice([
"move forward",
"turn left",
"turn right",
"toggle",
])
elif model in ["gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4-0613", "gpt-4-0314"]:
while True:
try:
c = openai.ChatCompletion.create(
model=model,
messages=[
# {"role": "system", "content": ""},
# {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
# {"role": "user", "content": "Continue the following text in the most logical way.\n"+text_input}
# {"role": "system", "content":
# "You are an agent and can use the following actions: 'move forward', 'toggle', 'turn left', 'turn right', 'done'."
# # "The caretaker will say the color of the box which you should open. Turn until you find this box and toggle it when it is right in front of it."
# # "Then an apple will appear and you can toggle it to succeed."
# },
{"role": "user", "content": text_input}
],
max_tokens=3,
n=1,
temperature=0.0,
request_timeout=30,
)
break
except Exception as e:
print(e)
print("Pausing")
time.sleep(10)
continue
print("->LLM generation: ", c['choices'][0]['message']['content'])
return c['choices'][0]['message']['content']
elif re.match(r"text-.*-\d{3}", model) or model in ["gpt-3.5-turbo-instruct-0914"]:
while True:
try:
response = openai.Completion.create(
model=model,
prompt=text_input,
# temperature=0.7,
temperature=0.0,
max_tokens=3,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
timeout=30
)
break
except Exception as e:
print(e)
print("Pausing")
time.sleep(10)
continue
choices = response["choices"]
assert len(choices) == 1
return choices[0]["text"].strip().lower() # remove newline from the end
elif model in ["gpt2_large", "api_bloom"]:
# HF_TOKEN = os.getenv("HF_TOKEN")
if model == "gpt2_large":
API_URL = "https://api-inference.huggingface.co/models/gpt2-large"
elif model == "api_bloom":
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
else:
raise ValueError(f"Undefined model {model}.")
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
def query(text_prompt, n_tokens=3):
input = text_prompt
# make n_tokens request and append the output each time - one request generates one token
for _ in range(n_tokens):
# prepare request
payload = {
"inputs": input,
"parameters": {
"do_sample": False,
'temperature': 0,
'wait_for_model': True,
# "max_length": 500, # for gpt2
# "max_new_tokens": 250 # fot gpt2-xl
},
}
data = json.dumps(payload)
# request
response = requests.request("POST", API_URL, headers=headers, data=data)
response_json = json.loads(response.content.decode("utf-8"))
if type(response_json) is list and len(response_json) == 1:
# generated_text contains the input + the response
response_full_text = response_json[0]['generated_text']
# we use this as the next input
input = response_full_text
else:
print("Invalid request to huggingface api")
from IPython import embed; embed()
# remove the prompt from the beginning
assert response_full_text.startswith(text_prompt)
response_text = response_full_text[len(text_prompt):]
return response_text
response = query(text_input).strip().lower()
return response
elif model in ["bloom_560m"]:
# from transformers import BloomForCausalLM
# from transformers import BloomTokenizerFast
#
# tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/")
# model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/")
inputs = hf_tokenizer(text_input, return_tensors="pt")
# 3 words
result_length = inputs['input_ids'].shape[-1]+3
full_output = hf_tokenizer.decode(hf_model.generate(inputs["input_ids"], max_length=result_length)[0])
assert full_output.startswith(text_input)
response = full_output[len(text_input):]
response = response.strip().lower()
return response
else:
raise ValueError("Unknown model.")
def estimate_tokens_selenium(prompt):
# selenium is used because python3.9 is needed for tiktoken
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import time
# Initialize the WebDriver instance
options = webdriver.ChromeOptions()
options.add_argument('headless')
# set up the driver
driver = webdriver.Chrome(options=options)
# Navigate to the website
driver.get('https://platform.openai.com/tokenizer')
text_input = driver.find_element(By.XPATH, '/html/body/div[1]/div[1]/div/div[2]/div[3]/textarea')
text_input.clear()
text_input.send_keys(prompt)
n_tokens = 0
while n_tokens == 0:
time.sleep(1)
# Wait for the response to be loaded
wait = WebDriverWait(driver, 10)
response = wait.until(
EC.presence_of_element_located((By.CSS_SELECTOR, 'div.tokenizer-stat:nth-child(1) > div:nth-child(2)')))
n_tokens = int(response.text.replace(",", ""))
# Close the WebDriver instance
driver.quit()
return n_tokens
def load_in_context_examples(in_context_episodes):
in_context_examples = ""
print(f'Loading {len(in_context_episodes)} examples.')
for episode_data in in_context_episodes:
in_context_examples += new_episode_marker()
for step_i, step_data in enumerate(episode_data):
action = step_data["action"]
info = step_data["info"]
obs = step_data["obs"]
reward = step_data["reward"]
done = step_data["done"]
if step_i == 0:
# step 0 is the initial state of the environment
assert action is None
prompt_action_text = ""
else:
prompt_action_text = action_to_prompt_action_text(action)
text_obs = generate_text_obs(obs, info)
step_text = prompt_preprocessor(prompt_action_text + text_obs)
in_context_examples += step_text
if done:
if reward > 0:
in_context_examples += success_marker()
else:
in_context_examples += failure_marker()
else:
# in all envs reward is given in the end
# in the initial step rewards is None
assert reward == 0 or (step_i == 0 and reward is None)
print("-------------------------- IN CONTEXT EXAMPLES --------------------------")
print(in_context_examples)
print("-------------------------------------------------------------------------")
return in_context_examples
if __name__ == "__main__":
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=False,
help="text-ada-001")
parser.add_argument("--seed", type=int, default=0,
help="Seed of the first episode. The seed for the following episodes will be used in order: seed, seed + 1, ... seed + (n_episodes-1) (default: 0)")
parser.add_argument("--max-steps", type=int, default=15,
help="max num of steps")
parser.add_argument("--shift", type=int, default=0,
help="number of times the environment is reset at the beginning (default: 0)")
parser.add_argument("--argmax", action="store_true", default=False,
help="select the action with highest probability (default: False)")
parser.add_argument("--pause", type=float, default=0.5,
help="pause duration between two consequent actions of the agent (default: 0.5)")
parser.add_argument("--env-name", type=str,
default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1",
# default="SocialAI-ColorBoxesLLMCSParamEnv-v1",
required=False,
help="env name")
parser.add_argument("--in-context-path", type=str,
# old
# default='llm_data/in_context_asocial_box.txt'
# default='llm_data/in_context_color_boxes.txt',
# new
# asocial box
default='llm_data/in_context_examples/in_context_asocialbox_SocialAI-AsocialBoxInformationSeekingParamEnv-v1_2023_07_19_19_28_48/episodes.pkl',
# colorbox
# default='llm_data/in_context_examples/in_context_colorbox_SocialAI-ColorBoxesLLMCSParamEnv-v1_2023_07_20_13_11_54/episodes.pkl',
required=False,
help="path to in context examples")
parser.add_argument("--episodes", type=int, default=10,
help="number of episodes to visualize")
parser.add_argument("--env-args", nargs='*', default=None)
parser.add_argument("--agent_view", default=False, help="draw the agent sees (partially observable view)", action='store_true' )
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32 )
parser.add_argument("--mask-unobserved", default=False, help="mask cells that are not observed by the agent", action='store_true' )
parser.add_argument("--log", type=str, default="llm_log/episodes_log", help="log from the run", required=False)
parser.add_argument("--feed-full-ep", default=False, help="weather to append the whole episode to the prompt", action='store_true')
parser.add_argument("--last-n", type=int, help="how many last steps to provide in observation (if not feed-full-ep)", default=3)
parser.add_argument("--skip-check", default=False, help="Don't estimate the price.", action="store_true")
args = parser.parse_args()
# Set seed for all randomness sources
seed(args.seed)
model = args.model
in_context_examples_path = args.in_context_path
# test for paper: remove later
if "asocialbox" in in_context_examples_path:
assert args.env_name == "SocialAI-AsocialBoxInformationSeekingParamEnv-v1"
elif "colorbox" in in_context_examples_path:
assert args.env_name == "SocialAI-ColorBoxesLLMCSParamEnv-v1"
print("env name:", args.env_name)
print("examples:", in_context_examples_path)
print("model:", args.model)
# datetime
now = datetime.now()
datetime_string = now.strftime("%d_%m_%Y_%H:%M:%S")
print(datetime_string)
# log filenames
log_folder = args.log+"_"+datetime_string+"/"
os.mkdir(log_folder)
evaluation_log_filename = log_folder+"evaluation_log.json"
prompt_log_filename = log_folder + "prompt_log.txt"
ep_h_log_filename = log_folder+"episode_history_query.txt"
gif_savename = log_folder + "demo.gif"
env_args = env_args_str_to_dict(args.env_args)
env = make_env(args.env_name, args.seed, env_args)
# env = gym.make(args.env_name, **env_args)
print(f"Environment {args.env_name} and args: {env_args_str_to_dict(args.env_args)}\n")
# Define agent
print("Agent loaded\n")
# prepare models
model_instance = None
if "text" in args.model or "gpt-3" in args.model or "gpt-4" in args.model:
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
elif args.model in ["gpt2_large", "api_bloom"]:
HF_TOKEN = os.getenv("HF_TOKEN")
elif args.model in ["bloom_560m"]:
from transformers import BloomForCausalLM
from transformers import BloomTokenizerFast
hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/")
hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/")
elif args.model in ["bloom"]:
from transformers import BloomForCausalLM
from transformers import BloomTokenizerFast
hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/")
hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/")
model_instance = (hf_tokenizer, hf_model)
with open(in_context_examples_path, "rb") as f:
in_context_episodes = pickle.load(f)
in_context_examples = load_in_context_examples(in_context_episodes)
with open(prompt_log_filename, "a+") as f:
f.write(datetime_string)
with open(ep_h_log_filename, "a+") as f:
f.write(datetime_string)
full_episode_history = args.feed_full_ep
last_n=args.last_n
if full_episode_history:
print("Full episode history.")
else:
print(f"Last {args.last_n} steps.")
if not args.skip_check and not args.model in ["dummy", "random", "interactive"]:
print(f"Estimating price for model {args.model}.")
in_context_n_tokens = estimate_tokens_selenium(in_context_examples)
n_in_context_steps = sum([len(ep) for ep in in_context_episodes])
tokens_per_step = in_context_n_tokens / n_in_context_steps
_, price = estimate_price(
num_of_episodes=args.episodes,
in_context_len=in_context_n_tokens,
tokens_per_step=tokens_per_step,
n_steps=args.max_steps,
last_n=last_n,
model=args.model,
feed_episode_history=full_episode_history
)
input(f"You will spend: {price} dollars. ok?")
# prepare frames list to save to gif
frames = []
assert args.max_steps <= 20
success_rates = []
# episodes start
for episode in range(args.episodes):
print("Episode:", episode)
episode_history_text = new_episode_marker()
success = False
episode_seed = args.seed + episode
env = make_env(args.env_name, episode_seed, env_args)
with open(prompt_log_filename, "a+") as f:
f.write("\n\n")
observations = []
actions = []
for i in range(int(args.max_steps)):
if i == 0:
obs, reward, done, info = reset(env)
prompt_action_text = ""
else:
with open(prompt_log_filename, "a+") as f:
f.write("\nnew prompt: -----------------------------------\n")
f.write(llm_prompt)
# querry the model
generation = generate(llm_prompt, args.model)
# parse the action
text_action = get_parsed_action(generation)
# get the raw action
action = text_action_to_action(text_action)
# execute the action
obs, reward, done, info = env.step(action)
prompt_action_text = f"{action_query()} {text_action}\n"
assert action_to_prompt_action_text(action) == prompt_action_text
actions.append(prompt_action_text)
text_obs = generate_text_obs(obs, info)
observations.append(text_obs)
step_text = prompt_preprocessor(prompt_action_text + text_obs)
print("Step text:")
print(step_text)
episode_history_text += step_text # append to history of this episode
if full_episode_history:
# feed full episode history
llm_prompt = in_context_examples + episode_history_text + action_query()
else:
n = min(last_n, len(observations))
obs = observations[-n:]
act = (actions + [action_query()])[-n:]
episode_text = "".join([o+a for o, a in zip(obs, act)])
llm_prompt = in_context_examples + new_episode_marker() + episode_text
llm_prompt = prompt_preprocessor(llm_prompt)
# save the image
env.render(mode="human")
rgb_img = plt_2_rgb(env)
frames.append(rgb_img)
if env.current_env.box.blocked and not env.current_env.box.is_open:
# target box is blocked -> apple can't be obtained
# break to save compute
break
if done:
# quadruple last frame to pause between episodes
for i in range(3):
same_img = np.copy(rgb_img)
# toggle a pixel between frames to avoid cropping when going from gif to mp4
same_img[0, 0, 2] = 0 if (i % 2) == 0 else 255
frames.append(same_img)
if reward > 0:
print("Success!")
episode_history_text += success_marker()
success = True
else:
episode_history_text += failure_marker()
with open(ep_h_log_filename, "a+") as f:
f.write("\nnew prompt: -----------------------------------\n")
f.write(episode_history_text)
break
else:
with open(ep_h_log_filename, "a+") as f:
f.write("\nnew prompt: -----------------------------------\n")
f.write(episode_history_text)
print(f"{'Success' if success else 'Failure'}")
success_rates.append(success)
mean_success_rate = np.mean(success_rates)
print("Success rate:", mean_success_rate)
print(f"Saving gif to {gif_savename}.")
mimsave(gif_savename, frames, duration=args.pause)
print("Done.")
log_data_dict = vars(args)
log_data_dict["success_rates"] = success_rates
log_data_dict["mean_success_rate"] = mean_success_rate
print("Evaluation log: ", evaluation_log_filename)
with open(evaluation_log_filename, "w") as f:
f.write(json.dumps(log_data_dict))