SocialAISchool / scripts /visualize.py
grg's picture
Cleaned old git history
be5548b
raw
history blame
No virus
4.63 kB
import argparse
import json
import time
import numpy as np
import torch
from pathlib import Path
from utils.babyai_utils.baby_agent import load_agent
from utils.env import make_env
from utils.other import seed
from utils.storage import get_model_dir
from utils.storage import get_status
from models import *
import subprocess
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="name of the trained model (REQUIRED)")
parser.add_argument("--seed", type=int, default=0,
help="random seed (default: 0)")
parser.add_argument("--max-steps", type=int, default=None,
help="max num of steps")
parser.add_argument("--shift", type=int, default=0,
help="number of times the environment is reset at the beginning (default: 0)")
parser.add_argument("--argmax", action="store_true", default=False,
help="select the action with highest probability (default: False)")
parser.add_argument("--pause", type=float, default=0.5,
help="pause duration between two consequent actions of the agent (default: 0.5)")
parser.add_argument("--env-name", type=str, default=None, required=True,
help="env name")
parser.add_argument("--gif", type=str, default=None,
help="store output as gif with the given filename", required=True)
parser.add_argument("--episodes", type=int, default=10,
help="number of episodes to visualize")
args = parser.parse_args()
# Set seed for all randomness sources
seed(args.seed)
save = args.gif
if save:
savename = args.gif
if savename == "model_id":
savename = args.model.replace('storage/', '')
savename = savename.replace('/','_')
savename += '_{}'.format(args.seed)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")
# Load environment
if str(args.model).startswith("./storage/"):
args.model = args.model.replace("./storage/", "")
if str(args.model).startswith("storage/"):
args.model = args.model.replace("storage/", "")
with open(Path("./storage") / args.model / "config.json") as f:
conf = json.load(f)
if args.env_name is None:
# load env_args from status
env_args = {}
if not "env_args" in conf.keys():
env_args = get_status(get_model_dir(args.model), None)['env_args']
else:
env_args = conf["env_args"]
env = make_env(args.env_name, args.seed, env_args=env_args)
else:
env_name = args.env_name
env = make_env(args.env_name, args.seed)
for _ in range(args.shift):
env.reset()
print("Environment loaded\n")
# Define agent
model_dir = get_model_dir(args.model)
num_frames = None
agent = load_agent(env, model_dir, args.argmax, num_frames)
print("Agent loaded\n")
# Run the agent
if save:
from imageio import mimsave
old_frames = []
frames = []
# Create a window to view the environment
env.render(mode='human')
def plt_2_rgb(env):
data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,))
return data
for episode in range(args.episodes):
print("episode:", episode)
obs = env.reset()
env.render(mode='human')
if save:
frames.append(plt_2_rgb(env))
i = 0
while True:
i += 1
action = agent.get_action(obs)
obs, reward, done, _ = env.step(action)
agent.analyze_feedback(reward, done)
env.render(mode='human')
if save:
img = plt_2_rgb(env)
frames.append(img)
if done:
# quadruple last frame to pause between episodes
for i in range(3):
same_img = np.copy(img)
# toggle a pixel between frames to avoid cropping when going from gif to mp4
same_img[0,0,2] = 0 if (i % 2) == 0 else 255
frames.append(same_img)
if done or env.window.closed:
break
if args.max_steps is not None:
if i > args.max_steps:
break
if env.window.closed:
break
if save:
# from IPython import embed; embed()
print(f"Saving to {savename} ", end="")
mimsave(savename + '.gif', frames, duration=args.pause)
# Reduce gif size
# bashCommand = "gifsicle -O3 --colors 32 -o {}.gif {}.gif".format(savename, savename)
# process = subprocess.run(bashCommand.split(), stdout=subprocess.PIPE)
print("Done.")