SocialAISchool / scripts /visualize.py
grg's picture
Cleaned old git history
be5548b
import argparse
import json
import time
import numpy as np
import torch
from pathlib import Path
from utils.babyai_utils.baby_agent import load_agent
from utils.env import make_env
from utils.other import seed
from utils.storage import get_model_dir
from utils.storage import get_status
from models import *
import subprocess
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="name of the trained model (REQUIRED)")
parser.add_argument("--seed", type=int, default=0,
help="random seed (default: 0)")
parser.add_argument("--max-steps", type=int, default=None,
help="max num of steps")
parser.add_argument("--shift", type=int, default=0,
help="number of times the environment is reset at the beginning (default: 0)")
parser.add_argument("--argmax", action="store_true", default=False,
help="select the action with highest probability (default: False)")
parser.add_argument("--pause", type=float, default=0.5,
help="pause duration between two consequent actions of the agent (default: 0.5)")
parser.add_argument("--env-name", type=str, default=None, required=True,
help="env name")
parser.add_argument("--gif", type=str, default=None,
help="store output as gif with the given filename", required=True)
parser.add_argument("--episodes", type=int, default=10,
help="number of episodes to visualize")
args = parser.parse_args()
# Set seed for all randomness sources
seed(args.seed)
save = args.gif
if save:
savename = args.gif
if savename == "model_id":
savename = args.model.replace('storage/', '')
savename = savename.replace('/','_')
savename += '_{}'.format(args.seed)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")
# Load environment
if str(args.model).startswith("./storage/"):
args.model = args.model.replace("./storage/", "")
if str(args.model).startswith("storage/"):
args.model = args.model.replace("storage/", "")
with open(Path("./storage") / args.model / "config.json") as f:
conf = json.load(f)
if args.env_name is None:
# load env_args from status
env_args = {}
if not "env_args" in conf.keys():
env_args = get_status(get_model_dir(args.model), None)['env_args']
else:
env_args = conf["env_args"]
env = make_env(args.env_name, args.seed, env_args=env_args)
else:
env_name = args.env_name
env = make_env(args.env_name, args.seed)
for _ in range(args.shift):
env.reset()
print("Environment loaded\n")
# Define agent
model_dir = get_model_dir(args.model)
num_frames = None
agent = load_agent(env, model_dir, args.argmax, num_frames)
print("Agent loaded\n")
# Run the agent
if save:
from imageio import mimsave
old_frames = []
frames = []
# Create a window to view the environment
env.render(mode='human')
def plt_2_rgb(env):
data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,))
return data
for episode in range(args.episodes):
print("episode:", episode)
obs = env.reset()
env.render(mode='human')
if save:
frames.append(plt_2_rgb(env))
i = 0
while True:
i += 1
action = agent.get_action(obs)
obs, reward, done, _ = env.step(action)
agent.analyze_feedback(reward, done)
env.render(mode='human')
if save:
img = plt_2_rgb(env)
frames.append(img)
if done:
# quadruple last frame to pause between episodes
for i in range(3):
same_img = np.copy(img)
# toggle a pixel between frames to avoid cropping when going from gif to mp4
same_img[0,0,2] = 0 if (i % 2) == 0 else 255
frames.append(same_img)
if done or env.window.closed:
break
if args.max_steps is not None:
if i > args.max_steps:
break
if env.window.closed:
break
if save:
# from IPython import embed; embed()
print(f"Saving to {savename} ", end="")
mimsave(savename + '.gif', frames, duration=args.pause)
# Reduce gif size
# bashCommand = "gifsicle -O3 --colors 32 -o {}.gif {}.gif".format(savename, savename)
# process = subprocess.run(bashCommand.split(), stdout=subprocess.PIPE)
print("Done.")