Thomas Simonini commited on
Commit
1c71ebe
β€’
1 Parent(s): 870aa65

Added work in progress generate video live

Browse files
Files changed (1) hide show
  1. app.py +110 -12
app.py CHANGED
@@ -18,20 +18,116 @@ def replay(option):
18
  videoclip.write_videofile("new_filename.mp4")
19
  return 'new_filename.mp4'
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  """
22
  TODO: Next version with live video generation
23
- def replay_classical(hf_model_filename, hf_model_id):
24
- import gym
25
- from stable_baselines3 import PPO
26
- from stable_baselines3.common.evaluation import evaluate_policy
27
 
28
- model = PPO.load_from_huggingface(hf_model_id,hf_model_filename)
29
 
30
- eval_env = gym.make(option)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def replay_atari(hf_model_filename, hf_model_id):
34
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  iface = gr.Interface(
37
  replay,
@@ -43,19 +139,21 @@ iface = gr.Interface(
43
  description = '',
44
  article =
45
  '''<div>
46
- <p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p>
47
- <p style="text-align: center"> Select the trained agent you want to watch perform.
 
48
  These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p>
49
  <p>
50
  There are currently 3 models:
51
  <ul>
52
- <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li>
53
- <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li>
54
- <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li>
55
  </ul>
56
  </div>'''
57
  )
58
 
59
 
60
  iface.launch()
 
61
 
18
  videoclip.write_videofile("new_filename.mp4")
19
  return 'new_filename.mp4'
20
 
21
+ iface = gr.Interface(
22
+ replay,
23
+ [
24
+ gr.inputs.Dropdown(["Atari Space Invaders πŸ‘Ύ", "CartPole-v1 πŸ•ΉοΈ", "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"]),
25
+ ],
26
+ "video",
27
+ title = 'Stable Baselines 3 with πŸ€—',
28
+ description = '',
29
+ article =
30
+ '''<div>
31
+ <p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p>
32
+ <p style="text-align: center"> Select the trained agent you want to watch perform.
33
+ These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p>
34
+ <p>
35
+ There are currently 3 models:
36
+ <ul>
37
+ <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li>
38
+ <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li>
39
+ <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li>
40
+ </ul>
41
+ </div>'''
42
+ )
43
+
44
+
45
+ iface.launch()
46
+
47
  """
48
  TODO: Next version with live video generation
49
+ import gradio as gr
50
+ import os
 
 
51
 
52
+ from Recorder import Recorder
53
 
54
+ from stable_baselines3 import PPO
55
+
56
+
57
+ #The Agent plays and we generate the video
58
+ def replay(option):
59
+ video_path = ""
60
+ # Get the correct model
61
+ if (option == "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"):
62
+ env_name = "Lunar Lander v2"
63
+ agent_name = "PPO"
64
+ print("TEST")
65
+ hf_model_filename = "LunarLander-v2"
66
+ hf_model_id = "ThomasSimonini/stable-baselines3-ppo-LunarLander-v2"
67
+ video_path = replay_gym(hf_model_filename, hf_model_id)
68
+ elif(option == "CartPole-v1 πŸ•ΉοΈ"):
69
+ hf_model_filename = "CartPole-v1"
70
+ hf_model_id = "ThomasSimonini/stable-baselines3-ppo-CartPole-v1"
71
+ video_path = replay_gym(hf_model_filename, hf_model_id)
72
+ elif(option == "Atari Space Invaders πŸ‘Ύ"):
73
+ hf_model_filename = "SpaceInvadersNoFrameskip-v4"
74
+ hf_model_id = "ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4"
75
+ video_path = replay_atari(hf_model_filename, hf_model_id)
76
+ #video_path = "./SpaceInvadersNoFrameskip-v4.mp4"
77
+
78
+ return video_path
79
+
80
+
81
+ def replay_gym(hf_model_filename, hf_model_id):
82
+ import gym
83
+ from stable_baselines3.common.evaluation import evaluate_policy
84
+
85
+
86
+ model = PPO.load_from_huggingface(hf_model_id,hf_model_filename)
87
+
88
+ eval_env = gym.make(hf_model_filename)
89
+
90
+ directory = './video'
91
+ env = Recorder(eval_env, directory)
92
+
93
+ obs = env.reset()
94
+ done = False
95
+ while not done:
96
+ action, _state = model.predict(obs)
97
+ obs, reward, done, info = env.step(action)
98
+ clip = env.play()
99
+ return clip
100
 
101
 
102
  def replay_atari(hf_model_filename, hf_model_id):
103
+ os.system("python -m atari_py.import_roms \"content/atari_roms\"")
104
+ import gym
105
+ from stable_baselines3.common.env_util import make_atari_env
106
+ from stable_baselines3.common.vec_env import VecFrameStack
107
+
108
+ from stable_baselines3.common.evaluation import evaluate_policy
109
+
110
+ model = PPO.load_from_huggingface(hf_model_id, hf_model_filename)
111
+
112
+
113
+ eval_env = make_atari_env(hf_model_filename, n_envs=1, seed=0)
114
+ eval_env = VecFrameStack(eval_env, n_stack=4)
115
+
116
+ model = PPO.load_from_huggingface(hf_model_id, hf_model_filename)
117
+
118
+ import gym
119
+ directory = './video'
120
+ env = Recorder(eval_env, directory)
121
+
122
+ obs = env.reset()
123
+ done = False
124
+ while not done:
125
+ action, _state = model.predict(obs)
126
+ obs, reward, done, info = env.step(action)
127
+ clip = env.play()
128
+ return clip
129
+
130
+
131
 
132
  iface = gr.Interface(
133
  replay,
139
  description = '',
140
  article =
141
  '''<div>
142
+ <p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p>
143
+ <p style="text-align: center"> Select the trained agent you want to watch perform. We record your agent playing.
144
+ <p style="text-align: center"> Don't forget to <b>click on clear between each record.</b> </p>
145
  These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p>
146
  <p>
147
  There are currently 3 models:
148
  <ul>
149
+ <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li>
150
+ <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li>
151
+ <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li>
152
  </ul>
153
  </div>'''
154
  )
155
 
156
 
157
  iface.launch()
158
+ """
159