pushing model
Browse files- .gitattributes +2 -0
- README.md +2 -3
- events.out.tfevents.1675632419.ip-26-0-136-145.1447874.0 → events.out.tfevents.1675915696.ip-26-0-141-11.220473.0 +2 -2
- replay.mp4 +0 -0
- sebulba_ppo_envpool.py +50 -113
- sebulba_ppo_envpool_impala_atari_wrapper.cleanrl_model +2 -2
- videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__041f1f0f-98c5-41f3-a73d-19d8967acb56-eval/0.mp4 +3 -0
- videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__701dcf31-d5b4-4d80-bfe6-a77f88357693-eval/0.mp4 +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
sebulba_ppo_envpool_impala_atari_wrapper.cleanrl_model filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
sebulba_ppo_envpool_impala_atari_wrapper.cleanrl_model filter=lfs diff=lfs merge=lfs -text
|
36 |
+
videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__041f1f0f-98c5-41f3-a73d-19d8967acb56-eval/0.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
replay.mp4 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -16,7 +16,7 @@ model-index:
|
|
16 |
type: Enduro-v5
|
17 |
metrics:
|
18 |
- type: mean_reward
|
19 |
-
value:
|
20 |
name: mean_reward
|
21 |
verified: false
|
22 |
---
|
@@ -46,7 +46,7 @@ curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_ata
|
|
46 |
curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/pyproject.toml
|
47 |
curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/poetry.lock
|
48 |
poetry install --all-extras
|
49 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_impala_atari_wrapper --actor-device-ids 0 --learner-device-ids 1 2 3 4 --
|
50 |
```
|
51 |
|
52 |
# Hyperparameters
|
@@ -75,7 +75,6 @@ python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_impala_atari_wrappe
|
|
75 |
'num_minibatches': 4,
|
76 |
'num_steps': 128,
|
77 |
'num_updates': 6103,
|
78 |
-
'params_queue_timeout': 0.02,
|
79 |
'profile': False,
|
80 |
'save_model': True,
|
81 |
'seed': 1,
|
|
|
16 |
type: Enduro-v5
|
17 |
metrics:
|
18 |
- type: mean_reward
|
19 |
+
value: 2360.30 +/- 28.22
|
20 |
name: mean_reward
|
21 |
verified: false
|
22 |
---
|
|
|
46 |
curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/pyproject.toml
|
47 |
curl -OL https://huggingface.co/cleanrl/Enduro-v5-sebulba_ppo_envpool_impala_atari_wrapper-seed1/raw/main/poetry.lock
|
48 |
poetry install --all-extras
|
49 |
+
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_impala_atari_wrapper --actor-device-ids 0 --learner-device-ids 1 2 3 4 --track --save-model --upload-model --hf-entity cleanrl --env-id Enduro-v5 --seed 1
|
50 |
```
|
51 |
|
52 |
# Hyperparameters
|
|
|
75 |
'num_minibatches': 4,
|
76 |
'num_steps': 128,
|
77 |
'num_updates': 6103,
|
|
|
78 |
'profile': False,
|
79 |
'save_model': True,
|
80 |
'seed': 1,
|
events.out.tfevents.1675632419.ip-26-0-136-145.1447874.0 → events.out.tfevents.1675915696.ip-26-0-141-11.220473.0
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:334ee3e0ae81b95888676bd1638b046be4a90062198c22e743597c23211804d7
|
3 |
+
size 8912675
|
replay.mp4
CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
|
|
sebulba_ppo_envpool.py
CHANGED
@@ -1,89 +1,24 @@
|
|
1 |
"""
|
2 |
-
|
3 |
-
python sebulba_ppo_envpool.py --actor-device-ids 0 --
|
4 |
-
python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1
|
5 |
-
|
6 |
-
🔥 core settings:
|
7 |
-
|
8 |
-
* test throughput
|
9 |
-
* python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l1_timeout --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
10 |
-
* python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l12_timeout --actor-device-ids 0 --learner-device-ids 1 2 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
11 |
* this will help us diagnose the throughput issue
|
12 |
-
* python sebulba_ppo_envpool.py --
|
13 |
-
|
14 |
-
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids
|
15 |
-
*
|
16 |
-
* python sebulba_ppo_envpool.py --
|
17 |
-
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
## throughput
|
22 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
23 |
-
|
24 |
-
## shared: actor on GPU0 and learner on GPU0
|
25 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 --total-timesteps 500000 --track
|
26 |
-
|
27 |
-
## separate: actor on GPU0 and learner on GPU1
|
28 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 --total-timesteps 500000 --track
|
29 |
-
|
30 |
-
## shared: actor on GPU0 and learner on GPU0,1
|
31 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track
|
32 |
-
|
33 |
-
## separate: actor on GPU0 and learner on GPU1,2
|
34 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 2 --total-timesteps 500000 --track
|
35 |
-
|
36 |
-
|
37 |
-
# 1.1 rollout is faster than training w/ timeout
|
38 |
-
|
39 |
-
## shared: actor on GPU0 and learner on GPU0
|
40 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 0 --params-queue-timeout 0.02 --total-timesteps 500000 --track
|
41 |
-
|
42 |
-
## separate: actor on GPU0 and learner on GPU1
|
43 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --total-timesteps 500000 --track
|
44 |
-
|
45 |
-
## shared: actor on GPU0 and learner on GPU0,1
|
46 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 0 1 --params-queue-timeout 0.02 --total-timesteps 500000 --track
|
47 |
-
|
48 |
-
## separate: actor on GPU0 and learner on GPU1,2
|
49 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 1 2 --params-queue-timeout 0.02 --total-timesteps 500000 --track
|
50 |
-
|
51 |
-
# 1.2. rollout is much faster than training w/ timeout
|
52 |
-
|
53 |
-
## throughput
|
54 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 1 --update-epochs 8 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
55 |
-
|
56 |
-
## shared: actor on GPU0 and learner on GPU0,1
|
57 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 0 1 --update-epochs 8 --params-queue-timeout 0.02 --total-timesteps 500000 --track
|
58 |
-
|
59 |
-
## separate: actor on GPU0 and learner on GPU1,2
|
60 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 1 2 --update-epochs 8 --params-queue-timeout 0.02 --total-timesteps 500000 --track
|
61 |
-
|
62 |
-
# 2. training is faster than rollout
|
63 |
-
|
64 |
-
## throughput
|
65 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
66 |
-
|
67 |
-
## shared: actor on GPU0 and learner on GPU0
|
68 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 0 --total-timesteps 500000 --track
|
69 |
-
|
70 |
-
## separate: actor on GPU0 and learner on GPU1
|
71 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 --total-timesteps 500000 --track
|
72 |
-
|
73 |
-
## shared: actor on GPU0 and learner on GPU0,1
|
74 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track
|
75 |
-
|
76 |
-
## separate: actor on GPU0 and learner on GPU1,2
|
77 |
-
python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 2 --total-timesteps 500000 --track
|
78 |
-
|
79 |
"""
|
80 |
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
|
81 |
-
# https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
|
82 |
import argparse
|
83 |
import os
|
84 |
import random
|
85 |
import time
|
86 |
import uuid
|
|
|
87 |
from collections import deque
|
88 |
from distutils.util import strtobool
|
89 |
from functools import partial
|
@@ -182,23 +117,18 @@ def parse_args():
|
|
182 |
help="whether to call block_until_ready() for profiling")
|
183 |
parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
|
184 |
help="whether to test actor-learner throughput by removing the actor-learner communication")
|
185 |
-
parser.add_argument("--params-queue-timeout", type=float, default=None,
|
186 |
-
help="the timeout for the `params_queue.get()` operation in the actor thread to pull params;" + \
|
187 |
-
"by default it's `None`; if you set a timeout, it will likely make the actor run faster but will introduce some side effects," + \
|
188 |
-
"such as the actor will not be able to pull the latest params from the learner and will use the old params instead")
|
189 |
args = parser.parse_args()
|
190 |
args.batch_size = int(args.num_envs * args.num_steps)
|
191 |
args.minibatch_size = int(args.batch_size // args.num_minibatches)
|
192 |
args.num_updates = args.total_timesteps // args.batch_size
|
193 |
args.async_update = int(args.num_envs / args.async_batch_size)
|
194 |
assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now"
|
|
|
|
|
195 |
# fmt: on
|
196 |
return args
|
197 |
|
198 |
|
199 |
-
LEARNER_WARMUP_TIME = 10 # seconds
|
200 |
-
|
201 |
-
|
202 |
def make_env(env_id, seed, num_envs, async_batch_size=1, num_threads=None, thread_affinity_offset=-1):
|
203 |
def thunk():
|
204 |
envs = envpool.make(
|
@@ -394,8 +324,14 @@ def rollout(
|
|
394 |
rollout_time = deque(maxlen=10)
|
395 |
data_transfer_time = deque(maxlen=10)
|
396 |
rollout_queue_put_time = deque(maxlen=10)
|
397 |
-
|
398 |
for update in range(1, args.num_updates + 2):
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
update_time_start = time.time()
|
400 |
obs = []
|
401 |
dones = []
|
@@ -411,21 +347,15 @@ def rollout(
|
|
411 |
storage_time = 0
|
412 |
env_send_time = 0
|
413 |
|
414 |
-
# NOTE:
|
415 |
-
#
|
416 |
-
#
|
417 |
-
# for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping.
|
418 |
-
# but note that the extra states are not used for the loss computation in the next iteration,
|
419 |
-
# while the sync version will use the extra state for the loss computation.
|
420 |
params_queue_get_time_start = time.time()
|
421 |
-
|
422 |
-
params = params_queue.get(
|
423 |
-
|
424 |
-
# print("params_queue.get timeout triggered")
|
425 |
-
params_timeout_count += 1
|
426 |
params_queue_get_time.append(time.time() - params_queue_get_time_start)
|
427 |
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
|
428 |
-
writer.add_scalar("stats/params_queue_timeout_count", params_timeout_count, global_step)
|
429 |
rollout_time_start = time.time()
|
430 |
for _ in range(
|
431 |
args.async_update, (args.num_steps + 1) * args.async_update
|
@@ -496,6 +426,7 @@ def rollout(
|
|
496 |
)
|
497 |
payload = (
|
498 |
global_step,
|
|
|
499 |
update,
|
500 |
jnp.array_split(b_obs, len(learner_devices)),
|
501 |
jnp.array_split(b_actions, len(learner_devices)),
|
@@ -513,9 +444,6 @@ def rollout(
|
|
513 |
rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start)
|
514 |
writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step)
|
515 |
|
516 |
-
if update == 1 or update == 2 or update == 3:
|
517 |
-
time.sleep(LEARNER_WARMUP_TIME) # makes sure the actor does to fill the rollout_queue at the get go
|
518 |
-
|
519 |
writer.add_scalar(
|
520 |
"charts/SPS_update",
|
521 |
int(
|
@@ -709,6 +637,7 @@ if __name__ == "__main__":
|
|
709 |
monitor_gym=True,
|
710 |
save_code=True,
|
711 |
)
|
|
|
712 |
writer = SummaryWriter(f"runs/{run_name}")
|
713 |
writer.add_text(
|
714 |
"hyperparameters",
|
@@ -762,7 +691,7 @@ if __name__ == "__main__":
|
|
762 |
static_broadcasted_argnums=(6),
|
763 |
)
|
764 |
|
765 |
-
rollout_queue = queue.Queue(maxsize=
|
766 |
params_queues = []
|
767 |
num_cpus = mp.cpu_count()
|
768 |
fair_num_cpus = num_cpus // len(args.actor_device_ids)
|
@@ -771,14 +700,10 @@ if __name__ == "__main__":
|
|
771 |
def add_scalar(self, arg0, arg1, arg3):
|
772 |
pass
|
773 |
|
774 |
-
# lock = threading.Lock()
|
775 |
-
# AgentParamsStore = namedtuple("AgentParamsStore", ["params", "version"])
|
776 |
-
# agent_params_store = AgentParamsStore(agent_state.params, 0)
|
777 |
-
|
778 |
dummy_writer = DummyWriter()
|
779 |
for d_idx, d_id in enumerate(args.actor_device_ids):
|
780 |
for j in range(args.num_actor_threads):
|
781 |
-
params_queue = queue.Queue(maxsize=
|
782 |
params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id]))
|
783 |
threading.Thread(
|
784 |
target=rollout,
|
@@ -797,12 +722,21 @@ if __name__ == "__main__":
|
|
797 |
params_queues.append(params_queue)
|
798 |
|
799 |
rollout_queue_get_time = deque(maxlen=10)
|
800 |
-
|
801 |
while True:
|
802 |
-
|
803 |
-
if
|
804 |
rollout_queue_get_time_start = time.time()
|
805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
|
807 |
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
|
808 |
|
@@ -817,7 +751,7 @@ if __name__ == "__main__":
|
|
817 |
envs.single_action_space.n,
|
818 |
key,
|
819 |
)
|
820 |
-
if
|
821 |
for d_idx, d_id in enumerate(args.actor_device_ids):
|
822 |
for j in range(args.num_actor_threads):
|
823 |
params_queues[d_idx * args.num_actor_threads + j].put(
|
@@ -828,7 +762,10 @@ if __name__ == "__main__":
|
|
828 |
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
|
829 |
writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
|
830 |
writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
|
831 |
-
print(
|
|
|
|
|
|
|
832 |
|
833 |
# TRY NOT TO MODIFY: record rewards for plotting purposes
|
834 |
writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
|
@@ -837,7 +774,7 @@ if __name__ == "__main__":
|
|
837 |
writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step)
|
838 |
writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step)
|
839 |
writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step)
|
840 |
-
if update
|
841 |
break
|
842 |
|
843 |
if args.save_model:
|
|
|
1 |
"""
|
2 |
+
* 🥼 Test throughput (see docs):
|
3 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
4 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
* this will help us diagnose the throughput issue
|
6 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --num-actor-threads 2 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track
|
7 |
+
* 🔥 Best performance so far (more GPUs -> faster)
|
8 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 --track
|
9 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 1 --track
|
10 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 --num-envs 60 --async-batch-size 20 --track
|
11 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --track
|
12 |
+
* python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 --num-envs 60 --async-batch-size 20 --track
|
13 |
+
* (this actually doesn't work that well) python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 7 --num-envs 70 --async-batch-size 35 --track
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
"""
|
15 |
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
|
|
|
16 |
import argparse
|
17 |
import os
|
18 |
import random
|
19 |
import time
|
20 |
import uuid
|
21 |
+
import warnings
|
22 |
from collections import deque
|
23 |
from distutils.util import strtobool
|
24 |
from functools import partial
|
|
|
117 |
help="whether to call block_until_ready() for profiling")
|
118 |
parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
|
119 |
help="whether to test actor-learner throughput by removing the actor-learner communication")
|
|
|
|
|
|
|
|
|
120 |
args = parser.parse_args()
|
121 |
args.batch_size = int(args.num_envs * args.num_steps)
|
122 |
args.minibatch_size = int(args.batch_size // args.num_minibatches)
|
123 |
args.num_updates = args.total_timesteps // args.batch_size
|
124 |
args.async_update = int(args.num_envs / args.async_batch_size)
|
125 |
assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now"
|
126 |
+
if args.num_actor_threads > 1:
|
127 |
+
warnings.warn("⚠️ !!!! `num_actor_threads` > 1 is not tested with learning; see docs for detail")
|
128 |
# fmt: on
|
129 |
return args
|
130 |
|
131 |
|
|
|
|
|
|
|
132 |
def make_env(env_id, seed, num_envs, async_batch_size=1, num_threads=None, thread_affinity_offset=-1):
|
133 |
def thunk():
|
134 |
envs = envpool.make(
|
|
|
324 |
rollout_time = deque(maxlen=10)
|
325 |
data_transfer_time = deque(maxlen=10)
|
326 |
rollout_queue_put_time = deque(maxlen=10)
|
327 |
+
actor_policy_version = 0
|
328 |
for update in range(1, args.num_updates + 2):
|
329 |
+
# NOTE: This is a major difference from the sync version:
|
330 |
+
# at the end of the rollout phase, the sync version will have the next observation
|
331 |
+
# ready for the value bootstrap, but the async version will not have it.
|
332 |
+
# for this reason we do `num_steps + 1`` to get the extra states for value bootstrapping.
|
333 |
+
# but note that the extra states are not used for the loss computation in the next iteration,
|
334 |
+
# while the sync version will use the extra state for the loss computation.
|
335 |
update_time_start = time.time()
|
336 |
obs = []
|
337 |
dones = []
|
|
|
347 |
storage_time = 0
|
348 |
env_send_time = 0
|
349 |
|
350 |
+
# NOTE: `update != 2` is actually IMPORTANT — it allows us to start running policy collection
|
351 |
+
# concurrently with the learning process. It also ensures the actor's policy version is only 1 step
|
352 |
+
# behind the learner's policy version
|
|
|
|
|
|
|
353 |
params_queue_get_time_start = time.time()
|
354 |
+
if update != 2:
|
355 |
+
params = params_queue.get()
|
356 |
+
actor_policy_version += 1
|
|
|
|
|
357 |
params_queue_get_time.append(time.time() - params_queue_get_time_start)
|
358 |
writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
|
|
|
359 |
rollout_time_start = time.time()
|
360 |
for _ in range(
|
361 |
args.async_update, (args.num_steps + 1) * args.async_update
|
|
|
426 |
)
|
427 |
payload = (
|
428 |
global_step,
|
429 |
+
actor_policy_version,
|
430 |
update,
|
431 |
jnp.array_split(b_obs, len(learner_devices)),
|
432 |
jnp.array_split(b_actions, len(learner_devices)),
|
|
|
444 |
rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start)
|
445 |
writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step)
|
446 |
|
|
|
|
|
|
|
447 |
writer.add_scalar(
|
448 |
"charts/SPS_update",
|
449 |
int(
|
|
|
637 |
monitor_gym=True,
|
638 |
save_code=True,
|
639 |
)
|
640 |
+
print(devices)
|
641 |
writer = SummaryWriter(f"runs/{run_name}")
|
642 |
writer.add_text(
|
643 |
"hyperparameters",
|
|
|
691 |
static_broadcasted_argnums=(6),
|
692 |
)
|
693 |
|
694 |
+
rollout_queue = queue.Queue(maxsize=1)
|
695 |
params_queues = []
|
696 |
num_cpus = mp.cpu_count()
|
697 |
fair_num_cpus = num_cpus // len(args.actor_device_ids)
|
|
|
700 |
def add_scalar(self, arg0, arg1, arg3):
|
701 |
pass
|
702 |
|
|
|
|
|
|
|
|
|
703 |
dummy_writer = DummyWriter()
|
704 |
for d_idx, d_id in enumerate(args.actor_device_ids):
|
705 |
for j in range(args.num_actor_threads):
|
706 |
+
params_queue = queue.Queue(maxsize=1)
|
707 |
params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id]))
|
708 |
threading.Thread(
|
709 |
target=rollout,
|
|
|
722 |
params_queues.append(params_queue)
|
723 |
|
724 |
rollout_queue_get_time = deque(maxlen=10)
|
725 |
+
learner_policy_version = 0
|
726 |
while True:
|
727 |
+
learner_policy_version += 1
|
728 |
+
if learner_policy_version == 1 or not args.test_actor_learner_throughput:
|
729 |
rollout_queue_get_time_start = time.time()
|
730 |
+
(
|
731 |
+
global_step,
|
732 |
+
actor_policy_version,
|
733 |
+
update,
|
734 |
+
b_obs,
|
735 |
+
b_actions,
|
736 |
+
b_logprobs,
|
737 |
+
b_advantages,
|
738 |
+
b_returns,
|
739 |
+
) = rollout_queue.get()
|
740 |
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
|
741 |
writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
|
742 |
|
|
|
751 |
envs.single_action_space.n,
|
752 |
key,
|
753 |
)
|
754 |
+
if learner_policy_version == 1 or not args.test_actor_learner_throughput:
|
755 |
for d_idx, d_id in enumerate(args.actor_device_ids):
|
756 |
for j in range(args.num_actor_threads):
|
757 |
params_queues[d_idx * args.num_actor_threads + j].put(
|
|
|
762 |
writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step)
|
763 |
writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step)
|
764 |
writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step)
|
765 |
+
print(
|
766 |
+
global_step,
|
767 |
+
f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s",
|
768 |
+
)
|
769 |
|
770 |
# TRY NOT TO MODIFY: record rewards for plotting purposes
|
771 |
writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step)
|
|
|
774 |
writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step)
|
775 |
writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step)
|
776 |
writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step)
|
777 |
+
if update >= args.num_updates:
|
778 |
break
|
779 |
|
780 |
if args.save_model:
|
sebulba_ppo_envpool_impala_atari_wrapper.cleanrl_model
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1241c474effc0c32294ea8a5a92151274451b8b4f4bf1b833c2bc0ec7f7b0780
|
3 |
+
size 4369080
|
videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__041f1f0f-98c5-41f3-a73d-19d8967acb56-eval/0.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1642fa75fb20b0b56ac991e4834ab66724d2c8ed8823cd9b26234d856d0370bc
|
3 |
+
size 3660933
|
videos/Enduro-v5__sebulba_ppo_envpool_impala_atari_wrapper__1__701dcf31-d5b4-4d80-bfe6-a77f88357693-eval/0.mp4
DELETED
Binary file (320 kB)
|
|