vwxyzjn commited on
Commit
205d55f
1 Parent(s): 5a61c9d

pushing model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  sebulba_ppo_envpool_impala_atari_wrapper.cleanrl_model filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  sebulba_ppo_envpool_impala_atari_wrapper.cleanrl_model filter=lfs diff=lfs merge=lfs -text
36
+ videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__041f1f0f-98c5-41f3-a73d-19d8967acb56-eval/0.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -16,7 +16,7 @@ model-index:
16
  type: Enduro-v5
17
  metrics:
18
  - type: mean_reward
19
- value: 0.00 +/- 0.00
20
  name: mean_reward
21
  verified: false
22
  ---
@@ -46,7 +46,7 @@ curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_ata
46
  curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/pyproject.toml
47
  curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/poetry.lock
48
  poetry install --all-extras
49
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_impala_atari_wrapper --actor-device-ids 0 --learner-device-ids 1 2 3 4 --params-queue-timeout 0.02 --track --save-model --upload-model --hf-entity cleanrl --env-id Enduro-v5 --seed 1
50
  ```
51
 
52
  # Hyperparameters
@@ -75,7 +75,6 @@ python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_impala_atari_wrappe
75
  'num_minibatches': 4,
76
  'num_steps': 128,
77
  'num_updates': 6103,
78
- 'params_queue_timeout': 0.02,
79
  'profile': False,
80
  'save_model': True,
81
  'seed': 1,
 
16
  type: Enduro-v5
17
  metrics:
18
  - type: mean_reward
19
+ value: 2360.30 +/- 28.22
20
  name: mean_reward
21
  verified: false
22
  ---
 
46
  curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/pyproject.toml
47
  curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/poetry.lock
48
  poetry install --all-extras
49
+ python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_impala_atari_wrapper --actor-device-ids 0 --learner-device-ids 1 2 3 4 --track --save-model --upload-model --hf-entity cleanrl --env-id Enduro-v5 --seed 1
50
  ```
51
 
52
  # Hyperparameters
 
75
  'num_minibatches': 4,
76
  'num_steps': 128,
77
  'num_updates': 6103,
 
78
  'profile': False,
79
  'save_model': True,
80
  'seed': 1,
events.out.tfevents.1675632419.ip-26-0-136-145.1447874.0 → events.out.tfevents.1675915696.ip-26-0-141-11.220473.0 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:96231d4de1de01916fb76297943edb7fd6c53c2348723f2e4bbb686f92f9f6c6
3
- size 9358634
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:334ee3e0ae81b95888676bd1638b046be4a90062198c22e743597c23211804d7
3
+ size 8912675
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
sebulba_ppo_envpool.py CHANGED
@@ -1,89 +1,24 @@
1
  """
2
- 0. multi-threaded actor
3
- python sebulba_ppo_envpool.py --actor-device-ids 0 --num-actor-threads 2 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
4
- python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
5
-
6
- 🔥 core settings:
7
-
8
- * test throughput
9
- * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l1_timeout --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
10
- * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l12_timeout --actor-device-ids 0 --learner-device-ids 1 2 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
11
  * this will help us diagnose the throughput issue
12
- * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l1 --actor-device-ids 0 --learner-device-ids 1 --profile --total-timesteps 500000 --track
13
- * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l12 --actor-device-ids 0 --learner-device-ids 1 2 --profile --total-timesteps 500000 --track
14
- * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --num-actor-threads 2 --track
15
- * Best performance so far
16
- * python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track
17
- * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --params-queue-timeout 0.02 --track
18
-
19
- # 1. rollout is faster than training
20
-
21
- ## throughput
22
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
23
-
24
- ## shared: actor on GPU0 and learner on GPU0
25
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 --total-timesteps 500000 --track
26
-
27
- ## separate: actor on GPU0 and learner on GPU1
28
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 --total-timesteps 500000 --track
29
-
30
- ## shared: actor on GPU0 and learner on GPU0,1
31
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track
32
-
33
- ## separate: actor on GPU0 and learner on GPU1,2
34
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 2 --total-timesteps 500000 --track
35
-
36
-
37
- # 1.1 rollout is faster than training w/ timeout
38
-
39
- ## shared: actor on GPU0 and learner on GPU0
40
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 0 --params-queue-timeout 0.02 --total-timesteps 500000 --track
41
-
42
- ## separate: actor on GPU0 and learner on GPU1
43
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --total-timesteps 500000 --track
44
-
45
- ## shared: actor on GPU0 and learner on GPU0,1
46
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 0 1 --params-queue-timeout 0.02 --total-timesteps 500000 --track
47
-
48
- ## separate: actor on GPU0 and learner on GPU1,2
49
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 1 2 --params-queue-timeout 0.02 --total-timesteps 500000 --track
50
-
51
- # 1.2. rollout is much faster than training w/ timeout
52
-
53
- ## throughput
54
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 1 --update-epochs 8 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
55
-
56
- ## shared: actor on GPU0 and learner on GPU0,1
57
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 0 1 --update-epochs 8 --params-queue-timeout 0.02 --total-timesteps 500000 --track
58
-
59
- ## separate: actor on GPU0 and learner on GPU1,2
60
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 1 2 --update-epochs 8 --params-queue-timeout 0.02 --total-timesteps 500000 --track
61
-
62
- # 2. training is faster than rollout
63
-
64
- ## throughput
65
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
66
-
67
- ## shared: actor on GPU0 and learner on GPU0
68
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 0 --total-timesteps 500000 --track
69
-
70
- ## separate: actor on GPU0 and learner on GPU1
71
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 --total-timesteps 500000 --track
72
-
73
- ## shared: actor on GPU0 and learner on GPU0,1
74
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track
75
-
76
- ## separate: actor on GPU0 and learner on GPU1,2
77
- python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 2 --total-timesteps 500000 --track
78
-
79
  """
