|
import gradio as gr |
|
import os |
|
from moviepy.editor import * |
|
|
|
def replay(option): |
|
path = "" |
|
|
|
if (option == "LunarLander-v2 ππ©βπ"): |
|
path = "./LunarLander-v2.mp4" |
|
elif(option == "CartPole-v1 πΉοΈ"): |
|
path = "./CartPole-v1.mp4" |
|
elif(option == "Atari Space Invaders πΎ"): |
|
path = "./SpaceInvadersNoFrameskip-v4.mp4" |
|
|
|
|
|
|
|
videoclip = VideoFileClip(path) |
|
videoclip.write_videofile("new_filename.mp4") |
|
return 'new_filename.mp4' |
|
|
|
iface = gr.Interface( |
|
replay, |
|
[ |
|
gr.inputs.Dropdown(["Atari Space Invaders πΎ", "CartPole-v1 πΉοΈ", "LunarLander-v2 ππ©βπ"]), |
|
], |
|
"video", |
|
title = 'Stable Baselines 3 with π€', |
|
description = '', |
|
article = |
|
'''<div> |
|
<p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p> |
|
<p style="text-align: center"> Select the trained agent you want to watch perform. |
|
These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p> |
|
<p> |
|
There are currently 3 models: |
|
<ul> |
|
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li> |
|
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li> |
|
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li> |
|
</ul> |
|
</div>''' |
|
) |
|
|
|
|
|
iface.launch() |
|
|
|
""" |
|
TODO: Next version with live video generation |
|
import gradio as gr |
|
import os |
|
|
|
from Recorder import Recorder |
|
|
|
from stable_baselines3 import PPO |
|
|
|
|
|
#The Agent plays and we generate the video |
|
def replay(option): |
|
video_path = "" |
|
# Get the correct model |
|
if (option == "LunarLander-v2 ππ©βπ"): |
|
env_name = "Lunar Lander v2" |
|
agent_name = "PPO" |
|
print("TEST") |
|
hf_model_filename = "LunarLander-v2" |
|
hf_model_id = "ThomasSimonini/stable-baselines3-ppo-LunarLander-v2" |
|
video_path = replay_gym(hf_model_filename, hf_model_id) |
|
elif(option == "CartPole-v1 πΉοΈ"): |
|
hf_model_filename = "CartPole-v1" |
|
hf_model_id = "ThomasSimonini/stable-baselines3-ppo-CartPole-v1" |
|
video_path = replay_gym(hf_model_filename, hf_model_id) |
|
elif(option == "Atari Space Invaders πΎ"): |
|
hf_model_filename = "SpaceInvadersNoFrameskip-v4" |
|
hf_model_id = "ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4" |
|
video_path = replay_atari(hf_model_filename, hf_model_id) |
|
#video_path = "./SpaceInvadersNoFrameskip-v4.mp4" |
|
|
|
return video_path |
|
|
|
|
|
def replay_gym(hf_model_filename, hf_model_id): |
|
import gym |
|
from stable_baselines3.common.evaluation import evaluate_policy |
|
|
|
|
|
model = PPO.load_from_huggingface(hf_model_id,hf_model_filename) |
|
|
|
eval_env = gym.make(hf_model_filename) |
|
|
|
directory = './video' |
|
env = Recorder(eval_env, directory) |
|
|
|
obs = env.reset() |
|
done = False |
|
while not done: |
|
action, _state = model.predict(obs) |
|
obs, reward, done, info = env.step(action) |
|
clip = env.play() |
|
return clip |
|
|
|
|
|
def replay_atari(hf_model_filename, hf_model_id): |
|
os.system("python -m atari_py.import_roms \"content/atari_roms\"") |
|
import gym |
|
from stable_baselines3.common.env_util import make_atari_env |
|
from stable_baselines3.common.vec_env import VecFrameStack |
|
|
|
from stable_baselines3.common.evaluation import evaluate_policy |
|
|
|
model = PPO.load_from_huggingface(hf_model_id, hf_model_filename) |
|
|
|
|
|
eval_env = make_atari_env(hf_model_filename, n_envs=1, seed=0) |
|
eval_env = VecFrameStack(eval_env, n_stack=4) |
|
|
|
model = PPO.load_from_huggingface(hf_model_id, hf_model_filename) |
|
|
|
import gym |
|
directory = './video' |
|
env = Recorder(eval_env, directory) |
|
|
|
obs = env.reset() |
|
done = False |
|
while not done: |
|
action, _state = model.predict(obs) |
|
obs, reward, done, info = env.step(action) |
|
clip = env.play() |
|
return clip |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
replay, |
|
[ |
|
gr.inputs.Dropdown(["Atari Space Invaders πΎ", "CartPole-v1 πΉοΈ", "LunarLander-v2 ππ©βπ"]), |
|
], |
|
"video", |
|
title = 'Stable Baselines 3 with π€', |
|
description = '', |
|
article = |
|
'''<div> |
|
<p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p> |
|
<p style="text-align: center"> Select the trained agent you want to watch perform. We record your agent playing. |
|
<p style="text-align: center"> Don't forget to <b>click on clear between each record.</b> </p> |
|
These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p> |
|
<p> |
|
There are currently 3 models: |
|
<ul> |
|
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li> |
|
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li> |
|
<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li> |
|
</ul> |
|
</div>''' |
|
) |
|
|
|
|
|
iface.launch() |
|
""" |
|
|
|
|