Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import time | |
import numpy as np | |
import torch | |
from pathlib import Path | |
from utils.babyai_utils.baby_agent import load_agent | |
from utils.env import make_env | |
from utils.other import seed | |
from utils.storage import get_model_dir | |
from utils.storage import get_status | |
from models import * | |
import subprocess | |
# Parse arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", required=True, | |
help="name of the trained model (REQUIRED)") | |
parser.add_argument("--seed", type=int, default=0, | |
help="random seed (default: 0)") | |
parser.add_argument("--max-steps", type=int, default=None, | |
help="max num of steps") | |
parser.add_argument("--shift", type=int, default=0, | |
help="number of times the environment is reset at the beginning (default: 0)") | |
parser.add_argument("--argmax", action="store_true", default=False, | |
help="select the action with highest probability (default: False)") | |
parser.add_argument("--pause", type=float, default=0.5, | |
help="pause duration between two consequent actions of the agent (default: 0.5)") | |
parser.add_argument("--env-name", type=str, default=None, required=True, | |
help="env name") | |
parser.add_argument("--gif", type=str, default=None, | |
help="store output as gif with the given filename", required=True) | |
parser.add_argument("--episodes", type=int, default=10, | |
help="number of episodes to visualize") | |
args = parser.parse_args() | |
# Set seed for all randomness sources | |
seed(args.seed) | |
save = args.gif | |
if save: | |
savename = args.gif | |
if savename == "model_id": | |
savename = args.model.replace('storage/', '') | |
savename = savename.replace('/','_') | |
savename += '_{}'.format(args.seed) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Device: {device}\n") | |
# Load environment | |
if str(args.model).startswith("./storage/"): | |
args.model = args.model.replace("./storage/", "") | |
if str(args.model).startswith("storage/"): | |
args.model = args.model.replace("storage/", "") | |
with open(Path("./storage") / args.model / "config.json") as f: | |
conf = json.load(f) | |
if args.env_name is None: | |
# load env_args from status | |
env_args = {} | |
if not "env_args" in conf.keys(): | |
env_args = get_status(get_model_dir(args.model), None)['env_args'] | |
else: | |
env_args = conf["env_args"] | |
env = make_env(args.env_name, args.seed, env_args=env_args) | |
else: | |
env_name = args.env_name | |
env = make_env(args.env_name, args.seed) | |
for _ in range(args.shift): | |
env.reset() | |
print("Environment loaded\n") | |
# Define agent | |
model_dir = get_model_dir(args.model) | |
num_frames = None | |
agent = load_agent(env, model_dir, args.argmax, num_frames) | |
print("Agent loaded\n") | |
# Run the agent | |
if save: | |
from imageio import mimsave | |
old_frames = [] | |
frames = [] | |
# Create a window to view the environment | |
env.render(mode='human') | |
def plt_2_rgb(env): | |
data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8) | |
data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
for episode in range(args.episodes): | |
print("episode:", episode) | |
obs = env.reset() | |
env.render(mode='human') | |
if save: | |
frames.append(plt_2_rgb(env)) | |
i = 0 | |
while True: | |
i += 1 | |
action = agent.get_action(obs) | |
obs, reward, done, _ = env.step(action) | |
agent.analyze_feedback(reward, done) | |
env.render(mode='human') | |
if save: | |
img = plt_2_rgb(env) | |
frames.append(img) | |
if done: | |
# quadruple last frame to pause between episodes | |
for i in range(3): | |
same_img = np.copy(img) | |
# toggle a pixel between frames to avoid cropping when going from gif to mp4 | |
same_img[0,0,2] = 0 if (i % 2) == 0 else 255 | |
frames.append(same_img) | |
if done or env.window.closed: | |
break | |
if args.max_steps is not None: | |
if i > args.max_steps: | |
break | |
if env.window.closed: | |
break | |
if save: | |
# from IPython import embed; embed() | |
print(f"Saving to {savename} ", end="") | |
mimsave(savename + '.gif', frames, duration=args.pause) | |
# Reduce gif size | |
# bashCommand = "gifsicle -O3 --colors 32 -o {}.gif {}.gif".format(savename, savename) | |
# process = subprocess.run(bashCommand.split(), stdout=subprocess.PIPE) | |
print("Done.") | |