80
  # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
81
- # https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
82
  import argparse
83
  import os
84
  import random
85
  import time
86
  import uuid
 
87
  from collections import deque
88
  from distutils.util import strtobool
89
  from functools import partial
@@ -182,23 +117,18 @@ def parse_args():
182
  help="whether to call block_until_ready() for profiling")
183
  parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
184
  help="whether to test actor-learner throughput by removing the actor-learner communication")
185
- parser.add_argument("--params-queue-timeout", type=float, default=None,
186
- help="the timeout for the `params_queue.get()` operation in the actor thread to pull params;" + \
187
- "by default it's `None`; if you set a timeout, it will likely make the actor run faster but will introduce some side effects," + \
188
- "such as the actor will not be able to pull the latest params from the learner and will use the old params instead")
189
  args = parser.parse_args()
190
  args.batch_size = int(args.num_envs * args.num_steps)
191
  args.minibatch_size = int(args.batch_size // args.num_minibatches)
192
  args.num_updates = args.total_timesteps // args.batch_size
193
  args.async_update = int(args.num_envs / args.async_batch_size)
194
  assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now"
 
 
195
  # fmt: on
196
  return args
197
 
198
 
199
- LEARNER_WARMUP_TIME = 10 # seconds
200
-
201
-
202
  def make_env(env_id, seed, num_envs, async_batch_size=1, num_threads=None, thread_affinity_offset=-1):
203
  def thunk():
204
  envs = envpool.make(
@@ -394,8 +324,14 @@ def rollout(
394
  rollout_time = deque(maxlen=10)
395
  data_transfer_time = deque(maxlen=10)
396
  rollout_queue_put_time = deque(maxlen=10)
397
- params_timeout_count = 0
398
  for update in range(1, args.num_updates + 2):
 
 
 
 
 
 
399
  update_time_start = time.time()
400
  obs = []
401
  dones = []
@@ -411,21 +347,15 @@ def rollout(
411
  storage_time = 0
412
  env_send_time = 0
413
 
414
- # NOTE: This is a major difference from the sync version:
415
- # at the end of the rollout phase, the sync version will have the next observation
416
- # ready for the value bootstrap, but the async version will not have it.
417
- # for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping.
418
- # but note that the extra states are not used for the loss computation in the next iteration,
419
- # while the sync version will use the extra state for the loss computation.
420
  params_queue_get_time_start = time.time()
421
- try:
422
- params = params_queue.get(timeout=args.params_queue_timeout)
423
- except queue.Empty:
424
- # print("params_queue.get timeout triggered")
425
- params_timeout_count += 1
426
  params_queue_get_time.append(time.time() - params_queue_get_time_start)
427
  writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
428
- writer.add_scalar("stats/params_queue_timeout_count", params_timeout_count, global_step)
429
  rollout_time_start = time.time()
430
  for _ in range(
431
  args.async_update, (args.num_steps + 1) * args.async_update
@@ -496,6 +426,7 @@ def rollout(
496
  )
497
  payload = (
498
  global_step,
 
499
  update,
500
  jnp.array_split(b_obs, len(learner_devices)),
501
  jnp.array_split(b_actions, len(learner_devices)),
@@ -513,9 +444,6 @@ def rollout(
513
  rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start)
514
  writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step)
515
 
516
- if update == 1 or update == 2 or update == 3:
517
- time.sleep(LEARNER_WARMUP_TIME) # makes sure the actor does to fill the rollout_queue at the get go
518
-
519
  writer.add_scalar(
520
  "charts/SPS_update",
521
  int(
@@ -709,6 +637,7 @@ if __name__ == "__main__":
709
  monitor_gym=True,
710
  save_code=True,
711
  )
 
712
  writer = SummaryWriter(f"runs/{run_name}")
713
  writer.add_text(
714
  "hyperparameters",
@@ -762,7 +691,7 @@ if __name__ == "__main__":
762
  static_broadcasted_argnums=(6),
763
  )
764
 
765
- rollout_queue = queue.Queue(maxsize=2)
766
  params_queues = []
767
  num_cpus = mp.cpu_count()
768
  fair_num_cpus = num_cpus // len(args.actor_device_ids)
@@ -771,14 +700,10 @@ if __name__ == "__main__":
771
  def add_scalar(self, arg0, arg1, arg3):
772
  pass
773
 
774
- # lock = threading.Lock()
775
- # AgentParamsStore = namedtuple("AgentParamsStore", ["params", "version"])
776
- # agent_params_store = AgentParamsStore(agent_state.params, 0)
777
-
778
  dummy_writer = DummyWriter()
779
  for d_idx, d_id in enumerate(args.actor_device_ids):
780
  for j in range(args.num_actor_threads):
781
- params_queue = queue.Queue(maxsize=2)
782
  params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id]))
783
  threading.Thread(
784
  target=rollout,
@@ -797,12 +722,21 @@ if __name__ == "__main__":
797
  params_queues.append(params_queue)
798
 
799
  rollout_queue_get_time = deque(maxlen=10)
800
- learner_update = 0
801
  while True:
802
- learner_update += 1
803
- if learner_update == 1 or not args.test_actor_learner_throughput:
804
  rollout_queue_get_time_start = time.time()
805
- global_step, update, b_obs, b_actions, b_logprobs, b_advantages, b_returns = rollout_queue.get()
 
 
 
 
 
 
 
 
 
806
  rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
807
  writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
808
 
@@ -817,7 +751,7 @@ if __name__ == "__main__":
817
  envs.single_action_space.n,
818
  key,
819
  )
820
- if learner_update == 1 or not args.test_actor_learner_throughput:
821
  for d_idx, d_id in enumerate(args.actor_device_ids):
822
  for j in range(args.num_actor_threads):
823
  params_queues[d_idx * args.num_actor_threads + j].put(
@@ -828,7 +762,10 @@ if __name__ == "__main__":
828
  writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
829
  writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
830
  writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
831
- print(global_step, update, rollout_queue.qsize(), f"training time: {time.time() - training_time_start}s")
 
 
 
832
 
833
  # TRY NOT TO MODIFY: record rewards for plotting purposes
834
  writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
@@ -837,7 +774,7 @@ if __name__ == "__main__":
837
  writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step)
838
  writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step)
839
  writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step)
840
- if update > args.num_updates:
841
  break
842
 
843
  if args.save_model:
 
1
  """
2
+ * 🥼 Test throughput (see docs):
3
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
4
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
 
 
 
 
 
 
5
  * this will help us diagnose the throughput issue
6
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --num-actor-threads 2 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
7
+ * 🔥 Best performance so far (more GPUs -> faster)
8
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 --track
9
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 1 --track
10
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 --num-envs 60 --async-batch-size 20 --track
11
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --track
12
+ * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 --num-envs 60 --async-batch-size 20 --track
13
+ * (this actually doesn't work that well) python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 7 --num-envs 70 --async-batch-size 35 --track
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
  # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
 
16
  import argparse
17
  import os
18
  import random
19
  import time
20
  import uuid
21
+ import warnings
22
  from collections import deque
23
  from distutils.util import strtobool
24
  from functools import partial
 
117
  help="whether to call block_until_ready() for profiling")
118
  parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
119
  help="whether to test actor-learner throughput by removing the actor-learner communication")
 
 
 
 
120
  args = parser.parse_args()
121
  args.batch_size = int(args.num_envs * args.num_steps)
122
  args.minibatch_size = int(args.batch_size // args.num_minibatches)
123
  args.num_updates = args.total_timesteps // args.batch_size
124
  args.async_update = int(args.num_envs / args.async_batch_size)
125
  assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now"
126
+ if args.num_actor_threads > 1:
127
+ warnings.warn("⚠️ !!!! `num_actor_threads` > 1 is not tested with learning; see docs for detail")
128
  # fmt: on
129
  return args
130
 
131
 
 
 
 
132
  def make_env(env_id, seed, num_envs, async_batch_size=1, num_threads=None, thread_affinity_offset=-1):
133
  def thunk():
134
  envs = envpool.make(
 
324
  rollout_time = deque(maxlen=10)
325
  data_transfer_time = deque(maxlen=10)
326
  rollout_queue_put_time = deque(maxlen=10)
327
+ actor_policy_version = 0
328
  for update in range(1, args.num_updates + 2):
329
+ # NOTE: This is a major difference from the sync version:
330
+ # at the end of the rollout phase, the sync version will have the next observation
331
+ # ready for the value bootstrap, but the async version will not have it.
332
+ # for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping.
333
+ # but note that the extra states are not used for the loss computation in the next iteration,
334
+ # while the sync version will use the extra state for the loss computation.
335
  update_time_start = time.time()
336
  obs = []
337
  dones = []
 
347
  storage_time = 0
348
  env_send_time = 0
349
 
350
+ # NOTE: `update != 2` is actually IMPORTANT it allows us to start running policy collection
351
+ # concurrently with the learning process. It also ensures the actor's policy version is only 1 step
352
+ # behind the learner's policy version
 
 
 
353
  params_queue_get_time_start = time.time()
354
+ if update != 2:
355
+ params = params_queue.get()
356
+ actor_policy_version += 1
 
 
357
  params_queue_get_time.append(time.time() - params_queue_get_time_start)
358
  writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
 
359
  rollout_time_start = time.time()
360
  for _ in range(
361
  args.async_update, (args.num_steps + 1) * args.async_update
 
426
  )
427
  payload = (
428
  global_step,
429
+ actor_policy_version,
430
  update,
431
  jnp.array_split(b_obs, len(learner_devices)),
432
  jnp.array_split(b_actions, len(learner_devices)),
 
444
  rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start)
445
  writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step)
446
 
 
 
 
447
  writer.add_scalar(
448
  "charts/SPS_update",
449
  int(
 
637
  monitor_gym=True,
638
  save_code=True,
639
  )
640
+ print(devices)
641
  writer = SummaryWriter(f"runs/{run_name}")
642
  writer.add_text(
643
  "hyperparameters",
 
691
  static_broadcasted_argnums=(6),
692
  )
693
 
694
+ rollout_queue = queue.Queue(maxsize=1)
695
  params_queues = []
696
  num_cpus = mp.cpu_count()
697
  fair_num_cpus = num_cpus // len(args.actor_device_ids)
 
700
  def add_scalar(self, arg0, arg1, arg3):
701
  pass
702
 
 
 
 
 
703
  dummy_writer = DummyWriter()
704
  for d_idx, d_id in enumerate(args.actor_device_ids):
705
  for j in range(args.num_actor_threads):
706
+ params_queue = queue.Queue(maxsize=1)
707
  params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id]))
708
  threading.Thread(
709
  target=rollout,
 
722
  params_queues.append(params_queue)
723
 
724
  rollout_queue_get_time = deque(maxlen=10)
725
+ learner_policy_version = 0
726
  while True:
727
+ learner_policy_version += 1
728
+ if learner_policy_version == 1 or not args.test_actor_learner_throughput:
729
  rollout_queue_get_time_start = time.time()
730
+ (
731
+ global_step,
732
+ actor_policy_version,
733
+ update,
734
+ b_obs,
735
+ b_actions,
736
+ b_logprobs,
737
+ b_advantages,
738
+ b_returns,
739
+ ) = rollout_queue.get()
740
  rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
741
  writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
742
 
 
751
  envs.single_action_space.n,
752
  key,
753
  )
754
+ if learner_policy_version == 1 or not args.test_actor_learner_throughput:
755
  for d_idx, d_id in enumerate(args.actor_device_ids):
756
  for j in range(args.num_actor_threads):
757
  params_queues[d_idx * args.num_actor_threads + j].put(
 
762
  writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
763
  writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
764
  writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
765
+ print(
766
+ global_step,
767
+ f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
768
+ )
769
 
770
  # TRY NOT TO MODIFY: record rewards for plotting purposes
771
  writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
 
774
  writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step)
775
  writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step)
776
  writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step)
777
+ if update >= args.num_updates:
778
  break
779
 
780
  if args.save_model:
sebulba_ppo_envpool_impala_atari_wrapper.cleanrl_model CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ab7d7a2c4eee663199cccdbed35d3b925cc7bb183bd82b2f416a6645be026e8
3
- size 4369110
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1241c474effc0c32294ea8a5a92151274451b8b4f4bf1b833c2bc0ec7f7b0780
3
+ size 4369080
videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__041f1f0f-98c5-41f3-a73d-19d8967acb56-eval/0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1642fa75fb20b0b56ac991e4834ab66724d2c8ed8823cd9b26234d856d0370bc
3
+ size 3660933
videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__701dcf31-d5b4-4d80-bfe6-a77f88357693-eval/0.mp4 DELETED
Binary file (320 kB)