Spaces:
Running
Running
#!/usr/bin/env python3 | |
import argparse | |
from gym_minigrid.window import Window | |
from utils import * | |
import gym | |
import pickle | |
from datetime import datetime | |
episodes = [] | |
record = [False] | |
def update_caption_with_recording_indicator(): | |
new_caption = f"Recoding {'ON' if record[0] else 'OFF'}\n------------------\n\n" + window.caption.get_text() | |
window.set_caption(new_caption) | |
def redraw(img): | |
if not args.agent_view: | |
img = env.render('rgb_array', tile_size=args.tile_size, mask_unobserved=args.mask_unobserved) | |
# adds the rocding | |
update_caption_with_recording_indicator() | |
window.show_img(img) | |
def start_recording(): | |
record[0] = True | |
print("Recording started") | |
episodes[-1][-1]["record"]=True | |
def reset(): | |
episodes.append([]) | |
obs, info = env.reset_with_info() | |
record[0] = False | |
redraw(obs) | |
episodes[-1].append( | |
{ | |
"action": None, | |
"info": info, | |
"obs": obs, | |
"reward": None, | |
"done": None, | |
"record": record[0], | |
} | |
) | |
def step(action): | |
if type(action) == np.ndarray: | |
obs, reward, done, info = env.step(action) | |
else: | |
action = [int(action), np.nan, np.nan] | |
obs, reward, done, info = env.step(action) | |
episodes[-1].append( | |
{ | |
"action": action, | |
"info": info, | |
"obs": obs, | |
"reward": reward, | |
"done": done, | |
"record": record[0], | |
} | |
) | |
redraw(obs) | |
if done: | |
print('done!') | |
print('Reward=%.2f' % (reward)) | |
# reset and add initial state to episodes | |
reset() | |
else: | |
print('\nStep=%s' % (env.step_count)) | |
# filter steps without recording | |
episodes_to_save = [[s for s in ep if s["record"]] for ep in episodes] | |
episodes_to_save = [ep for ep in episodes_to_save if len(ep) > 0] | |
# set first recording step to be as if it was just reset (the real first step) | |
for ep_to_save in episodes_to_save: | |
ep_to_save[0]["action"]=None | |
ep_to_save[0]["reward"]=None | |
ep_to_save[0]["done"]=None | |
# picle the episodes | |
dump_pickle = Path(output_dir) / "episodes.pkl" | |
print(f"Saving {len(episodes_to_save)} episodes ({[len(e) for e in episodes_to_save]}) to : {dump_pickle}") | |
with open(dump_pickle, 'wb') as f: | |
pickle.dump(episodes_to_save, f) | |
def key_handler(event): | |
print('pressed', event.key) | |
if event.key == 'r': | |
start_recording() | |
return | |
if event.key == 'escape': | |
window.close() | |
return | |
if event.key == 's': | |
reset() | |
return | |
if event.key == 'tab': | |
step(np.array([np.nan, np.nan, np.nan])) | |
return | |
if event.key == 'shift': | |
step(np.array([np.nan, np.nan, np.nan])) | |
return | |
if event.key == 'left': | |
step(env.actions.left) | |
return | |
if event.key == 'right': | |
step(env.actions.right) | |
return | |
if event.key == 'up': | |
step(env.actions.forward) | |
return | |
if event.key == 't': | |
step(env.actions.speak) | |
return | |
if event.key == '1': | |
step(np.array([np.nan, 0, 0])) | |
return | |
if event.key == '2': | |
step(np.array([np.nan, 0, 1])) | |
return | |
if event.key == '3': | |
step(np.array([np.nan, 1, 0])) | |
return | |
if event.key == '4': | |
step(np.array([np.nan, 1, 1])) | |
return | |
if event.key == '5': | |
step(np.array([np.nan, 2, 2])) | |
return | |
if event.key == '6': | |
step(np.array([np.nan, 1, 2])) | |
return | |
if event.key == '7': | |
step(np.array([np.nan, 2, 1])) | |
return | |
if event.key == '8': | |
step(np.array([np.nan, 1, 3])) | |
return | |
if event.key == 'p': | |
step(np.array([np.nan, 3, 3])) | |
return | |
# Spacebar | |
if event.key == ' ': | |
step(env.actions.toggle) | |
return | |
if event.key == '9': | |
step(env.actions.pickup) | |
return | |
if event.key == '0': | |
step(env.actions.drop) | |
return | |
if event.key == 'enter': | |
step(env.actions.done) | |
return | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--env", | |
help="gym environment to load", | |
# default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1", | |
# default="SocialAI-ColorBoxesLLMCSParamEnv-v1", | |
default="SocialAI-ColorLLMCSParamEnv-v1", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
help="random seed to generate the environment with", | |
default=-1 | |
) | |
parser.add_argument( | |
"--tile_size", | |
type=int, | |
help="size at which to render tiles", | |
default=32 | |
) | |
parser.add_argument( | |
'--agent_view', | |
default=False, | |
help="draw the agent sees (partially observable view)", | |
action='store_true' | |
) | |
parser.add_argument( | |
'--mask-unobserved', | |
default=False, | |
help="mask cells that are not observed by the agent", | |
action='store_true' | |
) | |
parser.add_argument( | |
'--save-dir', | |
default="./llm_data/in_context_examples/", | |
help="file where to save episodes", | |
) | |
parser.add_argument( | |
'--load', | |
default=None, | |
help="Load in context examples to append to", | |
) | |
parser.add_argument( | |
'--name', | |
default="in_context", | |
help="additional name tag for the episodes", | |
) | |
parser.add_argument( | |
'--draw-tree', | |
action="store_true", | |
help="Draw the sampling treee", | |
) | |
# Put all env related arguments after --env_args, e.g. --env_args nb_foo 1 is_bar True | |
parser.add_argument("--env-args", nargs='*', default=None) | |
args = parser.parse_args() | |
env = gym.make(args.env, **env_args_str_to_dict(args.env_args)) | |
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") | |
output_dir = Path(args.save_dir) / f"{args.name}_{args.env}_{timestamp}" | |
os.makedirs(output_dir, exist_ok=True) | |
if args.load: | |
with open(args.load, 'rb') as f: | |
episodes = pickle.load(f) | |
if args.draw_tree: | |
# draw tree | |
env.parameter_tree.draw_tree( | |
filename=output_dir / f"/{args.env}_raw_tree", | |
ignore_labels=["Num_of_colors"], | |
) | |
if args.seed >= 0: | |
env.seed(args.seed) | |
window = Window('gym_minigrid - ' + args.env, figsize=(6, 4)) | |
window.reg_key_handler(key_handler) | |
env.window = window | |
reset() | |
# # a trick to make the first image appear right away | |
# # this action is not saved | |
# obs, _, _, _ = env.step(np.array([np.nan, np.nan, np.nan])) | |
# redraw(obs) | |
# Blocking event loop | |
window.show(block=True) | |