ThomasSimonini HF staff commited on
Commit
a19d046
1 Parent(s): e19559e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import time
4
+
5
+ from huggingface_sb3 import load_from_hub
6
+
7
+ from stable_baselines3 import PPO
8
+ from stable_baselines3.common.env_util import make_atari_env
9
+ from stable_baselines3.common.vec_env import VecFrameStack
10
+
11
+ from stable_baselines3.common.env_util import make_atari_env
12
+
13
+ # Loading functions were taken from Edward Beeching code
14
+ def load_env(env_name):
15
+ env = make_atari_env(env_name, n_envs=1)
16
+ env = VecFrameStack(env, n_stack=4)
17
+ return env
18
+
19
+ def load_model(env_name):
20
+ custom_objects = {
21
+ "learning_rate": 0.0,
22
+ "lr_schedule": lambda _: 0.0,
23
+ "clip_range": lambda _: 0.0,
24
+ }
25
+
26
+ checkpoint = load_from_hub(
27
+ f"ThomasSimonini/ppo-{env_name}",
28
+ f"ppo-{env_name}.zip",
29
+ )
30
+
31
+ model = PPO.load(checkpoint, custom_objects=custom_objects)
32
+
33
+ return model
34
+
35
+ def replay(env_name, time_sleep, num_episodes):
36
+ env = load_env(env_name)
37
+ model = load_model(env_name)
38
+ for i in range(num_episodes):
39
+ obs = env.reset()
40
+ done = False
41
+ while not done:
42
+ frame = env.render(mode="rgb_array")
43
+ action, _states = model.predict(obs)
44
+ obs, reward, done, info = env.step([action])
45
+ time.sleep(time_sleep)
46
+ yield frame
47
+
48
+
49
+ demo = gr.Interface(
50
+ replay,
51
+ [gr.Dropdown(["SpaceInvadersNoFrameskip-v4",
52
+ "PongNoFrameskip-v4",
53
+ "SeaquestNoFrameskip-v4",
54
+ "QbertNoFrameskip-v4",
55
+ ]),
56
+ gr.Slider(0.01, 1, value=0.1),
57
+ gr.Slider(1, 20, value=5)
58
+ ],
59
+ gr.Image(),
60
+ title="Watch Agents playing Atari games 🤖",
61
+ description="Select an environment to watch a Hugging Face's trained deep reinforcement learning agent."
62
+ article = "time_sleep is the time delay between each frame (0.1 by default)."
63
+ ).launch(debug=True, enable_queue=True).queue()