Spaces:
Runtime error
Runtime error
File size: 2,579 Bytes
164ef55 d812b5e 064e386 164ef55 9e97c8e aa21d8d d812b5e 5d7151b d812b5e ae8d4fa d812b5e a53beac d812b5e 4a80994 d812b5e 5871174 d812b5e c044350 ae8d4fa 69db694 e3e7f84 f8a8b07 9e97c8e ae8d4fa 9e97c8e bf692e3 d812b5e 164ef55 90f838f aa21d8d 90f838f d38fa4c 90f838f bf692e3 164ef55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
from huggingface_sb3 import load_from_hub
import gym
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
from moviepy.editor import *
def replay(model_id, filename, environment, evaluate):
# Load the model
checkpoint = load_from_hub(model_id, filename)
# Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,}
model= PPO.load(checkpoint, custom_objects=custom_objects)
eval_env = make_atari_env(environment, n_envs=1)
eval_env = VecFrameStack(eval_env, n_stack=4)
video_folder = 'logs/videos/'
video_length = 1000
# Record the video starting at the first step
env = VecVideoRecorder(eval_env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix=f"test")
obs = env.reset()
for _ in range(video_length + 1):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
# Save the video
env.close()
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
results = f"mean_reward={mean_reward:.2f} +/- {std_reward}"
print(type(results))
print(env)
print(env.video_recorder.path)
videoclip = VideoFileClip(env.video_recorder.path)
videoclip.write_videofile("new_filename.mp4")
return 'new_filename.mp4', results
examples = [["ThomasSimonini/ppo-QbertNoFrameskip-v4", "ppo-QbertNoFrameskip-v4.zip", "QbertNoFrameskip-v4", True]]
iface = gr.Interface(fn=replay, inputs=[
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "),
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "),
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "),
gr.inputs.Checkbox(default=False, label="Evaluate?: ")
]
, outputs=["video", "text"], enable_queue=True, examples=examples)
iface.launch() |