SB3_Atari / app.py
ThomasSimonini's picture
Update app.py
bf692e3
import gradio as gr
from huggingface_sb3 import load_from_hub
import gym
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from moviepy.editor import *
def replay(model_id, filename, environment, evaluate):
# Load the model
checkpoint = load_from_hub(model_id, filename)
# Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,}
model= PPO.load(checkpoint, custom_objects=custom_objects)
eval_env = make_atari_env(environment, n_envs=1)
eval_env = VecFrameStack(eval_env, n_stack=4)
video_folder = 'logs/videos/'
video_length = 1000
# Record the video starting at the first step
env = VecVideoRecorder(eval_env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix=f"test")
obs = env.reset()
for _ in range(video_length + 1):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
# Save the video
env.close()
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
results = f"mean_reward={mean_reward:.2f} +/- {std_reward}"
print(type(results))
print(env)
print(env.video_recorder.path)
videoclip = VideoFileClip(env.video_recorder.path)
videoclip.write_videofile("new_filename.mp4")
return 'new_filename.mp4', results
examples = [["ThomasSimonini/ppo-QbertNoFrameskip-v4", "ppo-QbertNoFrameskip-v4.zip", "QbertNoFrameskip-v4", True]]
iface = gr.Interface(fn=replay, inputs=[
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "),
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "),
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "),
gr.inputs.Checkbox(default=False, label="Evaluate?: ")
]
, outputs=["video", "text"], enable_queue=True, examples=examples)
iface.launch()