ThomasSimonini HF staff commited on
Commit
197c70e
1 Parent(s): 167b87e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -10,6 +10,8 @@ from stable_baselines3.common.vec_env import VecFrameStack
10
 
11
  from stable_baselines3.common.env_util import make_atari_env
12
 
 
 
13
  # Loading functions were taken from Edward Beeching code
14
  def load_env(env_name):
15
  env = make_atari_env(env_name, n_envs=1)
@@ -32,19 +34,23 @@ def load_model(env_name):
32
 
33
  return model
34
 
35
- def replay(env_name, time_sleep): #, num_episodes):
36
  env = load_env(env_name)
37
  model = load_model(env_name)
38
  #for i in range(num_episodes):
39
  obs = env.reset()
40
  done = False
 
41
  while not done:
42
- frame = env.render(mode="rgb_array")
43
- action, _states = model.predict(obs)
44
- obs, reward, done, info = env.step([action])
45
- time.sleep(time_sleep)
46
- yield frame
47
-
 
 
 
48
 
49
  demo = gr.Interface(
50
  replay,
@@ -53,6 +59,7 @@ demo = gr.Interface(
53
  "SeaquestNoFrameskip-v4",
54
  "QbertNoFrameskip-v4",
55
  ]),
 
56
  gr.Slider(0.01, 1, value=0.05),
57
  #gr.Slider(1, 20, value=5)
58
  ],
 
10
 
11
  from stable_baselines3.common.env_util import make_atari_env
12
 
13
+ max_steps = 5000 # Let's try with 5000 steps.
14
+
15
  # Loading functions were taken from Edward Beeching code
16
  def load_env(env_name):
17
  env = make_atari_env(env_name, n_envs=1)
 
34
 
35
  return model
36
 
37
+ def replay(env_name, max_steps, time_sleep):
38
  env = load_env(env_name)
39
  model = load_model(env_name)
40
  #for i in range(num_episodes):
41
  obs = env.reset()
42
  done = False
43
+ i = 0
44
  while not done:
45
+ i++
46
+ if i < max_steps:
47
+ frame = env.render(mode="rgb_array")
48
+ action, _states = model.predict(obs)
49
+ obs, reward, done, info = env.step([action])
50
+ time.sleep(time_sleep)
51
+ yield frame
52
+ else:
53
+ break
54
 
55
  demo = gr.Interface(
56
  replay,
 
59
  "SeaquestNoFrameskip-v4",
60
  "QbertNoFrameskip-v4",
61
  ]),
62
+ gr.Slider(100, 10000, value=500),
63
  gr.Slider(0.01, 1, value=0.05),
64
  #gr.Slider(1, 20, value=5)
65
  ],