# python -m scripts.LLM_test --gif test_GPT_boxes --episodes 1 --max-steps 8 --model text-davinci-003 --env-args size 6 --env-name SocialAI-ColorBoxesLLMCSParamEnv-v1 --in-context-path llm_data/in_context_color_boxes.txt # python -m scripts.LLM_test --gif test_GPT_asoc --episodes 1 --max-steps 8 --model text-ada-001 --env-args size 6 --env-name SocialAI-AsocialBoxInformationSeekingParamEnv-v1 --in-context-path llm_data/in_context_asocial_box.txt --feed-full-ep # python -m scripts.LLM_test --gif test_GPT_boxes --episodes 1 --max-steps 8 --model bloom_560m --env-args size 6 --env-name SocialAI-ColorBoxesLLMCSParamEnv-v1 --in-context-path llm_data/in_context_color_boxes.txt # python -m scripts.LLM_test --gif test_GPT_asoc --episodes 1 --max-steps 8 --model bloom_560m --env-args size 6 --env-name SocialAI-AsocialBoxInformationSeekingParamEnv-v1 --in-context-path llm_data/in_context_asocial_box.txt --feed-full-ep ## bloom 560m # boxes # python -m scripts.LLM_test --log llm_log/bloom_560m_boxes_no_hist --gif evaluation --episodes 20 --max-steps 10 --model bloom_560m --env-args size 6 --env-name SocialAI-ColorBoxesLLMCSParamEnv-v1 --in-context-path llm_data/in_context_color_boxes.txt # asocial # python -m scripts.LLM_test --log llm_log/bloom_560m_asocial_no_hist --gif evaluation --episodes 20 --max-steps 10 --model bloom_560m --env-args size 6 --env-name SocialAI-AsocialBoxInformationSeekingParamEnv-v1 --in-context-path llm_data/in_context_asocial_box.txt # random # python -m scripts.LLM_test --log llm_log/random_boxes --gif evaluation --episodes 20 --max-steps 10 --model random --env-args size 6 --env-name SocialAI-ColorBoxesLLMCSParamEnv-v1 --in-context-path llm_data/in_context_color_boxes.txt import argparse import json import requests import time import warnings from n_tokens import estimate_price import numpy as np import torch from pathlib import Path from utils.babyai_utils.baby_agent import load_agent from utils import * from models import * import subprocess import os from matplotlib import pyplot as plt from gym_minigrid.wrappers import * from gym_minigrid.window import Window from datetime import datetime from imageio import mimsave def prompt_preprocessor(llm_prompt): # remove peer observations lines = llm_prompt.split("\n") new_lines = [] for line in lines: if line.startswith("#"): continue elif line.startswith("Conversation"): continue elif "peer" in line: caretaker = True if caretaker: # show only the location of the caretaker # this is very ugly, todo: refactor this assert "there is a" in line start_index = line.index('there is a') + 11 new_line = line[:start_index] + 'caretaker' new_lines.append(new_line) else: # no caretaker at all if line.startswith("Obs :") and "peer" in line: # remove only the peer descriptions line = "Obs :" new_lines.append(line) else: assert "peer" in line elif "Caretaker:" in line: # line = line.replace("Caretaker:", "Caretaker says: '") + "'" new_lines.append(line) else: new_lines.append(line) return "\n".join(new_lines) # Parse arguments parser = argparse.ArgumentParser() parser.add_argument("--model", required=False, help="text-ada-001") parser.add_argument("--seed", type=int, default=0, help="Seed of the first episode. The seed for the following episodes will be used in order: seed, seed + 1, ... seed + (n_episodes-1) (default: 0)") parser.add_argument("--max-steps", type=int, default=5, help="max num of steps") parser.add_argument("--shift", type=int, default=0, help="number of times the environment is reset at the beginning (default: 0)") parser.add_argument("--argmax", action="store_true", default=False, help="select the action with highest probability (default: False)") parser.add_argument("--pause", type=float, default=0.5, help="pause duration between two consequent actions of the agent (default: 0.5)") parser.add_argument("--env-name", type=str, # default="SocialAI-ELangColorBoxesTestInformationSeekingParamEnv-v1", # default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1", default="SocialAI-ColorBoxesLLMCSParamEnv-v1", required=False, help="env name") parser.add_argument("--in-context-path", type=str, # default='llm_data/short_in_context_boxes.txt' # default='llm_data/in_context_asocial_box.txt' default='llm_data/in_context_color_boxes.txt', required=False, help="path to in context examples") parser.add_argument("--gif", type=str, default="visualization", help="store output as gif with the given filename", required=False) parser.add_argument("--episodes", type=int, default=1, help="number of episodes to visualize") parser.add_argument("--env-args", nargs='*', default=None) parser.add_argument("--agent_view", default=False, help="draw the agent sees (partially observable view)", action='store_true' ) parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32 ) parser.add_argument("--mask-unobserved", default=False, help="mask cells that are not observed by the agent", action='store_true' ) parser.add_argument("--log", type=str, default="llm_log/episodes_log", help="log from the run", required=False) parser.add_argument("--feed-full-ep", default=False, help="weather to append the whole episode to the prompt", action='store_true') parser.add_argument("--skip-check", default=False, help="Don't estimate the price.", action="store_true") args = parser.parse_args() # Set seed for all randomness sources seed(args.seed) model = args.model in_context_examples_path = args.in_context_path print("env name:", args.env_name) print("examples:", in_context_examples_path) print("model:", args.model) # datetime now = datetime.now() datetime_string = now.strftime("%d_%m_%Y_%H:%M:%S") print(datetime_string) # log filenames log_folder = args.log+"_"+datetime_string+"/" os.mkdir(log_folder) evaluation_log_filename = log_folder+"evaluation_log.json" prompt_log_filename = log_folder + "prompt_log.txt" ep_h_log_filename = log_folder+"episode_history_query.txt" gif_savename = log_folder + args.gif + ".gif" assert "viz" not in gif_savename # don't use viz anymore env_args = env_args_str_to_dict(args.env_args) env = make_env(args.env_name, args.seed, env_args) # env = gym.make(args.env_name, **env_args) print(f"Environment {args.env_name} and args: {env_args_str_to_dict(args.env_args)}\n") # Define agent print("Agent loaded\n") # prepare models if args.model in ["text-davinci-003", "text-ada-001", "gpt-3.5-turbo-0301"]: import openai openai.api_key = os.getenv("OPENAI_API_KEY") elif args.model in ["gpt2_large", "api_bloom"]: HF_TOKEN = os.getenv("HF_TOKEN") elif args.model in ["bloom_560m"]: from transformers import BloomForCausalLM from transformers import BloomTokenizerFast hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") elif args.model in ["bloom"]: from transformers import BloomForCausalLM from transformers import BloomTokenizerFast hf_tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/") hf_model = BloomForCausalLM.from_pretrained("bigscience/bloom", cache_dir=".cache/huggingface/") def plt_2_rgb(env): # data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8) # data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,)) width, height = env.window.fig.get_size_inches() * env.window.fig.get_dpi() data = np.fromstring(env.window.fig.canvas.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3) return data def generate(text_input, model): # return "(a) move forward" if model == "dummy": print("dummy action forward") return "move forward" elif model == "random": print("random agent") return random.choice([ "move forward", "turn left", "turn right", "toggle", ]) elif model in ["gpt-3.5-turbo-0301"]: while True: try: c = openai.ChatCompletion.create( model=model, messages=[ # {"role": "system", "content": ""}, # {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, # {"role": "user", "content": "Continue the following text in the most logical way.\n"+text_input} {"role": "user", "content": text_input} ], max_tokens=3, n=1, temperature=0, request_timeout=30, ) break except Exception as e: print(e) print("Pausing") time.sleep(10) continue print("generation: ", c['choices'][0]['message']['content']) return c['choices'][0]['message']['content'] elif model in ["text-davinci-003", "text-ada-001"]: while True: try: response = openai.Completion.create( model=model, prompt=text_input, # temperature=0.7, temperature=0.0, max_tokens=3, top_p=1, frequency_penalty=0, presence_penalty=0, timeout=30 ) break except Exception as e: print(e) print("Pausing") time.sleep(10) continue choices = response["choices"] assert len(choices) == 1 return choices[0]["text"].strip().lower() # remove newline from the end elif model in ["gpt2_large", "api_bloom"]: # HF_TOKEN = os.getenv("HF_TOKEN") if model == "gpt2_large": API_URL = "https://api-inference.huggingface.co/models/gpt2-large" elif model == "api_bloom": API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" else: raise ValueError(f"Undefined model {model}.") headers = {"Authorization": f"Bearer {HF_TOKEN}"} def query(text_prompt, n_tokens=3): input = text_prompt # make n_tokens request and append the output each time - one request generates one token for _ in range(n_tokens): # prepare request payload = { "inputs": input, "parameters": { "do_sample": False, 'temperature': 0, 'wait_for_model': True, # "max_length": 500, # for gpt2 # "max_new_tokens": 250 # fot gpt2-xl }, } data = json.dumps(payload) # request response = requests.request("POST", API_URL, headers=headers, data=data) response_json = json.loads(response.content.decode("utf-8")) if type(response_json) is list and len(response_json) == 1: # generated_text contains the input + the response response_full_text = response_json[0]['generated_text'] # we use this as the next input input = response_full_text else: print("Invalid request to huggingface api") from IPython import embed; embed() # remove the prompt from the beginning assert response_full_text.startswith(text_prompt) response_text = response_full_text[len(text_prompt):] return response_text response = query(text_input).strip().lower() return response elif model in ["bloom_560m"]: # from transformers import BloomForCausalLM # from transformers import BloomTokenizerFast # # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") # model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m", cache_dir=".cache/huggingface/") inputs = hf_tokenizer(text_input, return_tensors="pt") # 3 words result_length = inputs['input_ids'].shape[-1]+3 full_output = hf_tokenizer.decode(hf_model.generate(inputs["input_ids"], max_length=result_length)[0]) assert full_output.startswith(text_input) response = full_output[len(text_input):] response = response.strip().lower() return response else: raise ValueError("Unknown model.") def get_parsed_action(text_action): if "move forward" in text_action: return "move forward" elif "turn left" in text_action: return "turn left" elif "turn right" in text_action: return "turn right" elif "toggle" in text_action: return "toggle" elif "no_op" in text_action: return "no_op" else: warnings.warn(f"Undefined action {text_action}") return "no_op" def step(text_action): text_action = get_parsed_action(text_action) if "move forward" == text_action: action = [int(env.actions.forward), np.nan, np.nan] elif "turn left" == text_action: action = [int(env.actions.left), np.nan, np.nan] elif "turn right" == text_action: action = [int(env.actions.right), np.nan, np.nan] elif "toggle" == text_action: action = [int(env.actions.toggle), np.nan, np.nan] elif "no_op" == text_action: action = [np.nan, np.nan, np.nan] # if text_action.startswith("a"): # action = [int(env.actions.forward), np.nan, np.nan] # # elif text_action.startswith("b"): # action = [int(env.actions.left), np.nan, np.nan] # # elif text_action.startswith("c"): # action = [int(env.actions.right), np.nan, np.nan] # # elif text_action.startswith("d"): # action = [int(env.actions.toggle), np.nan, np.nan] # # elif text_action.startswith("e"): # action = [np.nan, np.nan, np.nan] # # else: # print("Unknown action.") obs, reward, done, info = env.step(action) return obs, reward, done, info def reset(env): env.reset() # a dirty trick just to get obs and info return step("no_op") def generate_text_obs(obs, info): llm_prompt = "Obs : " llm_prompt += "".join(info["descriptions"]) if obs["utterance_history"] != "Conversation: \n": utt_hist = obs['utterance_history'] utt_hist = utt_hist.replace("Conversation: \n","") llm_prompt += utt_hist return llm_prompt def action_query(): # llm_prompt = "" # llm_prompt += "Your possible actions are:\n" # llm_prompt += "(a) move forward\n" # llm_prompt += "(b) turn left\n" # llm_prompt += "(c) turn right\n" # llm_prompt += "(d) toggle\n" # llm_prompt += "(e) no_op\n" # llm_prompt += "Your next action is: (" llm_prompt = "Act :" return llm_prompt # lod context examples with open(in_context_examples_path, "r") as f: in_context_examples = f.read() with open(prompt_log_filename, "a+") as f: f.write(datetime_string) with open(ep_h_log_filename, "a+") as f: f.write(datetime_string) feed_episode_history = args.feed_full_ep # asoc in_context_n_tokens = 800 ep_obs_len = 50 * 3 # color in_context_n_tokens = 1434 # ep_obs_len = 70 # feed only current obs if feed_episode_history: ep_obs_len = 50 else: # last_n = 1 # last_n = 2 last_n = 3 ep_obs_len = 50 * last_n _, price = estimate_price( num_of_episodes=args.episodes, in_context_len=in_context_n_tokens, ep_obs_len=ep_obs_len, n_steps=args.max_steps, model=args.model, feed_episode_history=feed_episode_history ) if not args.skip_check: input(f"You will spend: {price} dollars. (in context: {in_context_n_tokens} obs: {ep_obs_len}), ok?") # prepare frames list to save to gif frames = [] assert args.max_steps <= 20 success_rates = [] # episodes start for episode in range(args.episodes): print("Episode:", episode) new_episode_text = "New episode.\n" episode_history_text = new_episode_text success = False episode_seed = args.seed + episode env = make_env(args.env_name, episode_seed, env_args) with open(prompt_log_filename, "a+") as f: f.write("\n\n") observations = [] actions = [] for i in range(int(args.max_steps)): if i == 0: obs, reward, done, info = reset(env) action_text = "" else: with open(prompt_log_filename, "a+") as f: f.write("\nnew prompt: -----------------------------------\n") f.write(llm_prompt) text_action = generate(llm_prompt, args.model) obs, reward, done, info = step(text_action) action_text = f"Act : {get_parsed_action(text_action)}\n" actions.append(action_text) print(action_text) text_obs = generate_text_obs(obs, info) observations.append(text_obs) print(prompt_preprocessor(text_obs)) # feed the full episode history episode_history_text += prompt_preprocessor(action_text + text_obs) # append to history of this episode if feed_episode_history: # feed full episode history llm_prompt = in_context_examples + episode_history_text + action_query() else: n = min(last_n, len(observations)) obs = observations[-n:] act = (actions + [action_query()])[-n:] episode_text = "".join([o+a for o,a in zip(obs, act)]) llm_prompt = in_context_examples + new_episode_text + episode_text llm_prompt = prompt_preprocessor(llm_prompt) # save the image env.render(mode="human") rgb_img = plt_2_rgb(env) frames.append(rgb_img) if env.current_env.box.blocked and not env.current_env.box.is_open: # target box is blocked -> apple can't be obtained # break to save compute break if done: # quadruple last frame to pause between episodes for i in range(3): same_img = np.copy(rgb_img) # toggle a pixel between frames to avoid cropping when going from gif to mp4 same_img[0, 0, 2] = 0 if (i % 2) == 0 else 255 frames.append(same_img) if reward > 0: print("Success!") episode_history_text += "Success!\n" success = True else: episode_history_text += "Failure!\n" with open(ep_h_log_filename, "a+") as f: f.write("\nnew prompt: -----------------------------------\n") f.write(episode_history_text) break else: with open(ep_h_log_filename, "a+") as f: f.write("\nnew prompt: -----------------------------------\n") f.write(episode_history_text) print(f"{'Success' if success else 'Failure'}") success_rates.append(success) mean_success_rate = np.mean(success_rates) print("Success rate:", mean_success_rate) print(f"Saving gif to {gif_savename}.") mimsave(gif_savename, frames, duration=args.pause) print("Done.") log_data_dict = vars(args) log_data_dict["success_rates"] = success_rates log_data_dict["mean_success_rate"] = mean_success_rate print("Evaluation log: ", evaluation_log_filename) with open(evaluation_log_filename, "w") as f: f.write(json.dumps(log_data_dict))