Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_sb3 import load_from_hub | |
import gym | |
import os | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.vec_env import VecNormalize | |
from stable_baselines3.common.env_util import make_atari_env | |
from stable_baselines3.common.vec_env import VecFrameStack | |
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv | |
from stable_baselines3.common.evaluation import evaluate_policy | |
from moviepy.editor import * | |
def replay(model_id, filename, environment, evaluate): | |
# Load the model | |
checkpoint = load_from_hub(model_id, filename) | |
# Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors: | |
custom_objects = { | |
"learning_rate": 0.0, | |
"lr_schedule": lambda _: 0.0, | |
"clip_range": lambda _: 0.0,} | |
model= PPO.load(checkpoint, custom_objects=custom_objects) | |
eval_env = make_atari_env(environment, n_envs=1) | |
eval_env = VecFrameStack(eval_env, n_stack=4) | |
video_folder = 'logs/videos/' | |
video_length = 100 | |
# Record the video starting at the first step | |
env = VecVideoRecorder(eval_env, video_folder, | |
record_video_trigger=lambda x: x == 0, video_length=video_length, | |
name_prefix=f"test") | |
obs = env.reset() | |
for _ in range(video_length + 1): | |
action, _states = model.predict(obs) | |
obs, rewards, dones, info = env.step(action) | |
# Save the video | |
env.close() | |
#mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10) | |
#print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") | |
results = "hello" #f"mean_reward={mean_reward:.2f} +/- {std_reward}" | |
print(type(results)) | |
print(env) | |
print(env.video_recorder.path) | |
videoclip = VideoFileClip(env.video_recorder.path) | |
videoclip.write_videofile("new_filename.mp4") | |
return 'new_filename.mp4' | |
iface = gr.Interface(fn=replay, inputs=[ | |
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "), | |
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "), | |
gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "), | |
gr.inputs.Checkbox(default=False, label="Evaluate?: ") | |
] | |
, outputs="video") | |
iface.launch() |