atari_agents / app.py
ThomasSimonini's picture
Remove max_steps as user configurable
82410fa
import cv2
import gradio as gr
import time
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.env_util import make_atari_env
max_steps = 5000 # Let's try with 5000 steps.
# Loading functions were taken from Edward Beeching code
def load_env(env_name):
env = make_atari_env(env_name, n_envs=1)
env = VecFrameStack(env, n_stack=4)
return env
def load_model(env_name):
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
}
checkpoint = load_from_hub(
f"ThomasSimonini/ppo-{env_name}",
f"ppo-{env_name}.zip",
)
model = PPO.load(checkpoint, custom_objects=custom_objects)
return model
def replay(env_name, time_sleep):
max_steps = 500
env = load_env(env_name)
model = load_model(env_name)
#for i in range(num_episodes):
obs = env.reset()
done = False
i = 0
while not done:
i+= 1
if i < max_steps:
frame = env.render(mode="rgb_array")
action, _states = model.predict(obs)
obs, reward, done, info = env.step([action])
time.sleep(time_sleep)
yield frame
else:
break
demo = gr.Interface(
replay,
[gr.Dropdown(["SpaceInvadersNoFrameskip-v4",
"PongNoFrameskip-v4",
"SeaquestNoFrameskip-v4",
"QbertNoFrameskip-v4",
]),
#gr.Slider(100, 10000, value=500),
gr.Slider(0.01, 1, value=0.05),
#gr.Slider(1, 20, value=5)
],
gr.Image(),
title="Watch Agents playing Atari games 🤖",
description="Select an environment to watch a Hugging Face's trained deep reinforcement learning agent.",
article = "time_sleep is the time delay between each frame (0.05 by default)."
).launch().queue()