import argparse import json import requests import time import warnings from n_tokens import estimate_price import pickle import numpy as np import torch from pathlib import Path # from utils.babyai_utils.baby_agent import load_agent from utils import * from textworld_utils.utils import generate_text_obs from models import * import subprocess import os from matplotlib import pyplot as plt from gym_minigrid.wrappers import * from gym_minigrid.window import Window from datetime import datetime from imageio import mimsave def new_episode_marker(): return "New episode.\n" def success_marker(): return "Success!\n" def failure_marker(): return "Failure!\n" def action_query(): return "Act :" def get_parsed_action(text_action): """ Parses the text generated by a model and extracts the action """ if "move forward" in text_action: return "move forward" elif "done" in text_action: return "done" elif "turn left" in text_action: return "turn left" elif "turn right" in text_action: return "turn right" elif "toggle" in text_action: return "toggle" elif "no_op" in text_action: return "no_op" else: warnings.warn(f"Undefined action {text_action}") return "no_op" def action_to_prompt_action_text(action): if np.allclose(action, [int(env.actions.forward), np.nan, np.nan], equal_nan=True): # 2 text_action = "move forward" elif np.allclose(action, [int(env.actions.left), np.nan, np.nan], equal_nan=True): # 0 text_action = "turn left" elif np.allclose(action, [int(env.actions.right), np.nan, np.nan], equal_nan=True): # 1 text_action = "turn right" elif np.allclose(action, [int(env.actions.toggle), np.nan, np.nan], equal_nan=True): # 3 text_action = "toggle" elif np.allclose(action, [int(env.actions.done), np.nan, np.nan], equal_nan=True): # 4 text_action = "done" elif np.allclose(action, [np.nan, np.nan, np.nan], equal_nan=True): text_action = "no_op" else: warnings.warn(f"Undefined action {action}") return "no_op" return f"{action_query()} {text_action}\n" def text_action_to_action(text_action): # text_action = get_parsed_action(text_action) if "move forward" == text_action: action = [int(env.actions.forward), np.nan, np.nan] elif "turn left" == text_action: action = [int(env.actions.left), np.nan, np.nan] elif "turn right" == text_action: action = [int(env.actions.right), np.nan, np.nan] elif "toggle" == text_action: action = [int(env.actions.toggle), np.nan, np.nan] elif "done" == text_action: action = [int(env.actions.done), np.nan, np.nan] elif "no_op" == text_action: action = [np.nan, np.nan, np.nan] return action def prompt_preprocessor(llm_prompt): # remove peer observations lines = llm_prompt.split("\n") new_lines = [] for line in lines: if line.startswith("#"): continue elif line.startswith("Conversation"): continue elif "peer" in line: caretaker = True if caretaker: # show only the location of the caretaker # this is very ugly, todo: refactor this assert "there is a" in line start_index = line.index('there is a') + 11 new_line = line[:start_index] + 'caretaker' new_lines.append(new_line) else: # no caretaker at all if line.startswith("Obs :") and "peer" in line: # remove only the peer descriptions line = "Obs :" new_lines.append(line) else: assert "peer" in line elif "Caretaker:" in line: line = line.replace("Caretaker:", "Caretaker says: ") new_lines.append(line) else: new_lines.append(line) return "\n".join(new_lines) # def generate_text_obs(obs, info): # # text_observation = obs_to_text(info) # # llm_prompt = "Obs : " # llm_prompt += "".join(text_observation) # # # add utterances # if obs["utterance_history"] != "Conversation: \n": # utt_hist = obs['utterance_history'] # utt_hist = utt_hist.replace("Conversation: \n","") # llm_prompt += utt_hist # # return llm_prompt # def obs_to_text(info): # image, vis_mask = info["image"], info["vis_mask"] # carrying = info["carrying"] # agent_pos_vx, agent_pos_vy = info["agent_pos_vx"], info["agent_pos_vy"] # npc_actions_dict = info["npc_actions_dict"] # # # (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state) # # State, 0: open, 1: closed, 2: locked # IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys())) # IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys())) # # list_textual_descriptions = [] # # if carrying is not None: # list_textual_descriptions.append("You carry a {} {}".format(carrying.color, carrying.type)) # # # agent_pos_vx, agent_pos_vy = self.get_view_coords(self.agent_pos[0], self.agent_pos[1]) # # view_field_dictionary = dict() # # for i in range(image.shape[0]): # for j in range(image.shape[1]): # if image[i][j][0] != 0 and image[i][j][0] != 1 and image[i][j][0] != 2: # if i not in view_field_dictionary.keys(): # view_field_dictionary[i] = dict() # view_field_dictionary[i][j] = image[i][j] # else: # view_field_dictionary[i][j] = image[i][j] # # # Find the wall if any # # We describe a wall only if there is no objects between the agent and the wall in straight line # # # Find wall in front # add_wall_descr = False # if add_wall_descr: # j = agent_pos_vy - 1 # object_seen = False # while j >= 0 and not object_seen: # if image[agent_pos_vx][j][0] != 0 and image[agent_pos_vx][j][0] != 1: # if image[agent_pos_vx][j][0] == 2: # list_textual_descriptions.append( # f"A wall is {agent_pos_vy - j} steps in front of you. \n") # forward # object_seen = True # else: # object_seen = True # j -= 1 # # Find wall left # i = agent_pos_vx - 1 # object_seen = False # while i >= 0 and not object_seen: # if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1: # if image[i][agent_pos_vy][0] == 2: # list_textual_descriptions.append( # f"A wall is {agent_pos_vx - i} steps to the left. \n") # left # object_seen = True # else: # object_seen = True # i -= 1 # # Find wall right # i = agent_pos_vx + 1 # object_seen = False # while i < image.shape[0] and not object_seen: # if image[i][agent_pos_vy][0] != 0 and image[i][agent_pos_vy][0] != 1: # if image[i][agent_pos_vy][0] == 2: # list_textual_descriptions.append( # f"A wall is {i - agent_pos_vx} steps to the right. \n") # right # object_seen = True # else: # object_seen = True # i += 1 # # # list_textual_descriptions.append("You see the following objects: ") # # returns the position of seen objects relative to you # for i in view_field_dictionary.keys(): # for j in view_field_dictionary[i].keys(): # if i != agent_pos_vx or j != agent_pos_vy: # object = view_field_dictionary[i][j] # # # # don't show npc # # if IDX_TO_OBJECT[object[0]] == "npc": # # continue # # front_dist = agent_pos_vy - j # left_right_dist = i - agent_pos_vx # # loc_descr = "" # if front_dist == 1 and left_right_dist == 0: # loc_descr += "Right in front of you " # # elif left_right_dist == 1 and front_dist == 0: # loc_descr += "Just to the right of you" # # elif left_right_dist == -1 and front_dist == 0: # loc_descr += "Just to the left of you" # # else: # front_str = str(front_dist) + " steps in front of you " if front_dist > 0 else "" # # loc_descr += front_str # # suff = "s" if abs(left_right_dist) > 0 else "" # and_ = "and" if loc_descr != "" else "" # # if left_right_dist < 0: # left_right_str = f"{and_} {-left_right_dist} step{suff} to the left" # loc_descr += left_right_str # # elif left_right_dist > 0: # left_right_str = f"{and_} {left_right_dist} step{suff} to the right" # loc_descr += left_right_str # # else: # left_right_str = "" # loc_descr += left_right_str # # loc_descr += f" there is a " # # obj_type = IDX_TO_OBJECT[object[0]] # if obj_type == "npc": # IDX_TO_STATE = {0: 'friendly', 1: 'antagonistic'} # # description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} peer. " # # # gaze # gaze_dir = { # 0: "towards you", # 1: "to the left of you", # 2: "in the same direction as you", # 3: "to the right of you", # } # description += f"It is looking {gaze_dir[object[3]]}. " # # # point # point_dir = { # 0: "towards you", # 1: "to the left of you", # 2: "in the same direction as you", # 3: "to the right of you", # } # # if object[4] != 255: # description += f"It is pointing {point_dir[object[4]]}. " # # # last action # last_action = {v: k for k, v in npc_actions_dict.items()}[object[5]] # # last_action = { # "go_forward": "foward", # "rotate_left": "turn left", # "rotate_right": "turn right", # "toggle_action": "toggle", # "point_stop_point": "stop pointing", # "point_E": "", # "point_S": "", # "point_W": "", # "point_N": "", # "stop_point": "stop pointing", # "no_op": "" # }[last_action] # # if last_action not in ["no_op", ""]: # description += f"It's last action is {last_action}. " # # elif obj_type in ["switch", "apple", "generatorplatform", "marble", "marbletee", "fence"]: # # todo: this assumes that Switch.no_light == True # description = f"{IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " # assert object[2:].mean() == 0 # # elif obj_type == "lockablebox": # IDX_TO_STATE = {0: 'open', 1: 'closed', 2: 'locked'} # description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " # assert object[3:].mean() == 0 # # elif obj_type == "applegenerator": # IDX_TO_STATE = {1: 'square', 2: 'round'} # description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " # assert object[3:].mean() == 0 # # elif obj_type == "remotedoor": # IDX_TO_STATE = {0: 'open', 1: 'closed'} # description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " # assert object[3:].mean() == 0 # # elif obj_type == "door": # IDX_TO_STATE = {0: 'open', 1: 'closed', 2: 'locked'} # description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} " # assert object[3:].mean() == 0 # # elif obj_type == "lever": # IDX_TO_STATE = {1: 'activated', 0: 'unactivated'} # if object[3] == 255: # countdown_txt = "" # else: # countdown_txt = f"with {object[3]} timesteps left. " # # description = f"{IDX_TO_STATE[object[2]]} {IDX_TO_COLOR[object[1]]} {IDX_TO_OBJECT[object[0]]} {countdown_txt}" # # assert object[4:].mean() == 0 # else: # raise ValueError(f"Undefined object type {obj_type}") # # full_destr = loc_descr + description + "\n" # # list_textual_descriptions.append(full_destr) # # if len(list_textual_descriptions) == 0: # list_textual_descriptions.append("\n") # # return list_textual_descriptions def plt_2_rgb(env): # data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8) # data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,)) width, height = env.window.fig.get_size_inches() * env.window.fig.get_dpi() data = np.fromstring(env.window.fig.canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3) return data def reset(env): env.reset() # a dirty trick just to get obs and info return env.step([np.nan, np.nan, np.nan]) # return step("no_op") def generate(text_input, model): # return "(a) move forward" if model == "dummy": print("dummy action forward") return "move forward" elif model == "interactive": return input("Enter action:") elif model == "random": print("random agent") print("PROMPT:") print(text_input) return random.choice([ "move forward", "turn left", "turn right", "toggle", ]) elif model in ["gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-4-0613", "gpt-4-0314"]: while True: try: c = openai.ChatCompletion.create( model=model, messages=[ # {"role": "system", "content": ""}, # {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, # {"role": "user", "content": "Continue the following text in the most logical way.\n"+text_input} # {"role": "system", "content": # "You are an agent and can use the following actions: 'move forward', 'toggle', 'turn left', 'turn right', 'done'." # # "The caretaker will say the color of the box which you should open. Turn until you find this box and toggle it when it is right in front of it." # # "Then an apple will appear and you can toggle it to succeed." # }, {"role": "user", "content": text_input} ], max_tokens=3, n=1, temperature=0.0, request_timeout=30, ) break except Exception as e: print(e) print("Pausing") time.sleep(10) continue print("->LLM generation: ", c['choices'][0]['message']['content']) return c['choices'][0]['message']['content'] elif re.match(r"text-.*-\d{3}", model) or model in ["gpt-3.5-turbo-instruct-0914"]: while True: try: response = openai.Completion.create( model=model, prompt=text_input, # temperature=0.7, temperature=0.0, max_tokens=3, top_p=1, frequency_penalty=0, presence_penalty=0, timeout=30 ) break except Exception as e: print(e) print("Pausing") time.sleep(10) continue choices = response["choices"] assert len(choices) == 1 return choices[0]["text"].strip().lower() # remove newline from the end elif model in ["gpt2_large", "api_bloom"]: # HF_TOKEN = os.getenv("HF_TOKEN") if model == "gpt2_large": API_URL = "https://api-inference.huggingface.co/models/gpt2-large" elif model == "api_bloom": API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" else: raise ValueError(f"Undefined model {model}.") headers = {"Authorization": f"Bearer {HF_TOKEN}"} def query(text_prompt, n_tokens=3): input = text_prompt # make n_tokens request and append the output each time - one request generates one token for _ in range(n_tokens): # prepare request payload = { "inputs": input, "parameters": { "do_sample": False, 'temperature': 0, 'wait_for_model': True, # "max_length": 500, # for gpt2 # "max_new_tokens": 250 # fot gpt2-xl }, } data = json.dumps(payload) # request response = requests.request("POST", API_URL, headers=headers, data=data) response_json = json.loads(response.content.decode("utf-8")) if type(response_json) is list and len(response_json) == 1: # generated_text contains the input + the response response_full_text = response_json[0]['generated_text'] # we use this as the next input input = response_full_text else: print("Invalid request to huggingface api") from IPython import embed; embed() # remove the prompt from the beginning assert response_full_text.startswith(text_prompt) response_text = response_full_text[len(text_prompt):] return response_text response = query(text_input).strip().lower() return response elif model in ["bloom_560m"]: # from transformers import BloomForCausalLM # from transformers import BloomTokenizerFast # # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") # model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") inputs = hf_tokenizer(text_input, return_tensors="pt") # 3 words result_length = inputs['input_ids'].shape[-1]+3 full_output = hf_tokenizer.decode(hf_model.generate(inputs["input_ids"], max_length=result_length)[0]) assert full_output.startswith(text_input) response = full_output[len(text_input):] response = response.strip().lower() return response else: raise ValueError("Unknown model.") def estimate_tokens_selenium(prompt): # selenium is used because python3.9 is needed for tiktoken from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC import time # Initialize the WebDriver instance options = webdriver.ChromeOptions() options.add_argument('headless') # set up the driver driver = webdriver.Chrome(options=options) # Navigate to the website driver.get('https://platform.openai.com/tokenizer') text_input = driver.find_element(By.XPATH, '/html/body/div[1]/div[1]/div/div[2]/div[3]/textarea') text_input.clear() text_input.send_keys(prompt) n_tokens = 0 while n_tokens == 0: time.sleep(1) # Wait for the response to be loaded wait = WebDriverWait(driver, 10) response = wait.until( EC.presence_of_element_located((By.CSS_SELECTOR, 'div.tokenizer-stat:nth-child(1) > div:nth-child(2)'))) n_tokens = int(response.text.replace(",", "")) # Close the WebDriver instance driver.quit() return n_tokens def load_in_context_examples(in_context_episodes): in_context_examples = "" print(f'Loading {len(in_context_episodes)} examples.') for episode_data in in_context_episodes: in_context_examples += new_episode_marker() for step_i, step_data in enumerate(episode_data): action = step_data["action"] info = step_data["info"] obs = step_data["obs"] reward = step_data["reward"] done = step_data["done"] if step_i == 0: # step 0 is the initial state of the environment assert action is None prompt_action_text = "" else: prompt_action_text = action_to_prompt_action_text(action) text_obs = generate_text_obs(obs, info) step_text = prompt_preprocessor(prompt_action_text + text_obs) in_context_examples += step_text if done: if reward > 0: in_context_examples += success_marker() else: in_context_examples += failure_marker() else: # in all envs reward is given in the end # in the initial step rewards is None assert reward == 0 or (step_i == 0 and reward is None) print("-------------------------- IN CONTEXT EXAMPLES --------------------------") print(in_context_examples) print("-------------------------------------------------------------------------") return in_context_examples if __name__ == "__main__": # Parse arguments parser = argparse.ArgumentParser() parser.add_argument("--model", required=False, help="text-ada-001") parser.add_argument("--seed", type=int, default=0, help="Seed of the first episode. The seed for the following episodes will be used in order: seed, seed + 1, ... seed + (n_episodes-1) (default: 0)") parser.add_argument("--max-steps", type=int, default=15, help="max num of steps") parser.add_argument("--shift", type=int, default=0, help="number of times the environment is reset at the beginning (default: 0)") parser.add_argument("--argmax", action="store_true", default=False, help="select the action with highest probability (default: False)") parser.add_argument("--pause", type=float, default=0.5, help="pause duration between two consequent actions of the agent (default: 0.5)") parser.add_argument("--env-name", type=str, default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1", # default="SocialAI-ColorBoxesLLMCSParamEnv-v1", required=False, help="env name") parser.add_argument("--in-context-path", type=str, # old # default='llm_data/in_context_asocial_box.txt' # default='llm_data/in_context_color_boxes.txt', # new # asocial box default='llm_data/in_context_examples/in_context_asocialbox_SocialAI-AsocialBoxInformationSeekingParamEnv-v1_2023_07_19_19_28_48/episodes.pkl', # colorbox # default='llm_data/in_context_examples/in_context_colorbox_SocialAI-ColorBoxesLLMCSParamEnv-v1_2023_07_20_13_11_54/episodes.pkl', required=False, help="path to in context examples") parser.add_argument("--episodes", type=int, default=10, help="number of episodes to visualize") parser.add_argument("--env-args", nargs='*', default=None) parser.add_argument("--agent_view", default=False, help="draw the agent sees (partially observable view)", action='store_true' ) parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32 ) parser.add_argument("--mask-unobserved", default=False, help="mask cells that are not observed by the agent", action='store_true' ) parser.add_argument("--log", type=str, default="llm_log/episodes_log", help="log from the run", required=False) parser.add_argument("--feed-full-ep", default=False, help="weather to append the whole episode to the prompt", action='store_true') parser.add_argument("--last-n", type=int, help="how many last steps to provide in observation (if not feed-full-ep)", default=3) parser.add_argument("--skip-check", default=False, help="Don't estimate the price.", action="store_true") args = parser.parse_args() # Set seed for all randomness sources seed(args.seed) model = args.model in_context_examples_path = args.in_context_path # test for paper: remove later if "asocialbox" in in_context_examples_path: assert args.env_name == "SocialAI-AsocialBoxInformationSeekingParamEnv-v1" elif "colorbox" in in_context_examples_path: assert args.env_name == "SocialAI-ColorBoxesLLMCSParamEnv-v1" print("env name:", args.env_name) print("examples:", in_context_examples_path) print("model:", args.model) # datetime now = datetime.now() datetime_string = now.strftime("%d_%m_%Y_%H:%M:%S") print(datetime_string) # log filenames log_folder = args.log+"_"+datetime_string+"/" os.mkdir(log_folder) evaluation_log_filename = log_folder+"evaluation_log.json" prompt_log_filename = log_folder + "prompt_log.txt" ep_h_log_filename = log_folder+"episode_history_query.txt" gif_savename = log_folder + "demo.gif" env_args = env_args_str_to_dict(args.env_args) env = make_env(args.env_name, args.seed, env_args) # env = gym.make(args.env_name, **env_args) print(f"Environment {args.env_name} and args: {env_args_str_to_dict(args.env_args)}\n") # Define agent print("Agent loaded\n") # prepare models model_instance = None if "text" in args.model or "gpt-3" in args.model or "gpt-4" in args.model: import openai openai.api_key = os.getenv("OPENAI_API_KEY") elif args.model in ["gpt2_large", "api_bloom"]: HF_TOKEN = os.getenv("HF_TOKEN") elif args.model in ["bloom_560m"]: from transformers import BloomForCausalLM from transformers import BloomTokenizerFast hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") elif args.model in ["bloom"]: from transformers import BloomForCausalLM from transformers import BloomTokenizerFast hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/") hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/") model_instance = (hf_tokenizer, hf_model) with open(in_context_examples_path, "rb") as f: in_context_episodes = pickle.load(f) in_context_examples = load_in_context_examples(in_context_episodes) with open(prompt_log_filename, "a+") as f: f.write(datetime_string) with open(ep_h_log_filename, "a+") as f: f.write(datetime_string) full_episode_history = args.feed_full_ep last_n=args.last_n if full_episode_history: print("Full episode history.") else: print(f"Last {args.last_n} steps.") if not args.skip_check and not args.model in ["dummy", "random", "interactive"]: print(f"Estimating price for model {args.model}.") in_context_n_tokens = estimate_tokens_selenium(in_context_examples) n_in_context_steps = sum([len(ep) for ep in in_context_episodes]) tokens_per_step = in_context_n_tokens / n_in_context_steps _, price = estimate_price( num_of_episodes=args.episodes, in_context_len=in_context_n_tokens, tokens_per_step=tokens_per_step, n_steps=args.max_steps, last_n=last_n, model=args.model, feed_episode_history=full_episode_history ) input(f"You will spend: {price} dollars. ok?") # prepare frames list to save to gif frames = [] assert args.max_steps <= 20 success_rates = [] # episodes start for episode in range(args.episodes): print("Episode:", episode) episode_history_text = new_episode_marker() success = False episode_seed = args.seed + episode env = make_env(args.env_name, episode_seed, env_args) with open(prompt_log_filename, "a+") as f: f.write("\n\n") observations = [] actions = [] for i in range(int(args.max_steps)): if i == 0: obs, reward, done, info = reset(env) prompt_action_text = "" else: with open(prompt_log_filename, "a+") as f: f.write("\nnew prompt: -----------------------------------\n") f.write(llm_prompt) # querry the model generation = generate(llm_prompt, args.model) # parse the action text_action = get_parsed_action(generation) # get the raw action action = text_action_to_action(text_action) # execute the action obs, reward, done, info = env.step(action) prompt_action_text = f"{action_query()} {text_action}\n" assert action_to_prompt_action_text(action) == prompt_action_text actions.append(prompt_action_text) text_obs = generate_text_obs(obs, info) observations.append(text_obs) step_text = prompt_preprocessor(prompt_action_text + text_obs) print("Step text:") print(step_text) episode_history_text += step_text # append to history of this episode if full_episode_history: # feed full episode history llm_prompt = in_context_examples + episode_history_text + action_query() else: n = min(last_n, len(observations)) obs = observations[-n:] act = (actions + [action_query()])[-n:] episode_text = "".join([o+a for o, a in zip(obs, act)]) llm_prompt = in_context_examples + new_episode_marker() + episode_text llm_prompt = prompt_preprocessor(llm_prompt) # save the image env.render(mode="human") rgb_img = plt_2_rgb(env) frames.append(rgb_img) if env.current_env.box.blocked and not env.current_env.box.is_open: # target box is blocked -> apple can't be obtained # break to save compute break if done: # quadruple last frame to pause between episodes for i in range(3): same_img = np.copy(rgb_img) # toggle a pixel between frames to avoid cropping when going from gif to mp4 same_img[0, 0, 2] = 0 if (i % 2) == 0 else 255 frames.append(same_img) if reward > 0: print("Success!") episode_history_text += success_marker() success = True else: episode_history_text += failure_marker() with open(ep_h_log_filename, "a+") as f: f.write("\nnew prompt: -----------------------------------\n") f.write(episode_history_text) break else: with open(ep_h_log_filename, "a+") as f: f.write("\nnew prompt: -----------------------------------\n") f.write(episode_history_text) print(f"{'Success' if success else 'Failure'}") success_rates.append(success) mean_success_rate = np.mean(success_rates) print("Success rate:", mean_success_rate) print(f"Saving gif to {gif_savename}.") mimsave(gif_savename, frames, duration=args.pause) print("Done.") log_data_dict = vars(args) log_data_dict["success_rates"] = success_rates log_data_dict["mean_success_rate"] = mean_success_rate print("Evaluation log: ", evaluation_log_filename) with open(evaluation_log_filename, "w") as f: f.write(json.dumps(log_data_dict))