Spaces:
Running
Running
import argparse | |
import json | |
import requests | |
import time | |
import warnings | |
from n_tokens import estimate_price | |
import pickle | |
import numpy as np | |
import torch | |
from pathlib import Path | |
# from utils.babyai_utils.baby_agent import load_agent | |
from utils import * | |
from textworld_utils.utils import generate_text_obs | |
from models import * | |
import subprocess | |
import os | |
from matplotlib import pyplot as plt | |
from gym_minigrid.wrappers import * | |
from gym_minigrid.window import Window | |
from datetime import datetime | |
from imageio import mimsave | |
def new_episode_marker(): | |
return "New episode.\n" | |
def success_marker(): | |
return "Success!\n" | |
def failure_marker(): | |
return "Failure!\n" | |
def action_query(): | |
return "Act :" | |
def get_parsed_action(text_action): | |
""" | |
Parses the text generated by a model and extracts the action | |
""" | |
if "move forward" in text_action: | |
return "move forward" | |
elif "done" in text_action: | |
return "done" | |
elif "turn left" in text_action: | |
return "turn left" | |
elif "turn right" in text_action: | |
return "turn right" | |
elif "toggle" in text_action: | |
return "toggle" | |
elif "no_op" in text_action: | |
return "no_op" | |
else: | |
warnings.warn(f"Undefined action {text_action}") | |
return "no_op" | |
def action_to_prompt_action_text(action): | |
if np.allclose(action, [int(env.actions.forward), np.nan, np.nan], equal_nan=True): | |
# 2 | |
text_action = "move forward" | |
elif np.allclose(action, [int(env.actions.left), np.nan, np.nan], equal_nan=True): | |
# 0 | |
text_action = "turn left" | |
elif np.allclose(action, [int(env.actions.right), np.nan, np.nan], equal_nan=True): | |
# 1 | |
text_action = "turn right" | |
elif np.allclose(action, [int(env.actions.toggle), np.nan, np.nan], equal_nan=True): | |
# 3 | |
text_action = "toggle" | |
elif np.allclose(action, [int(env.actions.done), np.nan, np.nan], equal_nan=True): | |
# 4 | |
text_action = "done" | |
elif np.allclose(action, [np.nan, np.nan, np.nan], equal_nan=True): | |
text_action = "no_op" | |
else: | |
warnings.warn(f"Undefined action {action}") | |
return "no_op" | |
return f"{action_query()} {text_action}\n" | |
def text_action_to_action(text_action): | |
# text_action = get_parsed_action(text_action) | |
if "move forward" == text_action: | |
action = [int(env.actions.forward), np.nan, np.nan] | |
elif "turn left" == text_action: | |
action = [int(env.actions.left), np.nan, np.nan] | |
elif "turn right" == text_action: | |
action = [int(env.actions.right), np.nan, np.nan] | |
elif "toggle" == text_action: | |
action = [int(env.actions.toggle), np.nan, np.nan] | |
elif "done" == text_action: | |
action = [int(env.actions.done), np.nan, np.nan] | |
elif "no_op" == text_action: | |
action = [np.nan, np.nan, np.nan] | |
return action | |
def prompt_preprocessor(llm_prompt): | |
# remove peer observations | |
lines = llm_prompt.split("\n") | |
new_lines = [] | |
for line in lines: | |
if line.startswith("#"): | |
continue | |
elif line.startswith("Conversation"): | |
continue | |
elif "peer" in line: | |
caretaker = True | |
if caretaker: | |
# show only the location of the caretaker | |
# this is very ugly, todo: refactor this | |
assert "there is a" in line | |
start_index = line.index('there is a') + 11 | |
new_line = line[:start_index] + 'caretaker' | |
new_lines.append(new_line) | |
else: | |
# no caretaker at all | |
if line.startswith("Obs :") and "peer" in line: | |
# remove only the peer descriptions | |
line = "Obs :" | |
new_lines.append(line) | |
else: | |
assert "peer" in line | |
elif "Caretaker:" in line: | |
line = line.replace("Caretaker:", "Caretaker says: ") | |
new_lines.append(line) | |
else: | |
new_lines.append(line) | |
return "\n".join(new_lines) | |
# def generate_text_obs(obs, info): | |
# | |
# text_observation = obs_to_text(info) | |
# | |
# llm_prompt = "Obs : " | |
# llm_prompt += "".join(text_observation) | |
# | |
# # add utterances | |
# if obs["utterance_history"] != "Conversation: \n": | |
# utt_hist = obs['utterance_history'] | |
# utt_hist = utt_hist.replace("Conversation: \n","") | |
# llm_prompt += utt_hist | |
# | |
# return llm_prompt | |
# def obs_to_text(info): | |
# image, vis_mask = info["image"], info["vis_mask"] | |
# carrying = info["carrying"] | |
# agent_pos_vx, agent_pos_vy = info["agent_pos_vx"], info["agent_pos_vy"] | |
# npc_actions_dict = info["npc_actions_dict"] | |
# | |
# # (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state) | |
# # State, 0: open, 1: closed, 2: locked | |
# IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys())) | |
# IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys())) | |
# | |
# list_textual_descriptions = [] | |
# | |
# if carrying is not None: | |
# list_textual_descriptions.append("You carry a {} {}".format(carrying.color, carrying.type)) | |
# | |
# # agent_pos_vx, agent_pos_vy = self.get_view_coords(self.agent_pos[0], self.agent_pos[1]) | |
# | |
# view_field_dictionary = dict() | |
# | |
# for i in range(image.shape[0]): | |
# for j in range(image.shape[1]): | |
# if image[i][j][0] != 0 and image[i][j][0] != 1 and image[i][j][0] != 2: | |
# if i not in view_field_dictionary.keys(): | |
# view_field_dictionary[i] = dict() | |
# view_field_dictionary[i][j] = image[i][j] | |
# else: | |
# view_field_dictionary[i][j] = image[i][j] | |
# | |
# # Find the wall if any | |
# # We describe a wall only if there is no objects between the agent and the wall in straight line | |
# | |
# # Find wall in front | |
# add_wall_descr = False | |
# if add_wall_descr: | |
# j = agent_pos_vy - 1 | |
# object_seen = False | |
# while j >= 0 and not object_seen: | |
# if image[agent_pos_vx][j][0] != 0 and image[agent_pos_vx][j][0] != 1: | |
# if image[agent_pos_vx][j][0] == 2: | |
# list_textual_descriptions.append( | |
# f"A wall is {agent_pos_vy - j} steps in front of you. \n") # forward | |
# object_seen = True | |
# else: | |
# object_seen = True | |
# j -= 1 | |
# # Find wall left | |
# i = agent_pos_vx - 1 | |
# object_seen = False | |
# while i >= 0 and not object_seen: | |
# if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1: | |
# if image[i][agent_pos_vy][0] == 2: | |
# list_textual_descriptions.append( | |
# f"A wall is {agent_pos_vx - i} steps to the left. \n") # left | |
# object_seen = True | |
# else: | |
# object_seen = True | |
# i -= 1 | |
# # Find wall right | |
# i = agent_pos_vx + 1 | |
# object_seen = False | |
# while i < image.shape[0] and not object_seen: | |
# if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1: | |
# if image[i][agent_pos_vy][0] == 2: | |
# list_textual_descriptions.append( | |
# f"A wall is {i - agent_pos_vx} steps to the right. \n") # right | |
# object_seen = True | |
# else: | |
# object_seen = True | |
# i += 1 | |
# | |
# # list_textual_descriptions.append("You see the following objects: ") | |
# # returns the position of seen objects relative to you | |
# for i in view_field_dictionary.keys(): | |
# for j in view_field_dictionary[i].keys(): | |
# if i != agent_pos_vx or j != agent_pos_vy: | |
# object = view_field_dictionary[i][j] | |
# | |
# # # don't show npc | |
# # if IDX_TO_OBJECT[object[0]] == "npc": | |
# # continue | |
# | |
# front_dist = agent_pos_vy - j | |
# left_right_dist = i - agent_pos_vx | |
# | |
# loc_descr = "" | |
# if front_dist == 1 and left_right_dist == 0: | |
# loc_descr += "Right in front of you " | |
# | |
# elif left_right_dist == 1 and front_dist == 0: | |
# loc_descr += "Just to the right of you" | |
# | |
# elif left_right_dist == -1 and front_dist == 0: | |
# loc_descr += "Just to the left of you" | |
# | |
# else: | |
# front_str = str(front_dist) + " steps in front of you " if front_dist > 0 else "" | |
# | |
# loc_descr += front_str | |
# | |
# suff = "s" if abs(left_right_dist) > 0 else "" | |
# and_ = "and" if loc_descr != "" else "" | |
# | |
# if left_right_dist < 0: | |
# left_right_str = f"{and_} {-left_right_dist} step{suff} to the left" | |
# loc_descr += left_right_str | |
# | |
# elif left_right_dist > 0: | |
# left_right_str = f"{and_} {left_right_dist} step{suff} to the right" | |
# loc_descr += left_right_str | |
# | |
# else: | |
# left_right_str = "" | |
# loc_descr += left_right_str | |
# | |
# loc_descr += f" there is a " | |
# | |
# obj_type = IDX_TO_OBJECT[object[0]] | |
# if obj_type == "npc": | |
# IDX_TO_STATE = {0: 'friendly', 1: 'antagonistic'} | |
# | |
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} peer. " | |
# | |
# # gaze | |
# gaze_dir = { | |
# 0: "towards you", | |
# 1: "to the left of you", | |
# 2: "in the same direction as you", | |
# 3: "to the right of you", | |
# } | |
# description += f"It is looking {gaze_dir[object[3]]}. " | |
# | |
# # point | |
# point_dir = { | |
# 0: "towards you", | |
# 1: "to the left of you", | |
# 2: "in the same direction as you", | |
# 3: "to the right of you", | |
# } | |
# | |
# if object[4] != 255: | |
# description += f"It is pointing {point_dir[object[4]]}. " | |
# | |
# # last action | |
# last_action = {v: k for k, v in npc_actions_dict.items()}[object[5]] | |
# | |
# last_action = { | |
# "go_forward": "foward", | |
# "rotate_left": "turn left", | |
# "rotate_right": "turn right", | |
# "toggle_action": "toggle", | |
# "point_stop_point": "stop pointing", | |
# "point_E": "", | |
# "point_S": "", | |
# "point_W": "", | |
# "point_N": "", | |
# "stop_point": "stop pointing", | |
# "no_op": "" | |
# }[last_action] | |
# | |
# if last_action not in ["no_op", ""]: | |
# description += f"It's last action is {last_action}. " | |
# | |
# elif obj_type in ["switch", "apple", "generatorplatform", "marble", "marbletee", "fence"]: | |
# # todo: this assumes that Switch.no_light == True | |
# description = f"{IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " | |
# assert object[2:].mean() == 0 | |
# | |
# elif obj_type == "lockablebox": | |
# IDX_TO_STATE = {0: 'open', 1: 'closed', 2: 'locked'} | |
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " | |
# assert object[3:].mean() == 0 | |
# | |
# elif obj_type == "applegenerator": | |
# IDX_TO_STATE = {1: 'square', 2: 'round'} | |
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " | |
# assert object[3:].mean() == 0 | |
# | |
# elif obj_type == "remotedoor": | |
# IDX_TO_STATE = {0: 'open', 1: 'closed'} | |
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " | |
# assert object[3:].mean() == 0 | |
# | |
# elif obj_type == "door": | |
# IDX_TO_STATE = {0: 'open', 1: 'closed', 2: 'locked'} | |
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " | |
# assert object[3:].mean() == 0 | |
# | |
# elif obj_type == "lever": | |
# IDX_TO_STATE = {1: 'activated', 0: 'unactivated'} | |
# if object[3] == 255: | |
# countdown_txt = "" | |
# else: | |
# countdown_txt = f"with {object[3]} timesteps left. " | |
# | |
# description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} {countdown_txt}" | |
# | |
# assert object[4:].mean() == 0 | |
# else: | |
# raise ValueError(f"Undefined object type {obj_type}") | |
# | |
# full_destr = loc_descr + description + "\n" | |
# | |
# list_textual_descriptions.append(full_destr) | |
# | |
# if len(list_textual_descriptions) == 0: | |
# list_textual_descriptions.append("\n") | |
# | |
# return list_textual_descriptions | |
def plt_2_rgb(env): | |
# data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8) | |
# data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,)) | |
width, height = env.window.fig.get_size_inches() * env.window.fig.get_dpi() | |
data = np.fromstring(env.window.fig.canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3) | |
return data | |
def reset(env): | |
env.reset() | |
# a dirty trick just to get obs and info | |
return env.step([np.nan, np.nan, np.nan]) | |
# return step("no_op") | |
def generate(text_input, model): | |
# return "(a) move forward" | |
if model == "dummy": | |
print("dummy action forward") | |
return "move forward" | |
elif model == "interactive": | |
return input("Enter action:") | |
elif model == "random": | |
print("random agent") | |
print("PROMPT:") | |
print(text_input) | |
return random.choice([ | |
"move forward", | |
"turn left", | |
"turn right", | |
"toggle", | |
]) | |
elif model in ["gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4-0613", "gpt-4-0314"]: | |
while True: | |
try: | |
c = openai.ChatCompletion.create( | |
model=model, | |
messages=[ | |
# {"role": "system", "content": ""}, | |
# {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, | |
# {"role": "user", "content": "Continue the following text in the most logical way.\n"+text_input} | |
# {"role": "system", "content": | |
# "You are an agent and can use the following actions: 'move forward', 'toggle', 'turn left', 'turn right', 'done'." | |
# # "The caretaker will say the color of the box which you should open. Turn until you find this box and toggle it when it is right in front of it." | |
# # "Then an apple will appear and you can toggle it to succeed." | |
# }, | |
{"role": "user", "content": text_input} | |
], | |
max_tokens=3, | |
n=1, | |
temperature=0.0, | |
request_timeout=30, | |
) | |
break | |
except Exception as e: | |
print(e) | |
print("Pausing") | |
time.sleep(10) | |
continue | |
print("->LLM generation: ", c['choices'][0]['message']['content']) | |
return c['choices'][0]['message']['content'] | |
elif re.match(r"text-.*-\d{3}", model) or model in ["gpt-3.5-turbo-instruct-0914"]: | |
while True: | |
try: | |
response = openai.Completion.create( | |
model=model, | |
prompt=text_input, | |
# temperature=0.7, | |
temperature=0.0, | |
max_tokens=3, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
timeout=30 | |
) | |
break | |
except Exception as e: | |
print(e) | |
print("Pausing") | |
time.sleep(10) | |
continue | |
choices = response["choices"] | |
assert len(choices) == 1 | |
return choices[0]["text"].strip().lower() # remove newline from the end | |
elif model in ["gpt2_large", "api_bloom"]: | |
# HF_TOKEN = os.getenv("HF_TOKEN") | |
if model == "gpt2_large": | |
API_URL = "https://api-inference.huggingface.co/models/gpt2-large" | |
elif model == "api_bloom": | |
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" | |
else: | |
raise ValueError(f"Undefined model {model}.") | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
def query(text_prompt, n_tokens=3): | |
input = text_prompt | |
# make n_tokens request and append the output each time - one request generates one token | |
for _ in range(n_tokens): | |
# prepare request | |
payload = { | |
"inputs": input, | |
"parameters": { | |
"do_sample": False, | |
'temperature': 0, | |
'wait_for_model': True, | |
# "max_length": 500, # for gpt2 | |
# "max_new_tokens": 250 # fot gpt2-xl | |
}, | |
} | |
data = json.dumps(payload) | |
# request | |
response = requests.request("POST", API_URL, headers=headers, data=data) | |
response_json = json.loads(response.content.decode("utf-8")) | |
if type(response_json) is list and len(response_json) == 1: | |
# generated_text contains the input + the response | |
response_full_text = response_json[0]['generated_text'] | |
# we use this as the next input | |
input = response_full_text | |
else: | |
print("Invalid request to huggingface api") | |
from IPython import embed; embed() | |
# remove the prompt from the beginning | |
assert response_full_text.startswith(text_prompt) | |
response_text = response_full_text[len(text_prompt):] | |
return response_text | |
response = query(text_input).strip().lower() | |
return response | |
elif model in ["bloom_560m"]: | |
# from transformers import BloomForCausalLM | |
# from transformers import BloomTokenizerFast | |
# | |
# tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") | |
# model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") | |
inputs = hf_tokenizer(text_input, return_tensors="pt") | |
# 3 words | |
result_length = inputs['input_ids'].shape[-1]+3 | |
full_output = hf_tokenizer.decode(hf_model.generate(inputs["input_ids"], max_length=result_length)[0]) | |
assert full_output.startswith(text_input) | |
response = full_output[len(text_input):] | |
response = response.strip().lower() | |
return response | |
else: | |
raise ValueError("Unknown model.") | |
def estimate_tokens_selenium(prompt): | |
# selenium is used because python3.9 is needed for tiktoken | |
from selenium import webdriver | |
from selenium.webdriver.common.by import By | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.support import expected_conditions as EC | |
import time | |
# Initialize the WebDriver instance | |
options = webdriver.ChromeOptions() | |
options.add_argument('headless') | |
# set up the driver | |
driver = webdriver.Chrome(options=options) | |
# Navigate to the website | |
driver.get('https://platform.openai.com/tokenizer') | |
text_input = driver.find_element(By.XPATH, '/html/body/div[1]/div[1]/div/div[2]/div[3]/textarea') | |
text_input.clear() | |
text_input.send_keys(prompt) | |
n_tokens = 0 | |
while n_tokens == 0: | |
time.sleep(1) | |
# Wait for the response to be loaded | |
wait = WebDriverWait(driver, 10) | |
response = wait.until( | |
EC.presence_of_element_located((By.CSS_SELECTOR, 'div.tokenizer-stat:nth-child(1) > div:nth-child(2)'))) | |
n_tokens = int(response.text.replace(",", "")) | |
# Close the WebDriver instance | |
driver.quit() | |
return n_tokens | |
def load_in_context_examples(in_context_episodes): | |
in_context_examples = "" | |
print(f'Loading {len(in_context_episodes)} examples.') | |
for episode_data in in_context_episodes: | |
in_context_examples += new_episode_marker() | |
for step_i, step_data in enumerate(episode_data): | |
action = step_data["action"] | |
info = step_data["info"] | |
obs = step_data["obs"] | |
reward = step_data["reward"] | |
done = step_data["done"] | |
if step_i == 0: | |
# step 0 is the initial state of the environment | |
assert action is None | |
prompt_action_text = "" | |
else: | |
prompt_action_text = action_to_prompt_action_text(action) | |
text_obs = generate_text_obs(obs, info) | |
step_text = prompt_preprocessor(prompt_action_text + text_obs) | |
in_context_examples += step_text | |
if done: | |
if reward > 0: | |
in_context_examples += success_marker() | |
else: | |
in_context_examples += failure_marker() | |
else: | |
# in all envs reward is given in the end | |
# in the initial step rewards is None | |
assert reward == 0 or (step_i == 0 and reward is None) | |
print("-------------------------- IN CONTEXT EXAMPLES --------------------------") | |
print(in_context_examples) | |
print("-------------------------------------------------------------------------") | |
return in_context_examples | |
if __name__ == "__main__": | |
# Parse arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", required=False, | |
help="text-ada-001") | |
parser.add_argument("--seed", type=int, default=0, | |
help="Seed of the first episode. The seed for the following episodes will be used in order: seed, seed + 1, ... seed + (n_episodes-1) (default: 0)") | |
parser.add_argument("--max-steps", type=int, default=15, | |
help="max num of steps") | |
parser.add_argument("--shift", type=int, default=0, | |
help="number of times the environment is reset at the beginning (default: 0)") | |
parser.add_argument("--argmax", action="store_true", default=False, | |
help="select the action with highest probability (default: False)") | |
parser.add_argument("--pause", type=float, default=0.5, | |
help="pause duration between two consequent actions of the agent (default: 0.5)") | |
parser.add_argument("--env-name", type=str, | |
default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1", | |
# default="SocialAI-ColorBoxesLLMCSParamEnv-v1", | |
required=False, | |
help="env name") | |
parser.add_argument("--in-context-path", type=str, | |
# old | |
# default='llm_data/in_context_asocial_box.txt' | |
# default='llm_data/in_context_color_boxes.txt', | |
# new | |
# asocial box | |
default='llm_data/in_context_examples/in_context_asocialbox_SocialAI-AsocialBoxInformationSeekingParamEnv-v1_2023_07_19_19_28_48/episodes.pkl', | |
# colorbox | |
# default='llm_data/in_context_examples/in_context_colorbox_SocialAI-ColorBoxesLLMCSParamEnv-v1_2023_07_20_13_11_54/episodes.pkl', | |
required=False, | |
help="path to in context examples") | |
parser.add_argument("--episodes", type=int, default=10, | |
help="number of episodes to visualize") | |
parser.add_argument("--env-args", nargs='*', default=None) | |
parser.add_argument("--agent_view", default=False, help="draw the agent sees (partially observable view)", action='store_true' ) | |
parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32 ) | |
parser.add_argument("--mask-unobserved", default=False, help="mask cells that are not observed by the agent", action='store_true' ) | |
parser.add_argument("--log", type=str, default="llm_log/episodes_log", help="log from the run", required=False) | |
parser.add_argument("--feed-full-ep", default=False, help="weather to append the whole episode to the prompt", action='store_true') | |
parser.add_argument("--last-n", type=int, help="how many last steps to provide in observation (if not feed-full-ep)", default=3) | |
parser.add_argument("--skip-check", default=False, help="Don't estimate the price.", action="store_true") | |
args = parser.parse_args() | |
# Set seed for all randomness sources | |
seed(args.seed) | |
model = args.model | |
in_context_examples_path = args.in_context_path | |
# test for paper: remove later | |
if "asocialbox" in in_context_examples_path: | |
assert args.env_name == "SocialAI-AsocialBoxInformationSeekingParamEnv-v1" | |
elif "colorbox" in in_context_examples_path: | |
assert args.env_name == "SocialAI-ColorBoxesLLMCSParamEnv-v1" | |
print("env name:", args.env_name) | |
print("examples:", in_context_examples_path) | |
print("model:", args.model) | |
# datetime | |
now = datetime.now() | |
datetime_string = now.strftime("%d_%m_%Y_%H:%M:%S") | |
print(datetime_string) | |
# log filenames | |
log_folder = args.log+"_"+datetime_string+"/" | |
os.mkdir(log_folder) | |
evaluation_log_filename = log_folder+"evaluation_log.json" | |
prompt_log_filename = log_folder + "prompt_log.txt" | |
ep_h_log_filename = log_folder+"episode_history_query.txt" | |
gif_savename = log_folder + "demo.gif" | |
env_args = env_args_str_to_dict(args.env_args) | |
env = make_env(args.env_name, args.seed, env_args) | |
# env = gym.make(args.env_name, **env_args) | |
print(f"Environment {args.env_name} and args: {env_args_str_to_dict(args.env_args)}\n") | |
# Define agent | |
print("Agent loaded\n") | |
# prepare models | |
model_instance = None | |
if "text" in args.model or "gpt-3" in args.model or "gpt-4" in args.model: | |
import openai | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
elif args.model in ["gpt2_large", "api_bloom"]: | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
elif args.model in ["bloom_560m"]: | |
from transformers import BloomForCausalLM | |
from transformers import BloomTokenizerFast | |
hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") | |
hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") | |
elif args.model in ["bloom"]: | |
from transformers import BloomForCausalLM | |
from transformers import BloomTokenizerFast | |
hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/") | |
hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/") | |
model_instance = (hf_tokenizer, hf_model) | |
with open(in_context_examples_path, "rb") as f: | |
in_context_episodes = pickle.load(f) | |
in_context_examples = load_in_context_examples(in_context_episodes) | |
with open(prompt_log_filename, "a+") as f: | |
f.write(datetime_string) | |
with open(ep_h_log_filename, "a+") as f: | |
f.write(datetime_string) | |
full_episode_history = args.feed_full_ep | |
last_n=args.last_n | |
if full_episode_history: | |
print("Full episode history.") | |
else: | |
print(f"Last {args.last_n} steps.") | |
if not args.skip_check and not args.model in ["dummy", "random", "interactive"]: | |
print(f"Estimating price for model {args.model}.") | |
in_context_n_tokens = estimate_tokens_selenium(in_context_examples) | |
n_in_context_steps = sum([len(ep) for ep in in_context_episodes]) | |
tokens_per_step = in_context_n_tokens / n_in_context_steps | |
_, price = estimate_price( | |
num_of_episodes=args.episodes, | |
in_context_len=in_context_n_tokens, | |
tokens_per_step=tokens_per_step, | |
n_steps=args.max_steps, | |
last_n=last_n, | |
model=args.model, | |
feed_episode_history=full_episode_history | |
) | |
input(f"You will spend: {price} dollars. ok?") | |
# prepare frames list to save to gif | |
frames = [] | |
assert args.max_steps <= 20 | |
success_rates = [] | |
# episodes start | |
for episode in range(args.episodes): | |
print("Episode:", episode) | |
episode_history_text = new_episode_marker() | |
success = False | |
episode_seed = args.seed + episode | |
env = make_env(args.env_name, episode_seed, env_args) | |
with open(prompt_log_filename, "a+") as f: | |
f.write("\n\n") | |
observations = [] | |
actions = [] | |
for i in range(int(args.max_steps)): | |
if i == 0: | |
obs, reward, done, info = reset(env) | |
prompt_action_text = "" | |
else: | |
with open(prompt_log_filename, "a+") as f: | |
f.write("\nnew prompt: -----------------------------------\n") | |
f.write(llm_prompt) | |
# querry the model | |
generation = generate(llm_prompt, args.model) | |
# parse the action | |
text_action = get_parsed_action(generation) | |
# get the raw action | |
action = text_action_to_action(text_action) | |
# execute the action | |
obs, reward, done, info = env.step(action) | |
prompt_action_text = f"{action_query()} {text_action}\n" | |
assert action_to_prompt_action_text(action) == prompt_action_text | |
actions.append(prompt_action_text) | |
text_obs = generate_text_obs(obs, info) | |
observations.append(text_obs) | |
step_text = prompt_preprocessor(prompt_action_text + text_obs) | |
print("Step text:") | |
print(step_text) | |
episode_history_text += step_text # append to history of this episode | |
if full_episode_history: | |
# feed full episode history | |
llm_prompt = in_context_examples + episode_history_text + action_query() | |
else: | |
n = min(last_n, len(observations)) | |
obs = observations[-n:] | |
act = (actions + [action_query()])[-n:] | |
episode_text = "".join([o+a for o, a in zip(obs, act)]) | |
llm_prompt = in_context_examples + new_episode_marker() + episode_text | |
llm_prompt = prompt_preprocessor(llm_prompt) | |
# save the image | |
env.render(mode="human") | |
rgb_img = plt_2_rgb(env) | |
frames.append(rgb_img) | |
if env.current_env.box.blocked and not env.current_env.box.is_open: | |
# target box is blocked -> apple can't be obtained | |
# break to save compute | |
break | |
if done: | |
# quadruple last frame to pause between episodes | |
for i in range(3): | |
same_img = np.copy(rgb_img) | |
# toggle a pixel between frames to avoid cropping when going from gif to mp4 | |
same_img[0, 0, 2] = 0 if (i % 2) == 0 else 255 | |
frames.append(same_img) | |
if reward > 0: | |
print("Success!") | |
episode_history_text += success_marker() | |
success = True | |
else: | |
episode_history_text += failure_marker() | |
with open(ep_h_log_filename, "a+") as f: | |
f.write("\nnew prompt: -----------------------------------\n") | |
f.write(episode_history_text) | |
break | |
else: | |
with open(ep_h_log_filename, "a+") as f: | |
f.write("\nnew prompt: -----------------------------------\n") | |
f.write(episode_history_text) | |
print(f"{'Success' if success else 'Failure'}") | |
success_rates.append(success) | |
mean_success_rate = np.mean(success_rates) | |
print("Success rate:", mean_success_rate) | |
print(f"Saving gif to {gif_savename}.") | |
mimsave(gif_savename, frames, duration=args.pause) | |
print("Done.") | |
log_data_dict = vars(args) | |
log_data_dict["success_rates"] = success_rates | |
log_data_dict["mean_success_rate"] = mean_success_rate | |
print("Evaluation log: ", evaluation_log_filename) | |
with open(evaluation_log_filename, "w") as f: | |
f.write(json.dumps(log_data_dict)) | |