Spaces:
Runtime error
Runtime error
import cv2 | |
import gradio as gr | |
import time | |
from huggingface_sb3 import load_from_hub | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.env_util import make_atari_env | |
from stable_baselines3.common.vec_env import VecFrameStack | |
from stable_baselines3.common.env_util import make_atari_env | |
# Loading functions were taken from Edward Beeching code | |
def load_env(env_name): | |
env = make_atari_env(env_name, n_envs=1) | |
env = VecFrameStack(env, n_stack=4) | |
return env | |
def load_model(env_name): | |
custom_objects = { | |
"learning_rate": 0.0, | |
"lr_schedule": lambda _: 0.0, | |
"clip_range": lambda _: 0.0, | |
} | |
checkpoint = load_from_hub( | |
f"ThomasSimonini/ppo-{env_name}", | |
f"ppo-{env_name}.zip", | |
) | |
model = PPO.load(checkpoint, custom_objects=custom_objects) | |
return model | |
def replay(env_name, time_sleep, num_episodes): | |
env = load_env(env_name) | |
model = load_model(env_name) | |
for i in range(num_episodes): | |
obs = env.reset() | |
done = False | |
while not done: | |
frame = env.render(mode="rgb_array") | |
action, _states = model.predict(obs) | |
obs, reward, done, info = env.step([action]) | |
time.sleep(time_sleep) | |
yield frame | |
demo = gr.Interface( | |
replay, | |
[gr.Dropdown(["SpaceInvadersNoFrameskip-v4", | |
"PongNoFrameskip-v4", | |
"SeaquestNoFrameskip-v4", | |
"QbertNoFrameskip-v4", | |
]), | |
gr.Slider(0.01, 1, value=0.1), | |
gr.Slider(1, 20, value=5) | |
], | |
gr.Image(), | |
title="Watch Agents playing Atari games 🤖", | |
description="Select an environment to watch a Hugging Face's trained deep reinforcement learning agent.", | |
article = "time_sleep is the time delay between each frame (0.1 by default)." | |
).launch(debug=True, enable_queue=True).queue() |