Spaces:
Sleeping
Sleeping
File size: 4,625 Bytes
be5548b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import argparse
import json
import time
import numpy as np
import torch
from pathlib import Path
from utils.babyai_utils.baby_agent import load_agent
from utils.env import make_env
from utils.other import seed
from utils.storage import get_model_dir
from utils.storage import get_status
from models import *
import subprocess
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="name of the trained model (REQUIRED)")
parser.add_argument("--seed", type=int, default=0,
help="random seed (default: 0)")
parser.add_argument("--max-steps", type=int, default=None,
help="max num of steps")
parser.add_argument("--shift", type=int, default=0,
help="number of times the environment is reset at the beginning (default: 0)")
parser.add_argument("--argmax", action="store_true", default=False,
help="select the action with highest probability (default: False)")
parser.add_argument("--pause", type=float, default=0.5,
help="pause duration between two consequent actions of the agent (default: 0.5)")
parser.add_argument("--env-name", type=str, default=None, required=True,
help="env name")
parser.add_argument("--gif", type=str, default=None,
help="store output as gif with the given filename", required=True)
parser.add_argument("--episodes", type=int, default=10,
help="number of episodes to visualize")
args = parser.parse_args()
# Set seed for all randomness sources
seed(args.seed)
save = args.gif
if save:
savename = args.gif
if savename == "model_id":
savename = args.model.replace('storage/', '')
savename = savename.replace('/','_')
savename += '_{}'.format(args.seed)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")
# Load environment
if str(args.model).startswith("./storage/"):
args.model = args.model.replace("./storage/", "")
if str(args.model).startswith("storage/"):
args.model = args.model.replace("storage/", "")
with open(Path("./storage") / args.model / "config.json") as f:
conf = json.load(f)
if args.env_name is None:
# load env_args from status
env_args = {}
if not "env_args" in conf.keys():
env_args = get_status(get_model_dir(args.model), None)['env_args']
else:
env_args = conf["env_args"]
env = make_env(args.env_name, args.seed, env_args=env_args)
else:
env_name = args.env_name
env = make_env(args.env_name, args.seed)
for _ in range(args.shift):
env.reset()
print("Environment loaded\n")
# Define agent
model_dir = get_model_dir(args.model)
num_frames = None
agent = load_agent(env, model_dir, args.argmax, num_frames)
print("Agent loaded\n")
# Run the agent
if save:
from imageio import mimsave
old_frames = []
frames = []
# Create a window to view the environment
env.render(mode='human')
def plt_2_rgb(env):
data = np.frombuffer(env.window.fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(env.window.fig.canvas.get_width_height()[::-1] + (3,))
return data
for episode in range(args.episodes):
print("episode:", episode)
obs = env.reset()
env.render(mode='human')
if save:
frames.append(plt_2_rgb(env))
i = 0
while True:
i += 1
action = agent.get_action(obs)
obs, reward, done, _ = env.step(action)
agent.analyze_feedback(reward, done)
env.render(mode='human')
if save:
img = plt_2_rgb(env)
frames.append(img)
if done:
# quadruple last frame to pause between episodes
for i in range(3):
same_img = np.copy(img)
# toggle a pixel between frames to avoid cropping when going from gif to mp4
same_img[0,0,2] = 0 if (i % 2) == 0 else 255
frames.append(same_img)
if done or env.window.closed:
break
if args.max_steps is not None:
if i > args.max_steps:
break
if env.window.closed:
break
if save:
# from IPython import embed; embed()
print(f"Saving to {savename} ", end="")
mimsave(savename + '.gif', frames, duration=args.pause)
# Reduce gif size
# bashCommand = "gifsicle -O3 --colors 32 -o {}.gif {}.gif".format(savename, savename)
# process = subprocess.run(bashCommand.split(), stdout=subprocess.PIPE)
print("Done.")
|