SocialAISchool / scripts /create_LLM_examples.py
grg's picture
Cleaned old git history
be5548b
raw
history blame
6.82 kB
#!/usr/bin/env python3
import argparse
from gym_minigrid.window import Window
from utils import *
import gym
import pickle
from datetime import datetime
episodes = []
record = [False]
def update_caption_with_recording_indicator():
new_caption = f"Recoding {'ON' if record[0] else 'OFF'}\n------------------\n\n" + window.caption.get_text()
window.set_caption(new_caption)
def redraw(img):
if not args.agent_view:
img = env.render('rgb_array', tile_size=args.tile_size, mask_unobserved=args.mask_unobserved)
# adds the rocding
update_caption_with_recording_indicator()
window.show_img(img)
def start_recording():
record[0] = True
print("Recording started")
episodes[-1][-1]["record"]=True
def reset():
episodes.append([])
obs, info = env.reset_with_info()
record[0] = False
redraw(obs)
episodes[-1].append(
{
"action": None,
"info": info,
"obs": obs,
"reward": None,
"done": None,
"record": record[0],
}
)
def step(action):
if type(action) == np.ndarray:
obs, reward, done, info = env.step(action)
else:
action = [int(action), np.nan, np.nan]
obs, reward, done, info = env.step(action)
episodes[-1].append(
{
"action": action,
"info": info,
"obs": obs,
"reward": reward,
"done": done,
"record": record[0],
}
)
redraw(obs)
if done:
print('done!')
print('Reward=%.2f' % (reward))
# reset and add initial state to episodes
reset()
else:
print('\nStep=%s' % (env.step_count))
# filter steps without recording
episodes_to_save = [[s for s in ep if s["record"]] for ep in episodes]
episodes_to_save = [ep for ep in episodes_to_save if len(ep) > 0]
# set first recording step to be as if it was just reset (the real first step)
for ep_to_save in episodes_to_save:
ep_to_save[0]["action"]=None
ep_to_save[0]["reward"]=None
ep_to_save[0]["done"]=None
# picle the episodes
dump_pickle = Path(output_dir) / "episodes.pkl"
print(f"Saving {len(episodes_to_save)} episodes ({[len(e) for e in episodes_to_save]}) to : {dump_pickle}")
with open(dump_pickle, 'wb') as f:
pickle.dump(episodes_to_save, f)
def key_handler(event):
print('pressed', event.key)
if event.key == 'r':
start_recording()
return
if event.key == 'escape':
window.close()
return
if event.key == 's':
reset()
return
if event.key == 'tab':
step(np.array([np.nan, np.nan, np.nan]))
return
if event.key == 'shift':
step(np.array([np.nan, np.nan, np.nan]))
return
if event.key == 'left':
step(env.actions.left)
return
if event.key == 'right':
step(env.actions.right)
return
if event.key == 'up':
step(env.actions.forward)
return
if event.key == 't':
step(env.actions.speak)
return
if event.key == '1':
step(np.array([np.nan, 0, 0]))
return
if event.key == '2':
step(np.array([np.nan, 0, 1]))
return
if event.key == '3':
step(np.array([np.nan, 1, 0]))
return
if event.key == '4':
step(np.array([np.nan, 1, 1]))
return
if event.key == '5':
step(np.array([np.nan, 2, 2]))
return
if event.key == '6':
step(np.array([np.nan, 1, 2]))
return
if event.key == '7':
step(np.array([np.nan, 2, 1]))
return
if event.key == '8':
step(np.array([np.nan, 1, 3]))
return
if event.key == 'p':
step(np.array([np.nan, 3, 3]))
return
# Spacebar
if event.key == ' ':
step(env.actions.toggle)
return
if event.key == '9':
step(env.actions.pickup)
return
if event.key == '0':
step(env.actions.drop)
return
if event.key == 'enter':
step(env.actions.done)
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--env",
help="gym environment to load",
# default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1",
# default="SocialAI-ColorBoxesLLMCSParamEnv-v1",
default="SocialAI-ColorLLMCSParamEnv-v1",
)
parser.add_argument(
"--seed",
type=int,
help="random seed to generate the environment with",
default=-1
)
parser.add_argument(
"--tile_size",
type=int,
help="size at which to render tiles",
default=32
)
parser.add_argument(
'--agent_view',
default=False,
help="draw the agent sees (partially observable view)",
action='store_true'
)
parser.add_argument(
'--mask-unobserved',
default=False,
help="mask cells that are not observed by the agent",
action='store_true'
)
parser.add_argument(
'--save-dir',
default="./llm_data/in_context_examples/",
help="file where to save episodes",
)
parser.add_argument(
'--load',
default=None,
help="Load in context examples to append to",
)
parser.add_argument(
'--name',
default="in_context",
help="additional name tag for the episodes",
)
parser.add_argument(
'--draw-tree',
action="store_true",
help="Draw the sampling treee",
)
# Put all env related arguments after --env_args, e.g. --env_args nb_foo 1 is_bar True
parser.add_argument("--env-args", nargs='*', default=None)
args = parser.parse_args()
env = gym.make(args.env, **env_args_str_to_dict(args.env_args))
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
output_dir = Path(args.save_dir) / f"{args.name}_{args.env}_{timestamp}"
os.makedirs(output_dir, exist_ok=True)
if args.load:
with open(args.load, 'rb') as f:
episodes = pickle.load(f)
if args.draw_tree:
# draw tree
env.parameter_tree.draw_tree(
filename=output_dir / f"/{args.env}_raw_tree",
ignore_labels=["Num_of_colors"],
)
if args.seed >= 0:
env.seed(args.seed)
window = Window('gym_minigrid - ' + args.env, figsize=(6, 4))
window.reg_key_handler(key_handler)
env.window = window
reset()
# # a trick to make the first image appear right away
# # this action is not saved
# obs, _, _, _ = env.step(np.array([np.nan, np.nan, np.nan]))
# redraw(obs)
# Blocking event loop
window.show(block=True)