File size: 2,579 Bytes
164ef55
d812b5e
 
 
 
 
 
 
 
 
 
 
064e386
164ef55
9e97c8e
 
aa21d8d
d812b5e
 
 
 
 
 
 
 
 
 
 
 
5d7151b
d812b5e
 
ae8d4fa
d812b5e
 
 
 
 
a53beac
d812b5e
4a80994
d812b5e
 
5871174
d812b5e
 
c044350
ae8d4fa
 
 
69db694
e3e7f84
f8a8b07
9e97c8e
 
 
ae8d4fa
9e97c8e
bf692e3
d812b5e
164ef55
90f838f
aa21d8d
90f838f
 
d38fa4c
90f838f
 
 
bf692e3
164ef55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from huggingface_sb3 import load_from_hub

import gym
import os 
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize

from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

from moviepy.editor import *

def replay(model_id, filename, environment, evaluate):
	# Load the model
	checkpoint = load_from_hub(model_id, filename)
    
	# Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
	custom_objects = {
            "learning_rate": 0.0,
            "lr_schedule": lambda _: 0.0,
            "clip_range": lambda _: 0.0,}
	
	model= PPO.load(checkpoint, custom_objects=custom_objects)

	eval_env = make_atari_env(environment, n_envs=1)
	eval_env = VecFrameStack(eval_env, n_stack=4)

	video_folder = 'logs/videos/'
	video_length = 1000


	# Record the video starting at the first step
	env = VecVideoRecorder(eval_env, video_folder,
                       record_video_trigger=lambda x: x == 0, video_length=video_length,
                       name_prefix=f"test")

	obs = env.reset()
	for _ in range(video_length + 1):
		action, _states = model.predict(obs)
		obs, rewards, dones, info = env.step(action)
	# Save the video
	env.close()
	
	mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
	print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
	results = f"mean_reward={mean_reward:.2f} +/- {std_reward}" 
	print(type(results))
	print(env)
	print(env.video_recorder.path)
	
	videoclip = VideoFileClip(env.video_recorder.path)
	videoclip.write_videofile("new_filename.mp4")
	return 'new_filename.mp4', results

examples = [["ThomasSimonini/ppo-QbertNoFrameskip-v4", "ppo-QbertNoFrameskip-v4.zip", "QbertNoFrameskip-v4", True]]


iface = gr.Interface(fn=replay, inputs=[
                                        gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "), 
                                        gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Filename: "),
                                        gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Environment: "),
                                        gr.inputs.Checkbox(default=False, label="Evaluate?: ")
                                        ]


, outputs=["video", "text"], enable_queue=True, examples=examples)
iface.launch()