ThomasSimonini HF staff commited on
Commit
d812b5e
1 Parent(s): 0c427d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -1
app.py CHANGED
@@ -1,8 +1,48 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  def replay(model_id, filename, environment, evaluate):
5
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  iface = gr.Interface(fn=replay, inputs=[
8
  gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "),
 
1
  import gradio as gr
2
+ from huggingface_sb3 import load_from_hub
3
+
4
+ import gym
5
+ import os
6
+ from stable_baselines3 import PPO
7
+ from stable_baselines3.common.vec_env import VecNormalize
8
+
9
+ from stable_baselines3.common.env_util import make_atari_env
10
+ from stable_baselines3.common.vec_env import VecFrameStack
11
+
12
+ from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
13
 
14
 
15
  def replay(model_id, filename, environment, evaluate):
16
+ # Load the model
17
+ checkpoint = load_from_hub(model_id, filename)
18
+
19
+ # Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
20
+ custom_objects = {
21
+ "learning_rate": 0.0,
22
+ "lr_schedule": lambda _: 0.0,
23
+ "clip_range": lambda _: 0.0,}
24
+
25
+ model= PPO.load(checkpoint, custom_objects=custom_objects)
26
+
27
+ eval_env = make_atari_env(environment, n_envs=1)
28
+ eval_env = VecFrameStack(env, n_stack=4)
29
+
30
+ video_folder = 'logs/videos/'
31
+ video_length = 100
32
+
33
+
34
+ # Record the video starting at the first step
35
+ env = VecVideoRecorder(eval_env, video_folder,
36
+ record_video_trigger=lambda x: x == 0, video_length=video_length,
37
+ name_prefix=f"random-agent-{env_id}")
38
+
39
+ env.reset()
40
+ for _ in range(video_length + 1):
41
+ action, _states = model.predict(obs)
42
+ obs, rewards, dones, info = env.step(action)
43
+ # Save the video
44
+ env.close()
45
+
46
 
47
  iface = gr.Interface(fn=replay, inputs=[
48
  gr.inputs.Textbox(lines=1, placeholder=None, default="", label="Model Id: "),