import gradio as gr from huggingface_sb3 import load_from_hub import gym import os from stable_baselines3 import PPO from stable_baselines3.common.vec_env import VecNormalize from stable_baselines3.common.env_util import make_atari_env from stable_baselines3.common.vec_env import VecFrameStack from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv from stable_baselines3.common.evaluation import evaluate_policy from moviepy.editor import * def replay(model_id, filename, environment, evaluate): # Load the model checkpoint = load_from_hub(model_id, filename) # Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors: custom_objects = { "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0,} model= PPO.load(checkpoint, custom_objects=custom_objects) eval_env = make_atari_env(environment, n_envs=1) eval_env = VecFrameStack(eval_env, n_stack=4) video_folder = 'logs/videos/' video_length = 1000 # Record the video starting at the first step env = VecVideoRecorder(eval_env, video_folder, record_video_trigger=lambda x: x == 0, video_length=video_length, name_prefix=f"test") obs = env.reset() for _ in range(video_length + 1): action, _states = model.predict(obs) obs, rewards, dones, info = env.step(action) # Save the video env.close() mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10) print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") results = f"mean_reward={mean_reward:.2f} +/- {std_reward}" print(type(results)) print(env) print(env.video_recorder.path) videoclip = VideoFileClip(env.video_recorder.path) videoclip.write_videofile("new_filename.mp4") return 'new_filename.mp4', results examples = [["ThomasSimonini/ppo-QbertNoFrameskip-v4", "ppo-QbertNoFrameskip-v4.zip", "QbertNoFrameskip-v4", True]] iface = gr.Interface(fn=replay, inputs=[ gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "), gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "), gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "), gr.inputs.Checkbox(default=False, label="Evaluate?: ") ] , outputs=["video", "text"], enable_queue=True, examples=examples) iface.launch()