sgoodfriend
commited on
Commit
·
7058eb5
1
Parent(s):
2ee8433
PPO playing MicrortsDefeatCoacAIShaped-v3 from https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +147 -0
- LICENSE +21 -0
- README.md +206 -0
- benchmark_publish.py +4 -0
- colab/colab_atari1.sh +4 -0
- colab/colab_atari2.sh +4 -0
- colab/colab_basic.sh +4 -0
- colab/colab_benchmark.ipynb +195 -0
- colab/colab_carracing.sh +4 -0
- colab/colab_enjoy.ipynb +198 -0
- colab/colab_pybullet.sh +4 -0
- colab/colab_train.ipynb +200 -0
- compare_runs.py +4 -0
- enjoy.py +4 -0
- environment.yml +12 -0
- huggingface_publish.py +4 -0
- optimize.py +4 -0
- pyproject.toml +88 -0
- replay.meta.json +1 -0
- replay.mp4 +0 -0
- rl_algo_impls/a2c/a2c.py +205 -0
- rl_algo_impls/a2c/optimize.py +77 -0
- rl_algo_impls/benchmark_publish.py +111 -0
- rl_algo_impls/compare_runs.py +199 -0
- rl_algo_impls/dqn/dqn.py +189 -0
- rl_algo_impls/dqn/policy.py +62 -0
- rl_algo_impls/dqn/q_net.py +41 -0
- rl_algo_impls/enjoy.py +35 -0
- rl_algo_impls/huggingface_publish.py +193 -0
- rl_algo_impls/hyperparams/a2c.yml +142 -0
- rl_algo_impls/hyperparams/dqn.yml +130 -0
- rl_algo_impls/hyperparams/ppo.yml +616 -0
- rl_algo_impls/hyperparams/vpg.yml +197 -0
- rl_algo_impls/optimize.py +475 -0
- rl_algo_impls/ppo/ppo.py +385 -0
- rl_algo_impls/publish/markdown_format.py +210 -0
- rl_algo_impls/runner/config.py +203 -0
- rl_algo_impls/runner/evaluate.py +102 -0
- rl_algo_impls/runner/running_utils.py +199 -0
- rl_algo_impls/runner/train.py +161 -0
- rl_algo_impls/shared/actor/__init__.py +2 -0
- rl_algo_impls/shared/actor/actor.py +43 -0
- rl_algo_impls/shared/actor/categorical.py +64 -0
- rl_algo_impls/shared/actor/gaussian.py +61 -0
- rl_algo_impls/shared/actor/gridnet.py +108 -0
- rl_algo_impls/shared/actor/gridnet_decoder.py +79 -0
- rl_algo_impls/shared/actor/make_actor.py +98 -0
- rl_algo_impls/shared/actor/multi_discrete.py +101 -0
- rl_algo_impls/shared/actor/state_dependent_noise.py +199 -0
- rl_algo_impls/shared/algorithm.py +39 -0
.gitignore
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# Logging into tensorboard and wandb
|
132 |
+
runs/*
|
133 |
+
wandb
|
134 |
+
|
135 |
+
# macOS
|
136 |
+
.DS_STORE
|
137 |
+
|
138 |
+
# Local scratch work
|
139 |
+
scratch/*
|
140 |
+
|
141 |
+
# vscode
|
142 |
+
.vscode/
|
143 |
+
|
144 |
+
# Don't bother tracking saved_models or videos
|
145 |
+
saved_models/*
|
146 |
+
downloaded_models/*
|
147 |
+
videos/*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Scott Goodfriend
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: rl-algo-impls
|
3 |
+
tags:
|
4 |
+
- MicrortsDefeatCoacAIShaped-v3
|
5 |
+
- ppo
|
6 |
+
- deep-reinforcement-learning
|
7 |
+
- reinforcement-learning
|
8 |
+
model-index:
|
9 |
+
- name: ppo
|
10 |
+
results:
|
11 |
+
- metrics:
|
12 |
+
- type: mean_reward
|
13 |
+
value: 0.69 +/- 0.72
|
14 |
+
name: mean_reward
|
15 |
+
task:
|
16 |
+
type: reinforcement-learning
|
17 |
+
name: reinforcement-learning
|
18 |
+
dataset:
|
19 |
+
name: MicrortsDefeatCoacAIShaped-v3
|
20 |
+
type: MicrortsDefeatCoacAIShaped-v3
|
21 |
+
---
|
22 |
+
# **PPO** Agent playing **MicrortsDefeatCoacAIShaped-v3**
|
23 |
+
|
24 |
+
This is a trained model of a **PPO** agent playing **MicrortsDefeatCoacAIShaped-v3** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
|
25 |
+
|
26 |
+
All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/lf7j0hrv.
|
27 |
+
|
28 |
+
## Training Results
|
29 |
+
|
30 |
+
This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [4706d8d](https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
|
31 |
+
|
32 |
+
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
+
|:-------|:------------------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
+
| ppo | MicrortsDefeatCoacAIShaped-v3 | 1 | 0.461538 | 0.88712 | 26 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/arhv1foe) |
|
35 |
+
| ppo | MicrortsDefeatCoacAIShaped-v3 | 2 | 0.461538 | 0.84265 | 26 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/kd89zf31) |
|
36 |
+
| ppo | MicrortsDefeatCoacAIShaped-v3 | 3 | 0.692308 | 0.721602 | 26 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/1ak14nj4) |
|
37 |
+
|
38 |
+
|
39 |
+
### Prerequisites: Weights & Biases (WandB)
|
40 |
+
Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
|
41 |
+
By default training goes to a rl-algo-impls project while benchmarks go to
|
42 |
+
rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
|
43 |
+
models and the model weights are uploaded to WandB.
|
44 |
+
|
45 |
+
Before doing anything below, you'll need to create a wandb account and run `wandb
|
46 |
+
login`.
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
## Usage
|
51 |
+
/sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
|
52 |
+
|
53 |
+
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
+
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
+
results. You might need to checkout the commit the agent was trained on:
|
56 |
+
[4706d8d](https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f).
|
57 |
+
```
|
58 |
+
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
+
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/1ak14nj4
|
60 |
+
```
|
61 |
+
|
62 |
+
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
63 |
+
Colab starting from the
|
64 |
+
[colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
|
65 |
+
notebook.
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
## Training
|
70 |
+
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
+
commit the agent was trained on: [4706d8d](https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f). While
|
72 |
+
training is deterministic, different hardware will give different results.
|
73 |
+
|
74 |
+
```
|
75 |
+
python train.py --algo ppo --env MicrortsDefeatCoacAIShaped-v3 --seed 3
|
76 |
+
```
|
77 |
+
|
78 |
+
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
79 |
+
Colab starting from the
|
80 |
+
[colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
|
81 |
+
notebook.
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
## Benchmarking (with Lambda Labs instance)
|
86 |
+
This and other models from https://api.wandb.ai/links/sgoodfriend/lf7j0hrv were generated by running a script on a Lambda
|
87 |
+
Labs instance. In a Lambda Labs instance terminal:
|
88 |
+
```
|
89 |
+
git clone git@github.com:sgoodfriend/rl-algo-impls.git
|
90 |
+
cd rl-algo-impls
|
91 |
+
bash ./lambda_labs/setup.sh
|
92 |
+
wandb login
|
93 |
+
bash ./lambda_labs/benchmark.sh [-a {"ppo a2c dqn vpg"}] [-e ENVS] [-j {6}] [-p {rl-algo-impls-benchmarks}] [-s {"1 2 3"}]
|
94 |
+
```
|
95 |
+
|
96 |
+
### Alternative: Google Colab Pro+
|
97 |
+
As an alternative,
|
98 |
+
[colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
|
99 |
+
can be used. However, this requires a Google Colab Pro+ subscription and running across
|
100 |
+
4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
## Hyperparameters
|
105 |
+
This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
|
106 |
+
close and has some additional data:
|
107 |
+
```
|
108 |
+
additional_keys_to_log:
|
109 |
+
- microrts_stats
|
110 |
+
algo: ppo
|
111 |
+
algo_hyperparams:
|
112 |
+
batch_size: 3072
|
113 |
+
clip_range: 0.1
|
114 |
+
clip_range_decay: none
|
115 |
+
clip_range_vf: 0.1
|
116 |
+
ent_coef: 0.01
|
117 |
+
learning_rate: 0.00025
|
118 |
+
learning_rate_decay: spike
|
119 |
+
max_grad_norm: 0.5
|
120 |
+
n_epochs: 4
|
121 |
+
n_steps: 512
|
122 |
+
ppo2_vf_coef_halving: true
|
123 |
+
vf_coef: 0.5
|
124 |
+
device: auto
|
125 |
+
env: Microrts-selfplay-unet
|
126 |
+
env_hyperparams:
|
127 |
+
env_type: microrts
|
128 |
+
make_kwargs:
|
129 |
+
map_paths:
|
130 |
+
- maps/16x16/basesWorkers16x16.xml
|
131 |
+
max_steps: 2000
|
132 |
+
num_selfplay_envs: 36
|
133 |
+
render_theme: 2
|
134 |
+
reward_weight:
|
135 |
+
- 10
|
136 |
+
- 1
|
137 |
+
- 1
|
138 |
+
- 0.2
|
139 |
+
- 1
|
140 |
+
- 4
|
141 |
+
n_envs: 24
|
142 |
+
self_play_kwargs:
|
143 |
+
num_old_policies: 12
|
144 |
+
save_steps: 200000
|
145 |
+
swap_steps: 10000
|
146 |
+
swap_window_size: 4
|
147 |
+
window: 25
|
148 |
+
env_id: MicrortsDefeatCoacAIShaped-v3
|
149 |
+
eval_hyperparams:
|
150 |
+
deterministic: false
|
151 |
+
env_overrides:
|
152 |
+
bots:
|
153 |
+
coacAI: 2
|
154 |
+
droplet: 2
|
155 |
+
guidedRojoA3N: 2
|
156 |
+
izanagi: 2
|
157 |
+
lightRushAI: 2
|
158 |
+
mixedBot: 2
|
159 |
+
naiveMCTSAI: 2
|
160 |
+
passiveAI: 2
|
161 |
+
randomAI: 2
|
162 |
+
randomBiasedAI: 2
|
163 |
+
rojo: 2
|
164 |
+
tiamat: 2
|
165 |
+
workerRushAI: 2
|
166 |
+
make_kwargs:
|
167 |
+
map_paths:
|
168 |
+
- maps/16x16/basesWorkers16x16.xml
|
169 |
+
max_steps: 4000
|
170 |
+
num_selfplay_envs: 0
|
171 |
+
render_theme: 2
|
172 |
+
reward_weight:
|
173 |
+
- 1
|
174 |
+
- 0
|
175 |
+
- 0
|
176 |
+
- 0
|
177 |
+
- 0
|
178 |
+
- 0
|
179 |
+
n_envs: 26
|
180 |
+
self_play_kwargs: {}
|
181 |
+
max_video_length: 4000
|
182 |
+
n_episodes: 26
|
183 |
+
score_function: mean
|
184 |
+
step_freq: 1000000
|
185 |
+
microrts_reward_decay_callback: false
|
186 |
+
n_timesteps: 300000000
|
187 |
+
policy_hyperparams:
|
188 |
+
activation_fn: relu
|
189 |
+
actor_head_style: unet
|
190 |
+
cnn_flatten_dim: 256
|
191 |
+
cnn_style: microrts
|
192 |
+
v_hidden_sizes:
|
193 |
+
- 256
|
194 |
+
- 128
|
195 |
+
seed: 3
|
196 |
+
use_deterministic_algorithms: true
|
197 |
+
wandb_entity: null
|
198 |
+
wandb_group: null
|
199 |
+
wandb_project_name: rl-algo-impls-benchmarks
|
200 |
+
wandb_tags:
|
201 |
+
- benchmark_4706d8d
|
202 |
+
- host_192-9-146-21
|
203 |
+
- branch_selfplay
|
204 |
+
- v0.0.9
|
205 |
+
|
206 |
+
```
|
benchmark_publish.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.benchmark_publish import benchmark_publish
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
benchmark_publish()
|
colab/colab_atari1.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ALGO="ppo"
|
2 |
+
ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
|
3 |
+
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
4 |
+
bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
colab/colab_atari2.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ALGO="ppo"
|
2 |
+
ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
|
3 |
+
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
4 |
+
bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
colab/colab_basic.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ALGO="ppo"
|
2 |
+
ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
|
3 |
+
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
4 |
+
bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
colab/colab_benchmark.ipynb
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"machine_shape": "hm",
|
8 |
+
"authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
|
9 |
+
"include_colab_link": true
|
10 |
+
},
|
11 |
+
"kernelspec": {
|
12 |
+
"name": "python3",
|
13 |
+
"display_name": "Python 3"
|
14 |
+
},
|
15 |
+
"language_info": {
|
16 |
+
"name": "python"
|
17 |
+
},
|
18 |
+
"gpuClass": "standard",
|
19 |
+
"accelerator": "GPU"
|
20 |
+
},
|
21 |
+
"cells": [
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"metadata": {
|
25 |
+
"id": "view-in-github",
|
26 |
+
"colab_type": "text"
|
27 |
+
},
|
28 |
+
"source": [
|
29 |
+
"<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "markdown",
|
34 |
+
"source": [
|
35 |
+
"# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
|
36 |
+
"## Parameters\n",
|
37 |
+
"\n",
|
38 |
+
"\n",
|
39 |
+
"1. Wandb\n",
|
40 |
+
"\n"
|
41 |
+
],
|
42 |
+
"metadata": {
|
43 |
+
"id": "S-tXDWP8WTLc"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"source": [
|
49 |
+
"from getpass import getpass\n",
|
50 |
+
"import os\n",
|
51 |
+
"os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
|
52 |
+
],
|
53 |
+
"metadata": {
|
54 |
+
"id": "1ZtdYgxWNGwZ"
|
55 |
+
},
|
56 |
+
"execution_count": null,
|
57 |
+
"outputs": []
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "markdown",
|
61 |
+
"source": [
|
62 |
+
"## Setup\n",
|
63 |
+
"Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
|
64 |
+
],
|
65 |
+
"metadata": {
|
66 |
+
"id": "bsG35Io0hmKG"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"source": [
|
72 |
+
"%%capture\n",
|
73 |
+
"!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
|
74 |
+
],
|
75 |
+
"metadata": {
|
76 |
+
"id": "k5ynTV25hdAf"
|
77 |
+
},
|
78 |
+
"execution_count": null,
|
79 |
+
"outputs": []
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"source": [
|
84 |
+
"Installing the correct packages:\n",
|
85 |
+
"\n",
|
86 |
+
"While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
|
87 |
+
],
|
88 |
+
"metadata": {
|
89 |
+
"id": "jKxGok-ElYQ7"
|
90 |
+
}
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"source": [
|
95 |
+
"%%capture\n",
|
96 |
+
"!apt install python-opengl\n",
|
97 |
+
"!apt install ffmpeg\n",
|
98 |
+
"!apt install xvfb\n",
|
99 |
+
"!apt install swig"
|
100 |
+
],
|
101 |
+
"metadata": {
|
102 |
+
"id": "nn6EETTc2Ewf"
|
103 |
+
},
|
104 |
+
"execution_count": null,
|
105 |
+
"outputs": []
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"source": [
|
110 |
+
"%%capture\n",
|
111 |
+
"%cd /content/rl-algo-impls\n",
|
112 |
+
"python -m pip install ."
|
113 |
+
],
|
114 |
+
"metadata": {
|
115 |
+
"id": "AfZh9rH3yQii"
|
116 |
+
},
|
117 |
+
"execution_count": null,
|
118 |
+
"outputs": []
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "markdown",
|
122 |
+
"source": [
|
123 |
+
"## Run Once Per Runtime"
|
124 |
+
],
|
125 |
+
"metadata": {
|
126 |
+
"id": "4o5HOLjc4wq7"
|
127 |
+
}
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"source": [
|
132 |
+
"import wandb\n",
|
133 |
+
"wandb.login()"
|
134 |
+
],
|
135 |
+
"metadata": {
|
136 |
+
"id": "PCXa5tdS2qFX"
|
137 |
+
},
|
138 |
+
"execution_count": null,
|
139 |
+
"outputs": []
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"cell_type": "markdown",
|
143 |
+
"source": [
|
144 |
+
"## Restart Session beteween runs"
|
145 |
+
],
|
146 |
+
"metadata": {
|
147 |
+
"id": "AZBZfSUV43JQ"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"source": [
|
153 |
+
"%%capture\n",
|
154 |
+
"from pyvirtualdisplay import Display\n",
|
155 |
+
"\n",
|
156 |
+
"virtual_display = Display(visible=0, size=(1400, 900))\n",
|
157 |
+
"virtual_display.start()"
|
158 |
+
],
|
159 |
+
"metadata": {
|
160 |
+
"id": "VzemeQJP2NO9"
|
161 |
+
},
|
162 |
+
"execution_count": null,
|
163 |
+
"outputs": []
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "markdown",
|
167 |
+
"source": [
|
168 |
+
"The below 5 bash scripts train agents on environments with 3 seeds each:\n",
|
169 |
+
"- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
|
170 |
+
"- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
|
171 |
+
"- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
|
172 |
+
],
|
173 |
+
"metadata": {
|
174 |
+
"id": "nSHfna0hLlO1"
|
175 |
+
}
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"source": [
|
180 |
+
"%cd /content/rl-algo-impls\n",
|
181 |
+
"os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
|
182 |
+
"!./benchmarks/colab_basic.sh\n",
|
183 |
+
"!./benchmarks/colab_pybullet.sh\n",
|
184 |
+
"# !./benchmarks/colab_carracing.sh\n",
|
185 |
+
"# !./benchmarks/colab_atari1.sh\n",
|
186 |
+
"# !./benchmarks/colab_atari2.sh"
|
187 |
+
],
|
188 |
+
"metadata": {
|
189 |
+
"id": "07aHYFH1zfXa"
|
190 |
+
},
|
191 |
+
"execution_count": null,
|
192 |
+
"outputs": []
|
193 |
+
}
|
194 |
+
]
|
195 |
+
}
|
colab/colab_carracing.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ALGO="ppo"
|
2 |
+
ENVS="CarRacing-v0"
|
3 |
+
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
4 |
+
bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
colab/colab_enjoy.ipynb
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"machine_shape": "hm",
|
8 |
+
"authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
|
9 |
+
"include_colab_link": true
|
10 |
+
},
|
11 |
+
"kernelspec": {
|
12 |
+
"name": "python3",
|
13 |
+
"display_name": "Python 3"
|
14 |
+
},
|
15 |
+
"language_info": {
|
16 |
+
"name": "python"
|
17 |
+
},
|
18 |
+
"gpuClass": "standard",
|
19 |
+
"accelerator": "GPU"
|
20 |
+
},
|
21 |
+
"cells": [
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"metadata": {
|
25 |
+
"id": "view-in-github",
|
26 |
+
"colab_type": "text"
|
27 |
+
},
|
28 |
+
"source": [
|
29 |
+
"<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "markdown",
|
34 |
+
"source": [
|
35 |
+
"# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
|
36 |
+
"## Parameters\n",
|
37 |
+
"\n",
|
38 |
+
"\n",
|
39 |
+
"1. Wandb\n",
|
40 |
+
"\n"
|
41 |
+
],
|
42 |
+
"metadata": {
|
43 |
+
"id": "S-tXDWP8WTLc"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"source": [
|
49 |
+
"from getpass import getpass\n",
|
50 |
+
"import os\n",
|
51 |
+
"os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
|
52 |
+
],
|
53 |
+
"metadata": {
|
54 |
+
"id": "1ZtdYgxWNGwZ"
|
55 |
+
},
|
56 |
+
"execution_count": null,
|
57 |
+
"outputs": []
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "markdown",
|
61 |
+
"source": [
|
62 |
+
"2. enjoy.py parameters"
|
63 |
+
],
|
64 |
+
"metadata": {
|
65 |
+
"id": "ao0nAh3MOdN7"
|
66 |
+
}
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"source": [
|
71 |
+
"WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
|
72 |
+
],
|
73 |
+
"metadata": {
|
74 |
+
"id": "jKL_NFhVOjSc"
|
75 |
+
},
|
76 |
+
"execution_count": 2,
|
77 |
+
"outputs": []
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "markdown",
|
81 |
+
"source": [
|
82 |
+
"## Setup\n",
|
83 |
+
"Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
|
84 |
+
],
|
85 |
+
"metadata": {
|
86 |
+
"id": "bsG35Io0hmKG"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"source": [
|
92 |
+
"%%capture\n",
|
93 |
+
"!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
|
94 |
+
],
|
95 |
+
"metadata": {
|
96 |
+
"id": "k5ynTV25hdAf"
|
97 |
+
},
|
98 |
+
"execution_count": 3,
|
99 |
+
"outputs": []
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "markdown",
|
103 |
+
"source": [
|
104 |
+
"Installing the correct packages:\n",
|
105 |
+
"\n",
|
106 |
+
"While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
|
107 |
+
],
|
108 |
+
"metadata": {
|
109 |
+
"id": "jKxGok-ElYQ7"
|
110 |
+
}
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"source": [
|
115 |
+
"%%capture\n",
|
116 |
+
"!apt install python-opengl\n",
|
117 |
+
"!apt install ffmpeg\n",
|
118 |
+
"!apt install xvfb\n",
|
119 |
+
"!apt install swig"
|
120 |
+
],
|
121 |
+
"metadata": {
|
122 |
+
"id": "nn6EETTc2Ewf"
|
123 |
+
},
|
124 |
+
"execution_count": 4,
|
125 |
+
"outputs": []
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"source": [
|
130 |
+
"%%capture\n",
|
131 |
+
"%cd /content/rl-algo-impls\n",
|
132 |
+
"python -m pip install ."
|
133 |
+
],
|
134 |
+
"metadata": {
|
135 |
+
"id": "AfZh9rH3yQii"
|
136 |
+
},
|
137 |
+
"execution_count": 5,
|
138 |
+
"outputs": []
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "markdown",
|
142 |
+
"source": [
|
143 |
+
"## Run Once Per Runtime"
|
144 |
+
],
|
145 |
+
"metadata": {
|
146 |
+
"id": "4o5HOLjc4wq7"
|
147 |
+
}
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "code",
|
151 |
+
"source": [
|
152 |
+
"import wandb\n",
|
153 |
+
"wandb.login()"
|
154 |
+
],
|
155 |
+
"metadata": {
|
156 |
+
"id": "PCXa5tdS2qFX"
|
157 |
+
},
|
158 |
+
"execution_count": null,
|
159 |
+
"outputs": []
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "markdown",
|
163 |
+
"source": [
|
164 |
+
"## Restart Session beteween runs"
|
165 |
+
],
|
166 |
+
"metadata": {
|
167 |
+
"id": "AZBZfSUV43JQ"
|
168 |
+
}
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"source": [
|
173 |
+
"%%capture\n",
|
174 |
+
"from pyvirtualdisplay import Display\n",
|
175 |
+
"\n",
|
176 |
+
"virtual_display = Display(visible=0, size=(1400, 900))\n",
|
177 |
+
"virtual_display.start()"
|
178 |
+
],
|
179 |
+
"metadata": {
|
180 |
+
"id": "VzemeQJP2NO9"
|
181 |
+
},
|
182 |
+
"execution_count": 7,
|
183 |
+
"outputs": []
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"source": [
|
188 |
+
"%cd /content/rl-algo-impls\n",
|
189 |
+
"!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
|
190 |
+
],
|
191 |
+
"metadata": {
|
192 |
+
"id": "07aHYFH1zfXa"
|
193 |
+
},
|
194 |
+
"execution_count": null,
|
195 |
+
"outputs": []
|
196 |
+
}
|
197 |
+
]
|
198 |
+
}
|
colab/colab_pybullet.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ALGO="ppo"
|
2 |
+
ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
|
3 |
+
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
4 |
+
bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
colab/colab_train.ipynb
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"machine_shape": "hm",
|
8 |
+
"authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
|
9 |
+
"include_colab_link": true
|
10 |
+
},
|
11 |
+
"kernelspec": {
|
12 |
+
"name": "python3",
|
13 |
+
"display_name": "Python 3"
|
14 |
+
},
|
15 |
+
"language_info": {
|
16 |
+
"name": "python"
|
17 |
+
},
|
18 |
+
"gpuClass": "standard",
|
19 |
+
"accelerator": "GPU"
|
20 |
+
},
|
21 |
+
"cells": [
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"metadata": {
|
25 |
+
"id": "view-in-github",
|
26 |
+
"colab_type": "text"
|
27 |
+
},
|
28 |
+
"source": [
|
29 |
+
"<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "markdown",
|
34 |
+
"source": [
|
35 |
+
"# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
|
36 |
+
"## Parameters\n",
|
37 |
+
"\n",
|
38 |
+
"\n",
|
39 |
+
"1. Wandb\n",
|
40 |
+
"\n"
|
41 |
+
],
|
42 |
+
"metadata": {
|
43 |
+
"id": "S-tXDWP8WTLc"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"source": [
|
49 |
+
"from getpass import getpass\n",
|
50 |
+
"import os\n",
|
51 |
+
"os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
|
52 |
+
],
|
53 |
+
"metadata": {
|
54 |
+
"id": "1ZtdYgxWNGwZ"
|
55 |
+
},
|
56 |
+
"execution_count": null,
|
57 |
+
"outputs": []
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "markdown",
|
61 |
+
"source": [
|
62 |
+
"2. train run parameters"
|
63 |
+
],
|
64 |
+
"metadata": {
|
65 |
+
"id": "ao0nAh3MOdN7"
|
66 |
+
}
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"source": [
|
71 |
+
"ALGO = \"ppo\"\n",
|
72 |
+
"ENV = \"CartPole-v1\"\n",
|
73 |
+
"SEED = 1"
|
74 |
+
],
|
75 |
+
"metadata": {
|
76 |
+
"id": "jKL_NFhVOjSc"
|
77 |
+
},
|
78 |
+
"execution_count": null,
|
79 |
+
"outputs": []
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"source": [
|
84 |
+
"## Setup\n",
|
85 |
+
"Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
|
86 |
+
],
|
87 |
+
"metadata": {
|
88 |
+
"id": "bsG35Io0hmKG"
|
89 |
+
}
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"source": [
|
94 |
+
"%%capture\n",
|
95 |
+
"!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
|
96 |
+
],
|
97 |
+
"metadata": {
|
98 |
+
"id": "k5ynTV25hdAf"
|
99 |
+
},
|
100 |
+
"execution_count": null,
|
101 |
+
"outputs": []
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "markdown",
|
105 |
+
"source": [
|
106 |
+
"Installing the correct packages:\n",
|
107 |
+
"\n",
|
108 |
+
"While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
|
109 |
+
],
|
110 |
+
"metadata": {
|
111 |
+
"id": "jKxGok-ElYQ7"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"source": [
|
117 |
+
"%%capture\n",
|
118 |
+
"!apt install python-opengl\n",
|
119 |
+
"!apt install ffmpeg\n",
|
120 |
+
"!apt install xvfb\n",
|
121 |
+
"!apt install swig"
|
122 |
+
],
|
123 |
+
"metadata": {
|
124 |
+
"id": "nn6EETTc2Ewf"
|
125 |
+
},
|
126 |
+
"execution_count": null,
|
127 |
+
"outputs": []
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"source": [
|
132 |
+
"%%capture\n",
|
133 |
+
"%cd /content/rl-algo-impls\n",
|
134 |
+
"python -m pip install ."
|
135 |
+
],
|
136 |
+
"metadata": {
|
137 |
+
"id": "AfZh9rH3yQii"
|
138 |
+
},
|
139 |
+
"execution_count": null,
|
140 |
+
"outputs": []
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "markdown",
|
144 |
+
"source": [
|
145 |
+
"## Run Once Per Runtime"
|
146 |
+
],
|
147 |
+
"metadata": {
|
148 |
+
"id": "4o5HOLjc4wq7"
|
149 |
+
}
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"source": [
|
154 |
+
"import wandb\n",
|
155 |
+
"wandb.login()"
|
156 |
+
],
|
157 |
+
"metadata": {
|
158 |
+
"id": "PCXa5tdS2qFX"
|
159 |
+
},
|
160 |
+
"execution_count": null,
|
161 |
+
"outputs": []
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"source": [
|
166 |
+
"## Restart Session beteween runs"
|
167 |
+
],
|
168 |
+
"metadata": {
|
169 |
+
"id": "AZBZfSUV43JQ"
|
170 |
+
}
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"source": [
|
175 |
+
"%%capture\n",
|
176 |
+
"from pyvirtualdisplay import Display\n",
|
177 |
+
"\n",
|
178 |
+
"virtual_display = Display(visible=0, size=(1400, 900))\n",
|
179 |
+
"virtual_display.start()"
|
180 |
+
],
|
181 |
+
"metadata": {
|
182 |
+
"id": "VzemeQJP2NO9"
|
183 |
+
},
|
184 |
+
"execution_count": null,
|
185 |
+
"outputs": []
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"source": [
|
190 |
+
"%cd /content/rl-algo-impls\n",
|
191 |
+
"!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
|
192 |
+
],
|
193 |
+
"metadata": {
|
194 |
+
"id": "07aHYFH1zfXa"
|
195 |
+
},
|
196 |
+
"execution_count": null,
|
197 |
+
"outputs": []
|
198 |
+
}
|
199 |
+
]
|
200 |
+
}
|
compare_runs.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.compare_runs import compare_runs
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
compare_runs()
|
enjoy.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.enjoy import enjoy
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
enjoy()
|
environment.yml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: rl_algo_impls
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- conda-forge
|
5 |
+
- nodefaults
|
6 |
+
dependencies:
|
7 |
+
- python>=3.8, <3.10
|
8 |
+
- mamba
|
9 |
+
- pip
|
10 |
+
- pytorch
|
11 |
+
- torchvision
|
12 |
+
- torchaudio
|
huggingface_publish.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.huggingface_publish import huggingface_publish
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
huggingface_publish()
|
optimize.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.optimize import optimize
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
optimize()
|
pyproject.toml
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "rl_algo_impls"
|
3 |
+
version = "0.0.9"
|
4 |
+
description = "Implementations of reinforcement learning algorithms"
|
5 |
+
authors = [
|
6 |
+
{name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"},
|
7 |
+
]
|
8 |
+
license = {file = "LICENSE"}
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">= 3.8"
|
11 |
+
classifiers = [
|
12 |
+
"License :: OSI Approved :: MIT License",
|
13 |
+
"Development Status :: 3 - Alpha",
|
14 |
+
"Programming Language :: Python :: 3.8",
|
15 |
+
"Programming Language :: Python :: 3.9",
|
16 |
+
"Programming Language :: Python :: 3.10",
|
17 |
+
]
|
18 |
+
dependencies = [
|
19 |
+
"cmake",
|
20 |
+
"swig",
|
21 |
+
"scipy",
|
22 |
+
"torch",
|
23 |
+
"torchvision",
|
24 |
+
"tensorboard >= 2.11.2, < 2.12",
|
25 |
+
"AutoROM.accept-rom-license >= 0.4.2, < 0.5",
|
26 |
+
"stable-baselines3[extra] >= 1.7.0, < 1.8",
|
27 |
+
"gym[box2d] >= 0.21.0, < 0.22",
|
28 |
+
"pyglet == 1.5.27",
|
29 |
+
"wandb",
|
30 |
+
"pyvirtualdisplay",
|
31 |
+
"pybullet",
|
32 |
+
"tabulate",
|
33 |
+
"huggingface-hub",
|
34 |
+
"optuna",
|
35 |
+
"dash",
|
36 |
+
"kaleido",
|
37 |
+
"PyYAML",
|
38 |
+
"scikit-learn",
|
39 |
+
]
|
40 |
+
|
41 |
+
[tool.setuptools]
|
42 |
+
packages = ["rl_algo_impls"]
|
43 |
+
|
44 |
+
[project.optional-dependencies]
|
45 |
+
test = [
|
46 |
+
"pytest",
|
47 |
+
"black",
|
48 |
+
"mypy",
|
49 |
+
"flake8",
|
50 |
+
"flake8-bugbear",
|
51 |
+
"isort",
|
52 |
+
]
|
53 |
+
procgen = [
|
54 |
+
"numexpr >= 2.8.4",
|
55 |
+
"gym3",
|
56 |
+
"glfw >= 1.12.0, < 1.13",
|
57 |
+
"procgen; platform_machine=='x86_64'",
|
58 |
+
]
|
59 |
+
microrts-ppo = [
|
60 |
+
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
61 |
+
"gym-microrts == 0.2.0", # Match ppo-implementation-details
|
62 |
+
]
|
63 |
+
microrts-paper = [
|
64 |
+
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
65 |
+
"gym-microrts == 0.3.2",
|
66 |
+
]
|
67 |
+
microrts = [
|
68 |
+
"gym-microrts",
|
69 |
+
]
|
70 |
+
jupyter = [
|
71 |
+
"jupyter",
|
72 |
+
"notebook"
|
73 |
+
]
|
74 |
+
all = [
|
75 |
+
"rl-algo-impls[test]",
|
76 |
+
"rl-algo-impls[procgen]",
|
77 |
+
"rl-algo-impls[microrts]",
|
78 |
+
]
|
79 |
+
|
80 |
+
[project.urls]
|
81 |
+
"Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
|
82 |
+
|
83 |
+
[build-system]
|
84 |
+
requires = ["setuptools==65.5.0", "setuptools-scm"]
|
85 |
+
build-backend = "setuptools.build_meta"
|
86 |
+
|
87 |
+
[tool.isort]
|
88 |
+
profile = "black"
|
replay.meta.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "640x640", "-pix_fmt", "rgb24", "-framerate", "150", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "150", "/tmp/tmpiof1jwbr/ppo-Microrts-selfplay-unet/replay.mp4"]}, "episode": {"r": 1.0, "l": 794, "t": 10.908112}}
|
replay.mp4
ADDED
Binary file (322 kB). View file
|
|
rl_algo_impls/a2c/a2c.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from time import perf_counter
|
3 |
+
from typing import List, Optional, TypeVar
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
10 |
+
|
11 |
+
from rl_algo_impls.shared.algorithm import Algorithm
|
12 |
+
from rl_algo_impls.shared.callbacks import Callback
|
13 |
+
from rl_algo_impls.shared.gae import compute_advantages
|
14 |
+
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
|
15 |
+
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
|
16 |
+
from rl_algo_impls.shared.stats import log_scalars
|
17 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
18 |
+
VecEnv,
|
19 |
+
single_action_space,
|
20 |
+
single_observation_space,
|
21 |
+
)
|
22 |
+
|
23 |
+
A2CSelf = TypeVar("A2CSelf", bound="A2C")
|
24 |
+
|
25 |
+
|
26 |
+
class A2C(Algorithm):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
policy: ActorCritic,
|
30 |
+
env: VecEnv,
|
31 |
+
device: torch.device,
|
32 |
+
tb_writer: SummaryWriter,
|
33 |
+
learning_rate: float = 7e-4,
|
34 |
+
learning_rate_decay: str = "none",
|
35 |
+
n_steps: int = 5,
|
36 |
+
gamma: float = 0.99,
|
37 |
+
gae_lambda: float = 1.0,
|
38 |
+
ent_coef: float = 0.0,
|
39 |
+
ent_coef_decay: str = "none",
|
40 |
+
vf_coef: float = 0.5,
|
41 |
+
max_grad_norm: float = 0.5,
|
42 |
+
rms_prop_eps: float = 1e-5,
|
43 |
+
use_rms_prop: bool = True,
|
44 |
+
sde_sample_freq: int = -1,
|
45 |
+
normalize_advantage: bool = False,
|
46 |
+
) -> None:
|
47 |
+
super().__init__(policy, env, device, tb_writer)
|
48 |
+
self.policy = policy
|
49 |
+
|
50 |
+
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
|
51 |
+
if use_rms_prop:
|
52 |
+
self.optimizer = torch.optim.RMSprop(
|
53 |
+
policy.parameters(), lr=learning_rate, eps=rms_prop_eps
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
|
57 |
+
|
58 |
+
self.n_steps = n_steps
|
59 |
+
|
60 |
+
self.gamma = gamma
|
61 |
+
self.gae_lambda = gae_lambda
|
62 |
+
|
63 |
+
self.vf_coef = vf_coef
|
64 |
+
self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
|
65 |
+
self.max_grad_norm = max_grad_norm
|
66 |
+
|
67 |
+
self.sde_sample_freq = sde_sample_freq
|
68 |
+
self.normalize_advantage = normalize_advantage
|
69 |
+
|
70 |
+
def learn(
|
71 |
+
self: A2CSelf,
|
72 |
+
train_timesteps: int,
|
73 |
+
callbacks: Optional[List[Callback]] = None,
|
74 |
+
total_timesteps: Optional[int] = None,
|
75 |
+
start_timesteps: int = 0,
|
76 |
+
) -> A2CSelf:
|
77 |
+
if total_timesteps is None:
|
78 |
+
total_timesteps = train_timesteps
|
79 |
+
assert start_timesteps + train_timesteps <= total_timesteps
|
80 |
+
epoch_dim = (self.n_steps, self.env.num_envs)
|
81 |
+
step_dim = (self.env.num_envs,)
|
82 |
+
obs_space = single_observation_space(self.env)
|
83 |
+
act_space = single_action_space(self.env)
|
84 |
+
|
85 |
+
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
|
86 |
+
actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
|
87 |
+
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
88 |
+
episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
|
89 |
+
values = np.zeros(epoch_dim, dtype=np.float32)
|
90 |
+
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
91 |
+
|
92 |
+
next_obs = self.env.reset()
|
93 |
+
next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
|
94 |
+
|
95 |
+
timesteps_elapsed = start_timesteps
|
96 |
+
while timesteps_elapsed < start_timesteps + train_timesteps:
|
97 |
+
start_time = perf_counter()
|
98 |
+
|
99 |
+
progress = timesteps_elapsed / total_timesteps
|
100 |
+
ent_coef = self.ent_coef_schedule(progress)
|
101 |
+
learning_rate = self.lr_schedule(progress)
|
102 |
+
update_learning_rate(self.optimizer, learning_rate)
|
103 |
+
log_scalars(
|
104 |
+
self.tb_writer,
|
105 |
+
"charts",
|
106 |
+
{
|
107 |
+
"ent_coef": ent_coef,
|
108 |
+
"learning_rate": learning_rate,
|
109 |
+
},
|
110 |
+
timesteps_elapsed,
|
111 |
+
)
|
112 |
+
|
113 |
+
self.policy.eval()
|
114 |
+
self.policy.reset_noise()
|
115 |
+
for s in range(self.n_steps):
|
116 |
+
timesteps_elapsed += self.env.num_envs
|
117 |
+
if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
|
118 |
+
self.policy.reset_noise()
|
119 |
+
|
120 |
+
obs[s] = next_obs
|
121 |
+
episode_starts[s] = next_episode_starts
|
122 |
+
|
123 |
+
actions[s], values[s], logprobs[s], clamped_action = self.policy.step(
|
124 |
+
next_obs
|
125 |
+
)
|
126 |
+
next_obs, rewards[s], next_episode_starts, _ = self.env.step(
|
127 |
+
clamped_action
|
128 |
+
)
|
129 |
+
|
130 |
+
advantages = compute_advantages(
|
131 |
+
rewards,
|
132 |
+
values,
|
133 |
+
episode_starts,
|
134 |
+
next_episode_starts,
|
135 |
+
next_obs,
|
136 |
+
self.policy,
|
137 |
+
self.gamma,
|
138 |
+
self.gae_lambda,
|
139 |
+
)
|
140 |
+
returns = advantages + values
|
141 |
+
|
142 |
+
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
|
143 |
+
b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to(
|
144 |
+
self.device
|
145 |
+
)
|
146 |
+
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
|
147 |
+
b_returns = torch.tensor(returns.reshape(-1)).to(self.device)
|
148 |
+
|
149 |
+
if self.normalize_advantage:
|
150 |
+
b_advantages = (b_advantages - b_advantages.mean()) / (
|
151 |
+
b_advantages.std() + 1e-8
|
152 |
+
)
|
153 |
+
|
154 |
+
self.policy.train()
|
155 |
+
logp_a, entropy, v = self.policy(b_obs, b_actions)
|
156 |
+
|
157 |
+
pi_loss = -(b_advantages * logp_a).mean()
|
158 |
+
value_loss = F.mse_loss(b_returns, v)
|
159 |
+
entropy_loss = -entropy.mean()
|
160 |
+
|
161 |
+
loss = pi_loss + self.vf_coef * value_loss + ent_coef * entropy_loss
|
162 |
+
|
163 |
+
self.optimizer.zero_grad()
|
164 |
+
loss.backward()
|
165 |
+
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
166 |
+
self.optimizer.step()
|
167 |
+
|
168 |
+
y_pred = values.reshape(-1)
|
169 |
+
y_true = returns.reshape(-1)
|
170 |
+
var_y = np.var(y_true).item()
|
171 |
+
explained_var = (
|
172 |
+
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
|
173 |
+
)
|
174 |
+
|
175 |
+
end_time = perf_counter()
|
176 |
+
rollout_steps = self.n_steps * self.env.num_envs
|
177 |
+
self.tb_writer.add_scalar(
|
178 |
+
"train/steps_per_second",
|
179 |
+
(rollout_steps) / (end_time - start_time),
|
180 |
+
timesteps_elapsed,
|
181 |
+
)
|
182 |
+
|
183 |
+
log_scalars(
|
184 |
+
self.tb_writer,
|
185 |
+
"losses",
|
186 |
+
{
|
187 |
+
"loss": loss.item(),
|
188 |
+
"pi_loss": pi_loss.item(),
|
189 |
+
"v_loss": value_loss.item(),
|
190 |
+
"entropy_loss": entropy_loss.item(),
|
191 |
+
"explained_var": explained_var,
|
192 |
+
},
|
193 |
+
timesteps_elapsed,
|
194 |
+
)
|
195 |
+
|
196 |
+
if callbacks:
|
197 |
+
if not all(
|
198 |
+
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
|
199 |
+
):
|
200 |
+
logging.info(
|
201 |
+
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
202 |
+
)
|
203 |
+
break
|
204 |
+
|
205 |
+
return self
|
rl_algo_impls/a2c/optimize.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import optuna
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
|
6 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
7 |
+
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
|
8 |
+
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
|
9 |
+
|
10 |
+
|
11 |
+
def sample_params(
|
12 |
+
trial: optuna.Trial,
|
13 |
+
base_hyperparams: Hyperparams,
|
14 |
+
base_config: Config,
|
15 |
+
) -> Hyperparams:
|
16 |
+
hyperparams = deepcopy(base_hyperparams)
|
17 |
+
|
18 |
+
base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
|
19 |
+
env = make_eval_env(base_config, base_env_hyperparams, override_n_envs=1)
|
20 |
+
|
21 |
+
# env_hyperparams
|
22 |
+
env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
|
23 |
+
|
24 |
+
# policy_hyperparams
|
25 |
+
policy_hyperparams = sample_on_policy_hyperparams(
|
26 |
+
trial, hyperparams.policy_hyperparams, env
|
27 |
+
)
|
28 |
+
|
29 |
+
# algo_hyperparams
|
30 |
+
algo_hyperparams = hyperparams.algo_hyperparams
|
31 |
+
|
32 |
+
learning_rate = trial.suggest_float("learning_rate", 1e-5, 2e-3, log=True)
|
33 |
+
learning_rate_decay = trial.suggest_categorical(
|
34 |
+
"learning_rate_decay", ["none", "linear"]
|
35 |
+
)
|
36 |
+
n_steps_exp = trial.suggest_int("n_steps_exp", 1, 10)
|
37 |
+
n_steps = 2**n_steps_exp
|
38 |
+
trial.set_user_attr("n_steps", n_steps)
|
39 |
+
gamma = 1.0 - trial.suggest_float("gamma_om", 1e-4, 1e-1, log=True)
|
40 |
+
trial.set_user_attr("gamma", gamma)
|
41 |
+
gae_lambda = 1 - trial.suggest_float("gae_lambda_om", 1e-4, 1e-1)
|
42 |
+
trial.set_user_attr("gae_lambda", gae_lambda)
|
43 |
+
ent_coef = trial.suggest_float("ent_coef", 1e-8, 2.5e-2, log=True)
|
44 |
+
ent_coef_decay = trial.suggest_categorical("ent_coef_decay", ["none", "linear"])
|
45 |
+
vf_coef = trial.suggest_float("vf_coef", 0.1, 0.7)
|
46 |
+
max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e1, log=True)
|
47 |
+
use_rms_prop = trial.suggest_categorical("use_rms_prop", [True, False])
|
48 |
+
normalize_advantage = trial.suggest_categorical(
|
49 |
+
"normalize_advantage", [True, False]
|
50 |
+
)
|
51 |
+
|
52 |
+
algo_hyperparams.update(
|
53 |
+
{
|
54 |
+
"learning_rate": learning_rate,
|
55 |
+
"learning_rate_decay": learning_rate_decay,
|
56 |
+
"n_steps": n_steps,
|
57 |
+
"gamma": gamma,
|
58 |
+
"gae_lambda": gae_lambda,
|
59 |
+
"ent_coef": ent_coef,
|
60 |
+
"ent_coef_decay": ent_coef_decay,
|
61 |
+
"vf_coef": vf_coef,
|
62 |
+
"max_grad_norm": max_grad_norm,
|
63 |
+
"use_rms_prop": use_rms_prop,
|
64 |
+
"normalize_advantage": normalize_advantage,
|
65 |
+
}
|
66 |
+
)
|
67 |
+
|
68 |
+
if policy_hyperparams.get("use_sde", False):
|
69 |
+
sde_sample_freq = 2 ** trial.suggest_int("sde_sample_freq_exp", 0, n_steps_exp)
|
70 |
+
trial.set_user_attr("sde_sample_freq", sde_sample_freq)
|
71 |
+
algo_hyperparams["sde_sample_freq"] = sde_sample_freq
|
72 |
+
elif "sde_sample_freq" in algo_hyperparams:
|
73 |
+
del algo_hyperparams["sde_sample_freq"]
|
74 |
+
|
75 |
+
env.close()
|
76 |
+
|
77 |
+
return hyperparams
|
rl_algo_impls/benchmark_publish.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import subprocess
|
3 |
+
import wandb
|
4 |
+
import wandb.apis.public
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
from multiprocessing.pool import ThreadPool
|
8 |
+
from typing import List, NamedTuple
|
9 |
+
|
10 |
+
|
11 |
+
class RunGroup(NamedTuple):
|
12 |
+
algo: str
|
13 |
+
env_id: str
|
14 |
+
|
15 |
+
|
16 |
+
def benchmark_publish() -> None:
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument(
|
19 |
+
"--wandb-project-name",
|
20 |
+
type=str,
|
21 |
+
default="rl-algo-impls-benchmarks",
|
22 |
+
help="WandB project name to load runs from",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--wandb-entity",
|
26 |
+
type=str,
|
27 |
+
default=None,
|
28 |
+
help="WandB team of project. None uses default entity",
|
29 |
+
)
|
30 |
+
parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
|
31 |
+
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
|
32 |
+
parser.add_argument(
|
33 |
+
"--envs", type=str, nargs="*", help="Optional filter down to these envs"
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--exclude-envs",
|
37 |
+
type=str,
|
38 |
+
nargs="*",
|
39 |
+
help="Environments to exclude from publishing",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--huggingface-user",
|
43 |
+
type=str,
|
44 |
+
default=None,
|
45 |
+
help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--pool-size",
|
49 |
+
type=int,
|
50 |
+
default=3,
|
51 |
+
help="How many publish jobs can run in parallel",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--virtual-display", action="store_true", help="Use headless virtual display"
|
55 |
+
)
|
56 |
+
# parser.set_defaults(
|
57 |
+
# wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
|
58 |
+
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
|
59 |
+
# envs=[],
|
60 |
+
# exclude_envs=[],
|
61 |
+
# )
|
62 |
+
args = parser.parse_args()
|
63 |
+
print(args)
|
64 |
+
|
65 |
+
api = wandb.Api()
|
66 |
+
all_runs = api.runs(
|
67 |
+
f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
|
68 |
+
)
|
69 |
+
|
70 |
+
required_tags = set(args.wandb_tags)
|
71 |
+
runs: List[wandb.apis.public.Run] = [
|
72 |
+
r
|
73 |
+
for r in all_runs
|
74 |
+
if required_tags.issubset(set(r.config.get("wandb_tags", [])))
|
75 |
+
]
|
76 |
+
|
77 |
+
runs_paths_by_group = defaultdict(list)
|
78 |
+
for r in runs:
|
79 |
+
if r.state != "finished":
|
80 |
+
continue
|
81 |
+
algo = r.config["algo"]
|
82 |
+
env = r.config["env"]
|
83 |
+
if args.envs and env not in args.envs:
|
84 |
+
continue
|
85 |
+
if args.exclude_envs and env in args.exclude_envs:
|
86 |
+
continue
|
87 |
+
run_group = RunGroup(algo, env)
|
88 |
+
runs_paths_by_group[run_group].append("/".join(r.path))
|
89 |
+
|
90 |
+
def run(run_paths: List[str]) -> None:
|
91 |
+
publish_args = ["python", "huggingface_publish.py"]
|
92 |
+
publish_args.append("--wandb-run-paths")
|
93 |
+
publish_args.extend(run_paths)
|
94 |
+
publish_args.append("--wandb-report-url")
|
95 |
+
publish_args.append(args.wandb_report_url)
|
96 |
+
if args.huggingface_user:
|
97 |
+
publish_args.append("--huggingface-user")
|
98 |
+
publish_args.append(args.huggingface_user)
|
99 |
+
if args.virtual_display:
|
100 |
+
publish_args.append("--virtual-display")
|
101 |
+
subprocess.run(publish_args)
|
102 |
+
|
103 |
+
tp = ThreadPool(args.pool_size)
|
104 |
+
for run_paths in runs_paths_by_group.values():
|
105 |
+
tp.apply_async(run, (run_paths,))
|
106 |
+
tp.close()
|
107 |
+
tp.join()
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
benchmark_publish()
|
rl_algo_impls/compare_runs.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import wandb
|
6 |
+
import wandb.apis.public
|
7 |
+
|
8 |
+
from collections import defaultdict
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Dict, Iterable, List, TypeVar
|
11 |
+
|
12 |
+
from rl_algo_impls.benchmark_publish import RunGroup
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class Comparison:
|
17 |
+
control_values: List[float]
|
18 |
+
experiment_values: List[float]
|
19 |
+
|
20 |
+
def mean_diff_percentage(self) -> float:
|
21 |
+
return self._diff_percentage(
|
22 |
+
np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
|
23 |
+
)
|
24 |
+
|
25 |
+
def median_diff_percentage(self) -> float:
|
26 |
+
return self._diff_percentage(
|
27 |
+
np.median(self.control_values).item(),
|
28 |
+
np.median(self.experiment_values).item(),
|
29 |
+
)
|
30 |
+
|
31 |
+
def _diff_percentage(self, c: float, e: float) -> float:
|
32 |
+
if c == e:
|
33 |
+
return 0
|
34 |
+
elif c == 0:
|
35 |
+
return float("inf") if e > 0 else float("-inf")
|
36 |
+
return 100 * (e - c) / c
|
37 |
+
|
38 |
+
def score(self) -> float:
|
39 |
+
return (
|
40 |
+
np.sum(
|
41 |
+
np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
|
42 |
+
).item()
|
43 |
+
/ 2
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
|
48 |
+
|
49 |
+
|
50 |
+
class RunGroupRuns:
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
run_group: RunGroup,
|
54 |
+
control: List[str],
|
55 |
+
experiment: List[str],
|
56 |
+
summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
|
57 |
+
summary_metrics: List[str] = ["mean", "result"],
|
58 |
+
) -> None:
|
59 |
+
self.algo = run_group.algo
|
60 |
+
self.env = run_group.env_id
|
61 |
+
self.control = set(control)
|
62 |
+
self.experiment = set(experiment)
|
63 |
+
|
64 |
+
self.summary_stats = summary_stats
|
65 |
+
self.summary_metrics = summary_metrics
|
66 |
+
|
67 |
+
self.control_runs = []
|
68 |
+
self.experiment_runs = []
|
69 |
+
|
70 |
+
def add_run(self, run: wandb.apis.public.Run) -> None:
|
71 |
+
wandb_tags = set(run.config.get("wandb_tags", []))
|
72 |
+
if self.control & wandb_tags:
|
73 |
+
self.control_runs.append(run)
|
74 |
+
elif self.experiment & wandb_tags:
|
75 |
+
self.experiment_runs.append(run)
|
76 |
+
|
77 |
+
def comparisons_by_metric(self) -> Dict[str, Comparison]:
|
78 |
+
c_by_m = {}
|
79 |
+
for metric in (
|
80 |
+
f"{s}/{m}"
|
81 |
+
for s, m in itertools.product(self.summary_stats, self.summary_metrics)
|
82 |
+
):
|
83 |
+
c_by_m[metric] = Comparison(
|
84 |
+
[c.summary[metric] for c in self.control_runs],
|
85 |
+
[e.summary[metric] for e in self.experiment_runs],
|
86 |
+
)
|
87 |
+
return c_by_m
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
|
91 |
+
results = defaultdict(list)
|
92 |
+
for r in rows:
|
93 |
+
if not r.control_runs or not r.experiment_runs:
|
94 |
+
continue
|
95 |
+
results["algo"].append(r.algo)
|
96 |
+
results["env"].append(r.env)
|
97 |
+
results["control"].append(r.control)
|
98 |
+
results["expierment"].append(r.experiment)
|
99 |
+
c_by_m = r.comparisons_by_metric()
|
100 |
+
results["score"].append(
|
101 |
+
sum(m.score() for m in c_by_m.values()) / len(c_by_m)
|
102 |
+
)
|
103 |
+
for m, c in c_by_m.items():
|
104 |
+
results[f"{m}_mean"].append(c.mean_diff_percentage())
|
105 |
+
results[f"{m}_median"].append(c.median_diff_percentage())
|
106 |
+
return pd.DataFrame(results)
|
107 |
+
|
108 |
+
|
109 |
+
def compare_runs() -> None:
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument(
|
112 |
+
"-p",
|
113 |
+
"--wandb-project-name",
|
114 |
+
type=str,
|
115 |
+
default="rl-algo-impls-benchmarks",
|
116 |
+
help="WandB project name to load runs from",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--wandb-entity",
|
120 |
+
type=str,
|
121 |
+
default=None,
|
122 |
+
help="WandB team. None uses default entity",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"-n",
|
126 |
+
"--wandb-hostname-tag",
|
127 |
+
type=str,
|
128 |
+
nargs="*",
|
129 |
+
help="WandB tags for hostname (i.e. host_192-9-145-26)",
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"-c",
|
133 |
+
"--wandb-control-tag",
|
134 |
+
type=str,
|
135 |
+
nargs="+",
|
136 |
+
help="WandB tag for control commit (i.e. benchmark_5598ebc)",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"-e",
|
140 |
+
"--wandb-experiment-tag",
|
141 |
+
type=str,
|
142 |
+
nargs="+",
|
143 |
+
help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--envs",
|
147 |
+
type=str,
|
148 |
+
nargs="*",
|
149 |
+
help="If specified, only compare these envs",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--exclude-envs",
|
153 |
+
type=str,
|
154 |
+
nargs="*",
|
155 |
+
help="Environments to exclude from comparison",
|
156 |
+
)
|
157 |
+
# parser.set_defaults(
|
158 |
+
# wandb_hostname_tag=["host_150-230-44-105", "host_155-248-214-128"],
|
159 |
+
# wandb_control_tag=["benchmark_fbc943f"],
|
160 |
+
# wandb_experiment_tag=["benchmark_f59bf74"],
|
161 |
+
# exclude_envs=[],
|
162 |
+
# )
|
163 |
+
args = parser.parse_args()
|
164 |
+
print(args)
|
165 |
+
|
166 |
+
api = wandb.Api()
|
167 |
+
all_runs = api.runs(
|
168 |
+
path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
|
169 |
+
order="+created_at",
|
170 |
+
)
|
171 |
+
|
172 |
+
runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
|
173 |
+
wandb_hostname_tags = set(args.wandb_hostname_tag)
|
174 |
+
for r in all_runs:
|
175 |
+
if r.state != "finished":
|
176 |
+
continue
|
177 |
+
wandb_tags = set(r.config.get("wandb_tags", []))
|
178 |
+
if not wandb_tags or not wandb_hostname_tags & wandb_tags:
|
179 |
+
continue
|
180 |
+
rg = RunGroup(r.config["algo"], r.config.get("env_id") or r.config["env"])
|
181 |
+
if args.exclude_envs and rg.env_id in args.exclude_envs:
|
182 |
+
continue
|
183 |
+
if args.envs and rg.env_id not in args.envs:
|
184 |
+
continue
|
185 |
+
if rg not in runs_by_run_group:
|
186 |
+
runs_by_run_group[rg] = RunGroupRuns(
|
187 |
+
rg,
|
188 |
+
args.wandb_control_tag,
|
189 |
+
args.wandb_experiment_tag,
|
190 |
+
)
|
191 |
+
runs_by_run_group[rg].add_run(r)
|
192 |
+
df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
|
193 |
+
print(f"**Total Score: {sum(df.score)}**")
|
194 |
+
df.loc["mean"] = df.mean(numeric_only=True)
|
195 |
+
print(df.to_markdown())
|
196 |
+
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
compare_runs()
|
rl_algo_impls/dqn/dqn.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import random
|
4 |
+
from collections import deque
|
5 |
+
from typing import List, NamedTuple, Optional, TypeVar
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.optim import Adam
|
12 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
13 |
+
|
14 |
+
from rl_algo_impls.dqn.policy import DQNPolicy
|
15 |
+
from rl_algo_impls.shared.algorithm import Algorithm
|
16 |
+
from rl_algo_impls.shared.callbacks import Callback
|
17 |
+
from rl_algo_impls.shared.schedule import linear_schedule
|
18 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
|
19 |
+
|
20 |
+
|
21 |
+
class Transition(NamedTuple):
|
22 |
+
obs: np.ndarray
|
23 |
+
action: np.ndarray
|
24 |
+
reward: float
|
25 |
+
done: bool
|
26 |
+
next_obs: np.ndarray
|
27 |
+
|
28 |
+
|
29 |
+
class Batch(NamedTuple):
|
30 |
+
obs: np.ndarray
|
31 |
+
actions: np.ndarray
|
32 |
+
rewards: np.ndarray
|
33 |
+
dones: np.ndarray
|
34 |
+
next_obs: np.ndarray
|
35 |
+
|
36 |
+
|
37 |
+
class ReplayBuffer:
|
38 |
+
def __init__(self, num_envs: int, maxlen: int) -> None:
|
39 |
+
self.num_envs = num_envs
|
40 |
+
self.buffer = deque(maxlen=maxlen)
|
41 |
+
|
42 |
+
def add(
|
43 |
+
self,
|
44 |
+
obs: VecEnvObs,
|
45 |
+
action: np.ndarray,
|
46 |
+
reward: np.ndarray,
|
47 |
+
done: np.ndarray,
|
48 |
+
next_obs: VecEnvObs,
|
49 |
+
) -> None:
|
50 |
+
assert isinstance(obs, np.ndarray)
|
51 |
+
assert isinstance(next_obs, np.ndarray)
|
52 |
+
for i in range(self.num_envs):
|
53 |
+
self.buffer.append(
|
54 |
+
Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
|
55 |
+
)
|
56 |
+
|
57 |
+
def sample(self, batch_size: int) -> Batch:
|
58 |
+
ts = random.sample(self.buffer, batch_size)
|
59 |
+
return Batch(
|
60 |
+
obs=np.array([t.obs for t in ts]),
|
61 |
+
actions=np.array([t.action for t in ts]),
|
62 |
+
rewards=np.array([t.reward for t in ts]),
|
63 |
+
dones=np.array([t.done for t in ts]),
|
64 |
+
next_obs=np.array([t.next_obs for t in ts]),
|
65 |
+
)
|
66 |
+
|
67 |
+
def __len__(self) -> int:
|
68 |
+
return len(self.buffer)
|
69 |
+
|
70 |
+
|
71 |
+
DQNSelf = TypeVar("DQNSelf", bound="DQN")
|
72 |
+
|
73 |
+
|
74 |
+
class DQN(Algorithm):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
policy: DQNPolicy,
|
78 |
+
env: VecEnv,
|
79 |
+
device: torch.device,
|
80 |
+
tb_writer: SummaryWriter,
|
81 |
+
learning_rate: float = 1e-4,
|
82 |
+
buffer_size: int = 1_000_000,
|
83 |
+
learning_starts: int = 50_000,
|
84 |
+
batch_size: int = 32,
|
85 |
+
tau: float = 1.0,
|
86 |
+
gamma: float = 0.99,
|
87 |
+
train_freq: int = 4,
|
88 |
+
gradient_steps: int = 1,
|
89 |
+
target_update_interval: int = 10_000,
|
90 |
+
exploration_fraction: float = 0.1,
|
91 |
+
exploration_initial_eps: float = 1.0,
|
92 |
+
exploration_final_eps: float = 0.05,
|
93 |
+
max_grad_norm: float = 10.0,
|
94 |
+
) -> None:
|
95 |
+
super().__init__(policy, env, device, tb_writer)
|
96 |
+
self.policy = policy
|
97 |
+
|
98 |
+
self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
|
99 |
+
|
100 |
+
self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
|
101 |
+
self.target_q_net.train(False)
|
102 |
+
self.tau = tau
|
103 |
+
self.target_update_interval = target_update_interval
|
104 |
+
|
105 |
+
self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
|
106 |
+
self.batch_size = batch_size
|
107 |
+
|
108 |
+
self.learning_starts = learning_starts
|
109 |
+
self.train_freq = train_freq
|
110 |
+
self.gradient_steps = gradient_steps
|
111 |
+
|
112 |
+
self.gamma = gamma
|
113 |
+
self.exploration_eps_schedule = linear_schedule(
|
114 |
+
exploration_initial_eps,
|
115 |
+
exploration_final_eps,
|
116 |
+
end_fraction=exploration_fraction,
|
117 |
+
)
|
118 |
+
|
119 |
+
self.max_grad_norm = max_grad_norm
|
120 |
+
|
121 |
+
def learn(
|
122 |
+
self: DQNSelf, total_timesteps: int, callbacks: Optional[List[Callback]] = None
|
123 |
+
) -> DQNSelf:
|
124 |
+
self.policy.train(True)
|
125 |
+
obs = self.env.reset()
|
126 |
+
obs = self._collect_rollout(self.learning_starts, obs, 1)
|
127 |
+
learning_steps = total_timesteps - self.learning_starts
|
128 |
+
timesteps_elapsed = 0
|
129 |
+
steps_since_target_update = 0
|
130 |
+
while timesteps_elapsed < learning_steps:
|
131 |
+
progress = timesteps_elapsed / learning_steps
|
132 |
+
eps = self.exploration_eps_schedule(progress)
|
133 |
+
obs = self._collect_rollout(self.train_freq, obs, eps)
|
134 |
+
rollout_steps = self.train_freq
|
135 |
+
timesteps_elapsed += rollout_steps
|
136 |
+
for _ in range(
|
137 |
+
self.gradient_steps if self.gradient_steps > 0 else self.train_freq
|
138 |
+
):
|
139 |
+
self.train()
|
140 |
+
steps_since_target_update += rollout_steps
|
141 |
+
if steps_since_target_update >= self.target_update_interval:
|
142 |
+
self._update_target()
|
143 |
+
steps_since_target_update = 0
|
144 |
+
if callbacks:
|
145 |
+
if not all(
|
146 |
+
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
|
147 |
+
):
|
148 |
+
logging.info(
|
149 |
+
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
150 |
+
)
|
151 |
+
break
|
152 |
+
return self
|
153 |
+
|
154 |
+
def train(self) -> None:
|
155 |
+
if len(self.replay_buffer) < self.batch_size:
|
156 |
+
return
|
157 |
+
o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
|
158 |
+
o = torch.as_tensor(o, device=self.device)
|
159 |
+
a = torch.as_tensor(a, device=self.device).unsqueeze(1)
|
160 |
+
r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
|
161 |
+
d = torch.as_tensor(d, dtype=torch.long, device=self.device)
|
162 |
+
next_o = torch.as_tensor(next_o, device=self.device)
|
163 |
+
|
164 |
+
with torch.no_grad():
|
165 |
+
target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
|
166 |
+
current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
|
167 |
+
loss = F.smooth_l1_loss(current, target)
|
168 |
+
|
169 |
+
self.optimizer.zero_grad()
|
170 |
+
loss.backward()
|
171 |
+
if self.max_grad_norm:
|
172 |
+
nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
|
173 |
+
self.optimizer.step()
|
174 |
+
|
175 |
+
def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
|
176 |
+
for _ in range(0, timesteps, self.env.num_envs):
|
177 |
+
action = self.policy.act(obs, eps, deterministic=False)
|
178 |
+
next_obs, reward, done, _ = self.env.step(action)
|
179 |
+
self.replay_buffer.add(obs, action, reward, done, next_obs)
|
180 |
+
obs = next_obs
|
181 |
+
return obs
|
182 |
+
|
183 |
+
def _update_target(self) -> None:
|
184 |
+
for target_param, param in zip(
|
185 |
+
self.target_q_net.parameters(), self.policy.q_net.parameters()
|
186 |
+
):
|
187 |
+
target_param.data.copy_(
|
188 |
+
self.tau * param.data + (1 - self.tau) * target_param.data
|
189 |
+
)
|
rl_algo_impls/dqn/policy.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, Sequence, TypeVar
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from rl_algo_impls.dqn.q_net import QNetwork
|
8 |
+
from rl_algo_impls.shared.policy.policy import Policy
|
9 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
10 |
+
VecEnv,
|
11 |
+
VecEnvObs,
|
12 |
+
single_action_space,
|
13 |
+
single_observation_space,
|
14 |
+
)
|
15 |
+
|
16 |
+
DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
|
17 |
+
|
18 |
+
|
19 |
+
class DQNPolicy(Policy):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
env: VecEnv,
|
23 |
+
hidden_sizes: Sequence[int] = [],
|
24 |
+
cnn_flatten_dim: int = 512,
|
25 |
+
cnn_style: str = "nature",
|
26 |
+
cnn_layers_init_orthogonal: Optional[bool] = None,
|
27 |
+
impala_channels: Sequence[int] = (16, 32, 32),
|
28 |
+
**kwargs,
|
29 |
+
) -> None:
|
30 |
+
super().__init__(env, **kwargs)
|
31 |
+
self.q_net = QNetwork(
|
32 |
+
single_observation_space(env),
|
33 |
+
single_action_space(env),
|
34 |
+
hidden_sizes,
|
35 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
36 |
+
cnn_style=cnn_style,
|
37 |
+
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
38 |
+
impala_channels=impala_channels,
|
39 |
+
)
|
40 |
+
|
41 |
+
def act(
|
42 |
+
self,
|
43 |
+
obs: VecEnvObs,
|
44 |
+
eps: float = 0,
|
45 |
+
deterministic: bool = True,
|
46 |
+
action_masks: Optional[np.ndarray] = None,
|
47 |
+
) -> np.ndarray:
|
48 |
+
assert eps == 0 if deterministic else eps >= 0
|
49 |
+
assert (
|
50 |
+
action_masks is None
|
51 |
+
), f"action_masks not currently supported in {self.__class__.__name__}"
|
52 |
+
if not deterministic and np.random.random() < eps:
|
53 |
+
return np.array(
|
54 |
+
[
|
55 |
+
single_action_space(self.env).sample()
|
56 |
+
for _ in range(self.env.num_envs)
|
57 |
+
]
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
o = self._as_tensor(obs)
|
61 |
+
with torch.no_grad():
|
62 |
+
return self.q_net(o).argmax(axis=1).cpu().numpy()
|
rl_algo_impls/dqn/q_net.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Type
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch as th
|
5 |
+
import torch.nn as nn
|
6 |
+
from gym.spaces import Discrete
|
7 |
+
|
8 |
+
from rl_algo_impls.shared.encoder import Encoder
|
9 |
+
from rl_algo_impls.shared.module.utils import mlp
|
10 |
+
|
11 |
+
|
12 |
+
class QNetwork(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
observation_space: gym.Space,
|
16 |
+
action_space: gym.Space,
|
17 |
+
hidden_sizes: Sequence[int] = [],
|
18 |
+
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
|
19 |
+
cnn_flatten_dim: int = 512,
|
20 |
+
cnn_style: str = "nature",
|
21 |
+
cnn_layers_init_orthogonal: Optional[bool] = None,
|
22 |
+
impala_channels: Sequence[int] = (16, 32, 32),
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
assert isinstance(action_space, Discrete)
|
26 |
+
self._feature_extractor = Encoder(
|
27 |
+
observation_space,
|
28 |
+
activation,
|
29 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
30 |
+
cnn_style=cnn_style,
|
31 |
+
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
32 |
+
impala_channels=impala_channels,
|
33 |
+
)
|
34 |
+
layer_sizes = (
|
35 |
+
(self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
|
36 |
+
)
|
37 |
+
self._fc = mlp(layer_sizes, activation)
|
38 |
+
|
39 |
+
def forward(self, obs: th.Tensor) -> th.Tensor:
|
40 |
+
x = self._feature_extractor(obs)
|
41 |
+
return self._fc(x)
|
rl_algo_impls/enjoy.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
+
|
6 |
+
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
|
7 |
+
from rl_algo_impls.runner.running_utils import base_parser
|
8 |
+
|
9 |
+
|
10 |
+
def enjoy() -> None:
|
11 |
+
parser = base_parser(multiple=False)
|
12 |
+
parser.add_argument("--render", default=True, type=bool)
|
13 |
+
parser.add_argument("--best", default=True, type=bool)
|
14 |
+
parser.add_argument("--n_envs", default=1, type=int)
|
15 |
+
parser.add_argument("--n_episodes", default=3, type=int)
|
16 |
+
parser.add_argument("--deterministic-eval", default=None, type=bool)
|
17 |
+
parser.add_argument(
|
18 |
+
"--no-print-returns", action="store_true", help="Limit printing"
|
19 |
+
)
|
20 |
+
# wandb-run-path overrides base RunArgs
|
21 |
+
parser.add_argument("--wandb-run-path", default=None, type=str)
|
22 |
+
parser.set_defaults(
|
23 |
+
algo=["ppo"],
|
24 |
+
wandb_run_path="sgoodfriend/rl-algo-impls/m5c1t7g5",
|
25 |
+
)
|
26 |
+
args = parser.parse_args()
|
27 |
+
args.algo = args.algo[0]
|
28 |
+
args.env = args.env[0]
|
29 |
+
args = EvalArgs(**vars(args))
|
30 |
+
|
31 |
+
evaluate_model(args, os.getcwd())
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == "__main__":
|
35 |
+
enjoy()
|
rl_algo_impls/huggingface_publish.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
import tempfile
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
import requests
|
12 |
+
import wandb.apis.public
|
13 |
+
from huggingface_hub.hf_api import HfApi, upload_folder
|
14 |
+
from huggingface_hub.repocard import metadata_save
|
15 |
+
from pyvirtualdisplay.display import Display
|
16 |
+
|
17 |
+
import wandb
|
18 |
+
from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
|
19 |
+
from rl_algo_impls.runner.config import EnvHyperparams
|
20 |
+
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
|
21 |
+
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
22 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
23 |
+
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
24 |
+
|
25 |
+
|
26 |
+
def publish(
|
27 |
+
wandb_run_paths: List[str],
|
28 |
+
wandb_report_url: str,
|
29 |
+
huggingface_user: Optional[str] = None,
|
30 |
+
huggingface_token: Optional[str] = None,
|
31 |
+
virtual_display: bool = False,
|
32 |
+
) -> None:
|
33 |
+
if virtual_display:
|
34 |
+
display = Display(visible=False, size=(1400, 900))
|
35 |
+
display.start()
|
36 |
+
|
37 |
+
api = wandb.Api()
|
38 |
+
runs = [api.run(rp) for rp in wandb_run_paths]
|
39 |
+
algo = runs[0].config["algo"]
|
40 |
+
hyperparam_id = runs[0].config["env"]
|
41 |
+
evaluations = [
|
42 |
+
evaluate_model(
|
43 |
+
EvalArgs(
|
44 |
+
algo,
|
45 |
+
hyperparam_id,
|
46 |
+
seed=r.config.get("seed", None),
|
47 |
+
render=False,
|
48 |
+
best=True,
|
49 |
+
n_envs=None,
|
50 |
+
n_episodes=10,
|
51 |
+
no_print_returns=True,
|
52 |
+
wandb_run_path="/".join(r.path),
|
53 |
+
),
|
54 |
+
os.getcwd(),
|
55 |
+
)
|
56 |
+
for r in runs
|
57 |
+
]
|
58 |
+
run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
|
59 |
+
table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
|
60 |
+
best_eval = sorted(
|
61 |
+
table_data, key=lambda d: d.evaluation.stats.score, reverse=True
|
62 |
+
)[0]
|
63 |
+
|
64 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
65 |
+
_, (policy, stats, config) = best_eval
|
66 |
+
|
67 |
+
repo_name = config.model_name(include_seed=False)
|
68 |
+
repo_dir_path = os.path.join(tmpdirname, repo_name)
|
69 |
+
# Locally clone this repo to a temp directory
|
70 |
+
subprocess.run(["git", "clone", ".", repo_dir_path])
|
71 |
+
shutil.rmtree(os.path.join(repo_dir_path, ".git"))
|
72 |
+
model_path = config.model_dir_path(best=True, downloaded=True)
|
73 |
+
shutil.copytree(
|
74 |
+
model_path,
|
75 |
+
os.path.join(
|
76 |
+
repo_dir_path, "saved_models", config.model_dir_name(best=True)
|
77 |
+
),
|
78 |
+
)
|
79 |
+
|
80 |
+
github_url = "https://github.com/sgoodfriend/rl-algo-impls"
|
81 |
+
commit_hash = run_metadata.get("git", {}).get("commit", None)
|
82 |
+
env_id = runs[0].config.get("env_id") or runs[0].config["env"]
|
83 |
+
card_text = model_card_text(
|
84 |
+
algo,
|
85 |
+
env_id,
|
86 |
+
github_url,
|
87 |
+
commit_hash,
|
88 |
+
wandb_report_url,
|
89 |
+
table_data,
|
90 |
+
best_eval,
|
91 |
+
)
|
92 |
+
readme_filepath = os.path.join(repo_dir_path, "README.md")
|
93 |
+
os.remove(readme_filepath)
|
94 |
+
with open(readme_filepath, "w") as f:
|
95 |
+
f.write(card_text)
|
96 |
+
|
97 |
+
metadata = {
|
98 |
+
"library_name": "rl-algo-impls",
|
99 |
+
"tags": [
|
100 |
+
env_id,
|
101 |
+
algo,
|
102 |
+
"deep-reinforcement-learning",
|
103 |
+
"reinforcement-learning",
|
104 |
+
],
|
105 |
+
"model-index": [
|
106 |
+
{
|
107 |
+
"name": algo,
|
108 |
+
"results": [
|
109 |
+
{
|
110 |
+
"metrics": [
|
111 |
+
{
|
112 |
+
"type": "mean_reward",
|
113 |
+
"value": str(stats.score),
|
114 |
+
"name": "mean_reward",
|
115 |
+
}
|
116 |
+
],
|
117 |
+
"task": {
|
118 |
+
"type": "reinforcement-learning",
|
119 |
+
"name": "reinforcement-learning",
|
120 |
+
},
|
121 |
+
"dataset": {
|
122 |
+
"name": env_id,
|
123 |
+
"type": env_id,
|
124 |
+
},
|
125 |
+
}
|
126 |
+
],
|
127 |
+
}
|
128 |
+
],
|
129 |
+
}
|
130 |
+
metadata_save(readme_filepath, metadata)
|
131 |
+
|
132 |
+
video_env = VecEpisodeRecorder(
|
133 |
+
make_eval_env(
|
134 |
+
config,
|
135 |
+
EnvHyperparams(**config.env_hyperparams),
|
136 |
+
override_n_envs=1,
|
137 |
+
normalize_load_path=model_path,
|
138 |
+
),
|
139 |
+
os.path.join(repo_dir_path, "replay"),
|
140 |
+
max_video_length=3600,
|
141 |
+
)
|
142 |
+
evaluate(
|
143 |
+
video_env,
|
144 |
+
policy,
|
145 |
+
1,
|
146 |
+
deterministic=config.eval_hyperparams.get("deterministic", True),
|
147 |
+
)
|
148 |
+
|
149 |
+
api = HfApi()
|
150 |
+
huggingface_user = huggingface_user or api.whoami()["name"]
|
151 |
+
huggingface_repo = f"{huggingface_user}/{repo_name}"
|
152 |
+
api.create_repo(
|
153 |
+
token=huggingface_token,
|
154 |
+
repo_id=huggingface_repo,
|
155 |
+
private=False,
|
156 |
+
exist_ok=True,
|
157 |
+
)
|
158 |
+
repo_url = upload_folder(
|
159 |
+
repo_id=huggingface_repo,
|
160 |
+
folder_path=repo_dir_path,
|
161 |
+
path_in_repo="",
|
162 |
+
commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
|
163 |
+
token=huggingface_token,
|
164 |
+
delete_patterns="*",
|
165 |
+
)
|
166 |
+
print(f"Pushed model to the hub: {repo_url}")
|
167 |
+
|
168 |
+
|
169 |
+
def huggingface_publish():
|
170 |
+
parser = argparse.ArgumentParser()
|
171 |
+
parser.add_argument(
|
172 |
+
"--wandb-run-paths",
|
173 |
+
type=str,
|
174 |
+
nargs="+",
|
175 |
+
help="Run paths of the form entity/project/run_id",
|
176 |
+
)
|
177 |
+
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
|
178 |
+
parser.add_argument(
|
179 |
+
"--huggingface-user",
|
180 |
+
type=str,
|
181 |
+
help="Huggingface user or team to upload model cards",
|
182 |
+
default=None,
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--virtual-display", action="store_true", help="Use headless virtual display"
|
186 |
+
)
|
187 |
+
args = parser.parse_args()
|
188 |
+
print(args)
|
189 |
+
publish(**vars(args))
|
190 |
+
|
191 |
+
|
192 |
+
if __name__ == "__main__":
|
193 |
+
huggingface_publish()
|
rl_algo_impls/hyperparams/a2c.yml
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CartPole-v1: &cartpole-defaults
|
2 |
+
n_timesteps: !!float 5e5
|
3 |
+
env_hyperparams:
|
4 |
+
n_envs: 8
|
5 |
+
|
6 |
+
CartPole-v0:
|
7 |
+
<<: *cartpole-defaults
|
8 |
+
|
9 |
+
MountainCar-v0:
|
10 |
+
n_timesteps: !!float 1e6
|
11 |
+
env_hyperparams:
|
12 |
+
n_envs: 16
|
13 |
+
normalize: true
|
14 |
+
|
15 |
+
MountainCarContinuous-v0:
|
16 |
+
n_timesteps: !!float 1e5
|
17 |
+
env_hyperparams:
|
18 |
+
n_envs: 4
|
19 |
+
normalize: true
|
20 |
+
# policy_hyperparams:
|
21 |
+
# use_sde: true
|
22 |
+
# log_std_init: 0.0
|
23 |
+
# init_layers_orthogonal: false
|
24 |
+
algo_hyperparams:
|
25 |
+
n_steps: 100
|
26 |
+
sde_sample_freq: 16
|
27 |
+
|
28 |
+
Acrobot-v1:
|
29 |
+
n_timesteps: !!float 5e5
|
30 |
+
env_hyperparams:
|
31 |
+
normalize: true
|
32 |
+
n_envs: 16
|
33 |
+
|
34 |
+
# Tuned
|
35 |
+
LunarLander-v2:
|
36 |
+
device: cpu
|
37 |
+
n_timesteps: !!float 1e6
|
38 |
+
env_hyperparams:
|
39 |
+
n_envs: 4
|
40 |
+
normalize: true
|
41 |
+
algo_hyperparams:
|
42 |
+
n_steps: 2
|
43 |
+
gamma: 0.9955517404308908
|
44 |
+
gae_lambda: 0.9875340918797773
|
45 |
+
learning_rate: 0.0013814130817068916
|
46 |
+
learning_rate_decay: linear
|
47 |
+
ent_coef: !!float 3.388369146384422e-7
|
48 |
+
ent_coef_decay: none
|
49 |
+
max_grad_norm: 3.33982095073364
|
50 |
+
normalize_advantage: true
|
51 |
+
vf_coef: 0.1667838310548184
|
52 |
+
|
53 |
+
BipedalWalker-v3:
|
54 |
+
n_timesteps: !!float 5e6
|
55 |
+
env_hyperparams:
|
56 |
+
n_envs: 16
|
57 |
+
normalize: true
|
58 |
+
policy_hyperparams:
|
59 |
+
use_sde: true
|
60 |
+
log_std_init: -2
|
61 |
+
init_layers_orthogonal: false
|
62 |
+
algo_hyperparams:
|
63 |
+
ent_coef: 0
|
64 |
+
max_grad_norm: 0.5
|
65 |
+
n_steps: 8
|
66 |
+
gae_lambda: 0.9
|
67 |
+
vf_coef: 0.4
|
68 |
+
gamma: 0.99
|
69 |
+
learning_rate: !!float 9.6e-4
|
70 |
+
learning_rate_decay: linear
|
71 |
+
|
72 |
+
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
73 |
+
n_timesteps: !!float 2e6
|
74 |
+
env_hyperparams:
|
75 |
+
n_envs: 4
|
76 |
+
normalize: true
|
77 |
+
policy_hyperparams:
|
78 |
+
use_sde: true
|
79 |
+
log_std_init: -2
|
80 |
+
init_layers_orthogonal: false
|
81 |
+
algo_hyperparams: &pybullet-algo-defaults
|
82 |
+
n_steps: 8
|
83 |
+
ent_coef: 0
|
84 |
+
max_grad_norm: 0.5
|
85 |
+
gae_lambda: 0.9
|
86 |
+
gamma: 0.99
|
87 |
+
vf_coef: 0.4
|
88 |
+
learning_rate: !!float 9.6e-4
|
89 |
+
learning_rate_decay: linear
|
90 |
+
|
91 |
+
AntBulletEnv-v0:
|
92 |
+
<<: *pybullet-defaults
|
93 |
+
|
94 |
+
Walker2DBulletEnv-v0:
|
95 |
+
<<: *pybullet-defaults
|
96 |
+
|
97 |
+
HopperBulletEnv-v0:
|
98 |
+
<<: *pybullet-defaults
|
99 |
+
|
100 |
+
# Tuned
|
101 |
+
CarRacing-v0:
|
102 |
+
n_timesteps: !!float 4e6
|
103 |
+
env_hyperparams:
|
104 |
+
n_envs: 16
|
105 |
+
frame_stack: 4
|
106 |
+
normalize: true
|
107 |
+
normalize_kwargs:
|
108 |
+
norm_obs: false
|
109 |
+
norm_reward: true
|
110 |
+
policy_hyperparams:
|
111 |
+
use_sde: false
|
112 |
+
log_std_init: -1.3502584927786276
|
113 |
+
init_layers_orthogonal: true
|
114 |
+
activation_fn: tanh
|
115 |
+
share_features_extractor: false
|
116 |
+
cnn_flatten_dim: 256
|
117 |
+
hidden_sizes: [256]
|
118 |
+
algo_hyperparams:
|
119 |
+
n_steps: 16
|
120 |
+
learning_rate: 0.000025630993245026736
|
121 |
+
learning_rate_decay: linear
|
122 |
+
gamma: 0.99957617037542
|
123 |
+
gae_lambda: 0.949455676599436
|
124 |
+
ent_coef: !!float 1.707983205298309e-7
|
125 |
+
vf_coef: 0.10428178193833336
|
126 |
+
max_grad_norm: 0.5406643389792273
|
127 |
+
normalize_advantage: true
|
128 |
+
use_rms_prop: false
|
129 |
+
|
130 |
+
_atari: &atari-defaults
|
131 |
+
n_timesteps: !!float 1e7
|
132 |
+
env_hyperparams: &atari-env-defaults
|
133 |
+
n_envs: 16
|
134 |
+
frame_stack: 4
|
135 |
+
no_reward_timeout_steps: 1000
|
136 |
+
no_reward_fire_steps: 500
|
137 |
+
vec_env_class: async
|
138 |
+
policy_hyperparams: &atari-policy-defaults
|
139 |
+
activation_fn: relu
|
140 |
+
algo_hyperparams:
|
141 |
+
ent_coef: 0.01
|
142 |
+
vf_coef: 0.25
|
rl_algo_impls/hyperparams/dqn.yml
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CartPole-v1: &cartpole-defaults
|
2 |
+
n_timesteps: !!float 5e4
|
3 |
+
env_hyperparams:
|
4 |
+
rolling_length: 50
|
5 |
+
policy_hyperparams:
|
6 |
+
hidden_sizes: [256, 256]
|
7 |
+
algo_hyperparams:
|
8 |
+
learning_rate: !!float 2.3e-3
|
9 |
+
batch_size: 64
|
10 |
+
buffer_size: 100000
|
11 |
+
learning_starts: 1000
|
12 |
+
gamma: 0.99
|
13 |
+
target_update_interval: 10
|
14 |
+
train_freq: 256
|
15 |
+
gradient_steps: 128
|
16 |
+
exploration_fraction: 0.16
|
17 |
+
exploration_final_eps: 0.04
|
18 |
+
eval_hyperparams:
|
19 |
+
step_freq: !!float 1e4
|
20 |
+
|
21 |
+
CartPole-v0:
|
22 |
+
<<: *cartpole-defaults
|
23 |
+
n_timesteps: !!float 4e4
|
24 |
+
|
25 |
+
MountainCar-v0:
|
26 |
+
n_timesteps: !!float 1.2e5
|
27 |
+
env_hyperparams:
|
28 |
+
rolling_length: 50
|
29 |
+
policy_hyperparams:
|
30 |
+
hidden_sizes: [256, 256]
|
31 |
+
algo_hyperparams:
|
32 |
+
learning_rate: !!float 4e-3
|
33 |
+
batch_size: 128
|
34 |
+
buffer_size: 10000
|
35 |
+
learning_starts: 1000
|
36 |
+
gamma: 0.98
|
37 |
+
target_update_interval: 600
|
38 |
+
train_freq: 16
|
39 |
+
gradient_steps: 8
|
40 |
+
exploration_fraction: 0.2
|
41 |
+
exploration_final_eps: 0.07
|
42 |
+
|
43 |
+
Acrobot-v1:
|
44 |
+
n_timesteps: !!float 1e5
|
45 |
+
env_hyperparams:
|
46 |
+
rolling_length: 50
|
47 |
+
policy_hyperparams:
|
48 |
+
hidden_sizes: [256, 256]
|
49 |
+
algo_hyperparams:
|
50 |
+
learning_rate: !!float 6.3e-4
|
51 |
+
batch_size: 128
|
52 |
+
buffer_size: 50000
|
53 |
+
learning_starts: 0
|
54 |
+
gamma: 0.99
|
55 |
+
target_update_interval: 250
|
56 |
+
train_freq: 4
|
57 |
+
gradient_steps: -1
|
58 |
+
exploration_fraction: 0.12
|
59 |
+
exploration_final_eps: 0.1
|
60 |
+
|
61 |
+
LunarLander-v2:
|
62 |
+
n_timesteps: !!float 5e5
|
63 |
+
env_hyperparams:
|
64 |
+
rolling_length: 50
|
65 |
+
policy_hyperparams:
|
66 |
+
hidden_sizes: [256, 256]
|
67 |
+
algo_hyperparams:
|
68 |
+
learning_rate: !!float 1e-4
|
69 |
+
batch_size: 256
|
70 |
+
buffer_size: 100000
|
71 |
+
learning_starts: 10000
|
72 |
+
gamma: 0.99
|
73 |
+
target_update_interval: 250
|
74 |
+
train_freq: 8
|
75 |
+
gradient_steps: -1
|
76 |
+
exploration_fraction: 0.12
|
77 |
+
exploration_final_eps: 0.1
|
78 |
+
max_grad_norm: 0.5
|
79 |
+
eval_hyperparams:
|
80 |
+
step_freq: 25_000
|
81 |
+
|
82 |
+
_atari: &atari-defaults
|
83 |
+
n_timesteps: !!float 1e7
|
84 |
+
env_hyperparams:
|
85 |
+
frame_stack: 4
|
86 |
+
no_reward_timeout_steps: 1_000
|
87 |
+
no_reward_fire_steps: 500
|
88 |
+
n_envs: 8
|
89 |
+
vec_env_class: async
|
90 |
+
algo_hyperparams:
|
91 |
+
buffer_size: 100000
|
92 |
+
learning_rate: !!float 1e-4
|
93 |
+
batch_size: 32
|
94 |
+
learning_starts: 100000
|
95 |
+
target_update_interval: 1000
|
96 |
+
train_freq: 8
|
97 |
+
gradient_steps: 2
|
98 |
+
exploration_fraction: 0.1
|
99 |
+
exploration_final_eps: 0.01
|
100 |
+
eval_hyperparams:
|
101 |
+
deterministic: false
|
102 |
+
|
103 |
+
PongNoFrameskip-v4:
|
104 |
+
<<: *atari-defaults
|
105 |
+
n_timesteps: !!float 2.5e6
|
106 |
+
|
107 |
+
_impala-atari: &impala-atari-defaults
|
108 |
+
<<: *atari-defaults
|
109 |
+
policy_hyperparams:
|
110 |
+
cnn_style: impala
|
111 |
+
cnn_flatten_dim: 256
|
112 |
+
init_layers_orthogonal: true
|
113 |
+
cnn_layers_init_orthogonal: false
|
114 |
+
|
115 |
+
impala-PongNoFrameskip-v4:
|
116 |
+
<<: *impala-atari-defaults
|
117 |
+
env_id: PongNoFrameskip-v4
|
118 |
+
n_timesteps: !!float 2.5e6
|
119 |
+
|
120 |
+
impala-BreakoutNoFrameskip-v4:
|
121 |
+
<<: *impala-atari-defaults
|
122 |
+
env_id: BreakoutNoFrameskip-v4
|
123 |
+
|
124 |
+
impala-SpaceInvadersNoFrameskip-v4:
|
125 |
+
<<: *impala-atari-defaults
|
126 |
+
env_id: SpaceInvadersNoFrameskip-v4
|
127 |
+
|
128 |
+
impala-QbertNoFrameskip-v4:
|
129 |
+
<<: *impala-atari-defaults
|
130 |
+
env_id: QbertNoFrameskip-v4
|
rl_algo_impls/hyperparams/ppo.yml
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CartPole-v1: &cartpole-defaults
|
2 |
+
n_timesteps: !!float 1e5
|
3 |
+
env_hyperparams:
|
4 |
+
n_envs: 8
|
5 |
+
algo_hyperparams:
|
6 |
+
n_steps: 32
|
7 |
+
batch_size: 256
|
8 |
+
n_epochs: 20
|
9 |
+
gae_lambda: 0.8
|
10 |
+
gamma: 0.98
|
11 |
+
ent_coef: 0.0
|
12 |
+
learning_rate: 0.001
|
13 |
+
learning_rate_decay: linear
|
14 |
+
clip_range: 0.2
|
15 |
+
clip_range_decay: linear
|
16 |
+
eval_hyperparams:
|
17 |
+
step_freq: !!float 2.5e4
|
18 |
+
|
19 |
+
CartPole-v0:
|
20 |
+
<<: *cartpole-defaults
|
21 |
+
n_timesteps: !!float 5e4
|
22 |
+
|
23 |
+
MountainCar-v0:
|
24 |
+
n_timesteps: !!float 1e6
|
25 |
+
env_hyperparams:
|
26 |
+
normalize: true
|
27 |
+
n_envs: 16
|
28 |
+
algo_hyperparams:
|
29 |
+
n_steps: 16
|
30 |
+
n_epochs: 4
|
31 |
+
gae_lambda: 0.98
|
32 |
+
gamma: 0.99
|
33 |
+
ent_coef: 0.0
|
34 |
+
|
35 |
+
MountainCarContinuous-v0:
|
36 |
+
n_timesteps: !!float 1e5
|
37 |
+
env_hyperparams:
|
38 |
+
normalize: true
|
39 |
+
n_envs: 4
|
40 |
+
# policy_hyperparams:
|
41 |
+
# init_layers_orthogonal: false
|
42 |
+
# log_std_init: -3.29
|
43 |
+
# use_sde: true
|
44 |
+
algo_hyperparams:
|
45 |
+
n_steps: 512
|
46 |
+
batch_size: 256
|
47 |
+
n_epochs: 10
|
48 |
+
learning_rate: !!float 7.77e-5
|
49 |
+
ent_coef: 0.01 # 0.00429
|
50 |
+
ent_coef_decay: linear
|
51 |
+
clip_range: 0.1
|
52 |
+
gae_lambda: 0.9
|
53 |
+
max_grad_norm: 5
|
54 |
+
vf_coef: 0.19
|
55 |
+
eval_hyperparams:
|
56 |
+
step_freq: 5000
|
57 |
+
|
58 |
+
Acrobot-v1:
|
59 |
+
n_timesteps: !!float 1e6
|
60 |
+
env_hyperparams:
|
61 |
+
n_envs: 16
|
62 |
+
normalize: true
|
63 |
+
algo_hyperparams:
|
64 |
+
n_steps: 256
|
65 |
+
n_epochs: 4
|
66 |
+
gae_lambda: 0.94
|
67 |
+
gamma: 0.99
|
68 |
+
ent_coef: 0.0
|
69 |
+
|
70 |
+
LunarLander-v2:
|
71 |
+
n_timesteps: !!float 4e6
|
72 |
+
env_hyperparams:
|
73 |
+
n_envs: 16
|
74 |
+
algo_hyperparams:
|
75 |
+
n_steps: 1024
|
76 |
+
batch_size: 64
|
77 |
+
n_epochs: 4
|
78 |
+
gae_lambda: 0.98
|
79 |
+
gamma: 0.999
|
80 |
+
learning_rate: !!float 5e-4
|
81 |
+
learning_rate_decay: linear
|
82 |
+
clip_range: 0.2
|
83 |
+
clip_range_decay: linear
|
84 |
+
ent_coef: 0.01
|
85 |
+
normalize_advantage: false
|
86 |
+
|
87 |
+
BipedalWalker-v3:
|
88 |
+
n_timesteps: !!float 10e6
|
89 |
+
env_hyperparams:
|
90 |
+
n_envs: 16
|
91 |
+
normalize: true
|
92 |
+
algo_hyperparams:
|
93 |
+
n_steps: 2048
|
94 |
+
batch_size: 64
|
95 |
+
gae_lambda: 0.95
|
96 |
+
gamma: 0.99
|
97 |
+
n_epochs: 10
|
98 |
+
ent_coef: 0.001
|
99 |
+
learning_rate: !!float 2.5e-4
|
100 |
+
learning_rate_decay: linear
|
101 |
+
clip_range: 0.2
|
102 |
+
clip_range_decay: linear
|
103 |
+
|
104 |
+
CarRacing-v0: &carracing-defaults
|
105 |
+
n_timesteps: !!float 4e6
|
106 |
+
env_hyperparams:
|
107 |
+
n_envs: 8
|
108 |
+
frame_stack: 4
|
109 |
+
policy_hyperparams: &carracing-policy-defaults
|
110 |
+
use_sde: true
|
111 |
+
log_std_init: -2
|
112 |
+
init_layers_orthogonal: false
|
113 |
+
activation_fn: relu
|
114 |
+
share_features_extractor: false
|
115 |
+
cnn_flatten_dim: 256
|
116 |
+
hidden_sizes: [256]
|
117 |
+
algo_hyperparams:
|
118 |
+
n_steps: 512
|
119 |
+
batch_size: 128
|
120 |
+
n_epochs: 10
|
121 |
+
learning_rate: !!float 1e-4
|
122 |
+
learning_rate_decay: linear
|
123 |
+
gamma: 0.99
|
124 |
+
gae_lambda: 0.95
|
125 |
+
ent_coef: 0.0
|
126 |
+
sde_sample_freq: 4
|
127 |
+
max_grad_norm: 0.5
|
128 |
+
vf_coef: 0.5
|
129 |
+
clip_range: 0.2
|
130 |
+
|
131 |
+
impala-CarRacing-v0:
|
132 |
+
<<: *carracing-defaults
|
133 |
+
env_id: CarRacing-v0
|
134 |
+
policy_hyperparams:
|
135 |
+
<<: *carracing-policy-defaults
|
136 |
+
cnn_style: impala
|
137 |
+
init_layers_orthogonal: true
|
138 |
+
cnn_layers_init_orthogonal: false
|
139 |
+
hidden_sizes: []
|
140 |
+
|
141 |
+
# BreakoutNoFrameskip-v4
|
142 |
+
# PongNoFrameskip-v4
|
143 |
+
# SpaceInvadersNoFrameskip-v4
|
144 |
+
# QbertNoFrameskip-v4
|
145 |
+
_atari: &atari-defaults
|
146 |
+
n_timesteps: !!float 1e7
|
147 |
+
env_hyperparams: &atari-env-defaults
|
148 |
+
n_envs: 8
|
149 |
+
frame_stack: 4
|
150 |
+
no_reward_timeout_steps: 1000
|
151 |
+
no_reward_fire_steps: 500
|
152 |
+
vec_env_class: async
|
153 |
+
policy_hyperparams: &atari-policy-defaults
|
154 |
+
activation_fn: relu
|
155 |
+
algo_hyperparams: &atari-algo-defaults
|
156 |
+
n_steps: 128
|
157 |
+
batch_size: 256
|
158 |
+
n_epochs: 4
|
159 |
+
learning_rate: !!float 2.5e-4
|
160 |
+
learning_rate_decay: linear
|
161 |
+
clip_range: 0.1
|
162 |
+
clip_range_decay: linear
|
163 |
+
vf_coef: 0.5
|
164 |
+
ent_coef: 0.01
|
165 |
+
eval_hyperparams:
|
166 |
+
deterministic: false
|
167 |
+
|
168 |
+
_norm-rewards-atari: &norm-rewards-atari-default
|
169 |
+
<<: *atari-defaults
|
170 |
+
env_hyperparams:
|
171 |
+
<<: *atari-env-defaults
|
172 |
+
clip_atari_rewards: false
|
173 |
+
normalize: true
|
174 |
+
normalize_kwargs:
|
175 |
+
norm_obs: false
|
176 |
+
norm_reward: true
|
177 |
+
|
178 |
+
norm-rewards-BreakoutNoFrameskip-v4:
|
179 |
+
<<: *norm-rewards-atari-default
|
180 |
+
env_id: BreakoutNoFrameskip-v4
|
181 |
+
|
182 |
+
debug-PongNoFrameskip-v4:
|
183 |
+
<<: *atari-defaults
|
184 |
+
device: cpu
|
185 |
+
env_id: PongNoFrameskip-v4
|
186 |
+
env_hyperparams:
|
187 |
+
<<: *atari-env-defaults
|
188 |
+
vec_env_class: sync
|
189 |
+
|
190 |
+
_impala-atari: &impala-atari-defaults
|
191 |
+
<<: *atari-defaults
|
192 |
+
policy_hyperparams:
|
193 |
+
<<: *atari-policy-defaults
|
194 |
+
cnn_style: impala
|
195 |
+
cnn_flatten_dim: 256
|
196 |
+
init_layers_orthogonal: true
|
197 |
+
cnn_layers_init_orthogonal: false
|
198 |
+
|
199 |
+
impala-PongNoFrameskip-v4:
|
200 |
+
<<: *impala-atari-defaults
|
201 |
+
env_id: PongNoFrameskip-v4
|
202 |
+
|
203 |
+
impala-BreakoutNoFrameskip-v4:
|
204 |
+
<<: *impala-atari-defaults
|
205 |
+
env_id: BreakoutNoFrameskip-v4
|
206 |
+
|
207 |
+
impala-SpaceInvadersNoFrameskip-v4:
|
208 |
+
<<: *impala-atari-defaults
|
209 |
+
env_id: SpaceInvadersNoFrameskip-v4
|
210 |
+
|
211 |
+
impala-QbertNoFrameskip-v4:
|
212 |
+
<<: *impala-atari-defaults
|
213 |
+
env_id: QbertNoFrameskip-v4
|
214 |
+
|
215 |
+
_microrts: µrts-defaults
|
216 |
+
<<: *atari-defaults
|
217 |
+
n_timesteps: !!float 2e6
|
218 |
+
env_hyperparams: µrts-env-defaults
|
219 |
+
n_envs: 8
|
220 |
+
vec_env_class: sync
|
221 |
+
mask_actions: true
|
222 |
+
policy_hyperparams: µrts-policy-defaults
|
223 |
+
<<: *atari-policy-defaults
|
224 |
+
cnn_style: microrts
|
225 |
+
cnn_flatten_dim: 128
|
226 |
+
algo_hyperparams: µrts-algo-defaults
|
227 |
+
<<: *atari-algo-defaults
|
228 |
+
clip_range_decay: none
|
229 |
+
clip_range_vf: 0.1
|
230 |
+
ppo2_vf_coef_halving: true
|
231 |
+
eval_hyperparams: µrts-eval-defaults
|
232 |
+
deterministic: false # Good idea because MultiCategorical mode isn't great
|
233 |
+
|
234 |
+
_no-mask-microrts: &no-mask-microrts-defaults
|
235 |
+
<<: *microrts-defaults
|
236 |
+
env_hyperparams:
|
237 |
+
<<: *microrts-env-defaults
|
238 |
+
mask_actions: false
|
239 |
+
|
240 |
+
MicrortsMining-v1-NoMask:
|
241 |
+
<<: *no-mask-microrts-defaults
|
242 |
+
env_id: MicrortsMining-v1
|
243 |
+
|
244 |
+
MicrortsAttackShapedReward-v1-NoMask:
|
245 |
+
<<: *no-mask-microrts-defaults
|
246 |
+
env_id: MicrortsAttackShapedReward-v1
|
247 |
+
|
248 |
+
MicrortsRandomEnemyShapedReward3-v1-NoMask:
|
249 |
+
<<: *no-mask-microrts-defaults
|
250 |
+
env_id: MicrortsRandomEnemyShapedReward3-v1
|
251 |
+
|
252 |
+
_microrts_ai: µrts-ai-defaults
|
253 |
+
<<: *microrts-defaults
|
254 |
+
n_timesteps: !!float 100e6
|
255 |
+
additional_keys_to_log: ["microrts_stats"]
|
256 |
+
env_hyperparams: µrts-ai-env-defaults
|
257 |
+
n_envs: 24
|
258 |
+
env_type: microrts
|
259 |
+
make_kwargs: µrts-ai-env-make-kwargs-defaults
|
260 |
+
num_selfplay_envs: 0
|
261 |
+
max_steps: 2000
|
262 |
+
render_theme: 2
|
263 |
+
map_paths: [maps/16x16/basesWorkers16x16.xml]
|
264 |
+
reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
|
265 |
+
policy_hyperparams: µrts-ai-policy-defaults
|
266 |
+
<<: *microrts-policy-defaults
|
267 |
+
cnn_flatten_dim: 256
|
268 |
+
actor_head_style: gridnet
|
269 |
+
algo_hyperparams: µrts-ai-algo-defaults
|
270 |
+
<<: *microrts-algo-defaults
|
271 |
+
learning_rate: !!float 2.5e-4
|
272 |
+
learning_rate_decay: linear
|
273 |
+
n_steps: 512
|
274 |
+
batch_size: 3072
|
275 |
+
n_epochs: 4
|
276 |
+
ent_coef: 0.01
|
277 |
+
vf_coef: 0.5
|
278 |
+
max_grad_norm: 0.5
|
279 |
+
clip_range: 0.1
|
280 |
+
clip_range_vf: 0.1
|
281 |
+
eval_hyperparams: µrts-ai-eval-defaults
|
282 |
+
<<: *microrts-eval-defaults
|
283 |
+
score_function: mean
|
284 |
+
max_video_length: 4000
|
285 |
+
env_overrides: µrts-ai-eval-env-overrides
|
286 |
+
make_kwargs:
|
287 |
+
<<: *microrts-ai-env-make-kwargs-defaults
|
288 |
+
max_steps: 4000
|
289 |
+
reward_weight: [1.0, 0, 0, 0, 0, 0]
|
290 |
+
|
291 |
+
MicrortsAttackPassiveEnemySparseReward-v3:
|
292 |
+
<<: *microrts-ai-defaults
|
293 |
+
n_timesteps: !!float 2e6
|
294 |
+
env_id: MicrortsAttackPassiveEnemySparseReward-v3 # Workaround to keep model name simple
|
295 |
+
env_hyperparams:
|
296 |
+
<<: *microrts-ai-env-defaults
|
297 |
+
bots:
|
298 |
+
passiveAI: 24
|
299 |
+
|
300 |
+
MicrortsDefeatRandomEnemySparseReward-v3: µrts-random-ai-defaults
|
301 |
+
<<: *microrts-ai-defaults
|
302 |
+
n_timesteps: !!float 2e6
|
303 |
+
env_id: MicrortsDefeatRandomEnemySparseReward-v3 # Workaround to keep model name simple
|
304 |
+
env_hyperparams:
|
305 |
+
<<: *microrts-ai-env-defaults
|
306 |
+
bots:
|
307 |
+
randomBiasedAI: 24
|
308 |
+
|
309 |
+
enc-dec-MicrortsDefeatRandomEnemySparseReward-v3:
|
310 |
+
<<: *microrts-random-ai-defaults
|
311 |
+
policy_hyperparams:
|
312 |
+
<<: *microrts-ai-policy-defaults
|
313 |
+
cnn_style: gridnet_encoder
|
314 |
+
actor_head_style: gridnet_decoder
|
315 |
+
v_hidden_sizes: [128]
|
316 |
+
|
317 |
+
unet-MicrortsDefeatRandomEnemySparseReward-v3:
|
318 |
+
<<: *microrts-random-ai-defaults
|
319 |
+
# device: cpu
|
320 |
+
policy_hyperparams:
|
321 |
+
<<: *microrts-ai-policy-defaults
|
322 |
+
actor_head_style: unet
|
323 |
+
v_hidden_sizes: [256, 128]
|
324 |
+
algo_hyperparams:
|
325 |
+
<<: *microrts-ai-algo-defaults
|
326 |
+
learning_rate: !!float 2.5e-4
|
327 |
+
learning_rate_decay: spike
|
328 |
+
|
329 |
+
MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults
|
330 |
+
<<: *microrts-ai-defaults
|
331 |
+
env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
|
332 |
+
n_timesteps: !!float 300e6
|
333 |
+
env_hyperparams: µrts-coacai-env-defaults
|
334 |
+
<<: *microrts-ai-env-defaults
|
335 |
+
bots:
|
336 |
+
coacAI: 24
|
337 |
+
eval_hyperparams: µrts-coacai-eval-defaults
|
338 |
+
<<: *microrts-ai-eval-defaults
|
339 |
+
step_freq: !!float 1e6
|
340 |
+
n_episodes: 26
|
341 |
+
env_overrides: µrts-coacai-eval-env-overrides
|
342 |
+
<<: *microrts-ai-eval-env-overrides
|
343 |
+
n_envs: 26
|
344 |
+
bots:
|
345 |
+
coacAI: 2
|
346 |
+
randomBiasedAI: 2
|
347 |
+
randomAI: 2
|
348 |
+
passiveAI: 2
|
349 |
+
workerRushAI: 2
|
350 |
+
lightRushAI: 2
|
351 |
+
naiveMCTSAI: 2
|
352 |
+
mixedBot: 2
|
353 |
+
rojo: 2
|
354 |
+
izanagi: 2
|
355 |
+
tiamat: 2
|
356 |
+
droplet: 2
|
357 |
+
guidedRojoA3N: 2
|
358 |
+
|
359 |
+
MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults
|
360 |
+
<<: *microrts-coacai-defaults
|
361 |
+
env_hyperparams:
|
362 |
+
<<: *microrts-coacai-env-defaults
|
363 |
+
bots:
|
364 |
+
coacAI: 18
|
365 |
+
randomBiasedAI: 2
|
366 |
+
lightRushAI: 2
|
367 |
+
workerRushAI: 2
|
368 |
+
|
369 |
+
enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
|
370 |
+
µrts-env-dec-diverse-defaults
|
371 |
+
<<: *microrts-diverse-defaults
|
372 |
+
policy_hyperparams:
|
373 |
+
<<: *microrts-ai-policy-defaults
|
374 |
+
cnn_style: gridnet_encoder
|
375 |
+
actor_head_style: gridnet_decoder
|
376 |
+
v_hidden_sizes: [128]
|
377 |
+
|
378 |
+
debug-enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
|
379 |
+
<<: *microrts-env-dec-diverse-defaults
|
380 |
+
n_timesteps: !!float 1e6
|
381 |
+
|
382 |
+
unet-MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-unet-defaults
|
383 |
+
<<: *microrts-diverse-defaults
|
384 |
+
policy_hyperparams:
|
385 |
+
<<: *microrts-ai-policy-defaults
|
386 |
+
actor_head_style: unet
|
387 |
+
v_hidden_sizes: [256, 128]
|
388 |
+
algo_hyperparams: µrts-unet-algo-defaults
|
389 |
+
<<: *microrts-ai-algo-defaults
|
390 |
+
learning_rate: !!float 2.5e-4
|
391 |
+
learning_rate_decay: spike
|
392 |
+
|
393 |
+
Microrts-selfplay-unet: µrts-selfplay-defaults
|
394 |
+
<<: *microrts-unet-defaults
|
395 |
+
env_hyperparams: µrts-selfplay-env-defaults
|
396 |
+
<<: *microrts-ai-env-defaults
|
397 |
+
make_kwargs: µrts-selfplay-env-make-kwargs-defaults
|
398 |
+
<<: *microrts-ai-env-make-kwargs-defaults
|
399 |
+
num_selfplay_envs: 36
|
400 |
+
self_play_kwargs:
|
401 |
+
num_old_policies: 12
|
402 |
+
save_steps: 200000
|
403 |
+
swap_steps: 10000
|
404 |
+
swap_window_size: 4
|
405 |
+
window: 25
|
406 |
+
eval_hyperparams: µrts-selfplay-eval-defaults
|
407 |
+
<<: *microrts-coacai-eval-defaults
|
408 |
+
env_overrides: µrts-selfplay-eval-env-overrides
|
409 |
+
<<: *microrts-coacai-eval-env-overrides
|
410 |
+
self_play_kwargs: {}
|
411 |
+
|
412 |
+
Microrts-selfplay-unet-winloss: µrts-selfplay-winloss-defaults
|
413 |
+
<<: *microrts-selfplay-defaults
|
414 |
+
env_hyperparams:
|
415 |
+
<<: *microrts-selfplay-env-defaults
|
416 |
+
make_kwargs:
|
417 |
+
<<: *microrts-selfplay-env-make-kwargs-defaults
|
418 |
+
reward_weight: [1.0, 0, 0, 0, 0, 0]
|
419 |
+
algo_hyperparams: µrts-selfplay-winloss-algo-defaults
|
420 |
+
<<: *microrts-unet-algo-defaults
|
421 |
+
gamma: 0.999
|
422 |
+
|
423 |
+
Microrts-selfplay-unet-decay:
|
424 |
+
<<: *microrts-selfplay-winloss-defaults
|
425 |
+
microrts_reward_decay_callback: true
|
426 |
+
algo_hyperparams:
|
427 |
+
<<: *microrts-selfplay-winloss-algo-defaults
|
428 |
+
gamma_end: 0.999
|
429 |
+
|
430 |
+
Microrts-selfplay-unet-debug: µrts-selfplay-debug-defaults
|
431 |
+
<<: *microrts-selfplay-defaults
|
432 |
+
eval_hyperparams:
|
433 |
+
<<: *microrts-selfplay-eval-defaults
|
434 |
+
step_freq: !!float 1e5
|
435 |
+
env_overrides:
|
436 |
+
<<: *microrts-selfplay-eval-env-overrides
|
437 |
+
n_envs: 24
|
438 |
+
bots:
|
439 |
+
coacAI: 12
|
440 |
+
randomBiasedAI: 4
|
441 |
+
workerRushAI: 4
|
442 |
+
lightRushAI: 4
|
443 |
+
|
444 |
+
Microrts-selfplay-unet-debug-mps:
|
445 |
+
<<: *microrts-selfplay-debug-defaults
|
446 |
+
device: mps
|
447 |
+
|
448 |
+
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
449 |
+
n_timesteps: !!float 2e6
|
450 |
+
env_hyperparams: &pybullet-env-defaults
|
451 |
+
n_envs: 16
|
452 |
+
normalize: true
|
453 |
+
policy_hyperparams: &pybullet-policy-defaults
|
454 |
+
pi_hidden_sizes: [256, 256]
|
455 |
+
v_hidden_sizes: [256, 256]
|
456 |
+
activation_fn: relu
|
457 |
+
algo_hyperparams: &pybullet-algo-defaults
|
458 |
+
n_steps: 512
|
459 |
+
batch_size: 128
|
460 |
+
n_epochs: 20
|
461 |
+
gamma: 0.99
|
462 |
+
gae_lambda: 0.9
|
463 |
+
ent_coef: 0.0
|
464 |
+
max_grad_norm: 0.5
|
465 |
+
vf_coef: 0.5
|
466 |
+
learning_rate: !!float 3e-5
|
467 |
+
clip_range: 0.4
|
468 |
+
|
469 |
+
AntBulletEnv-v0:
|
470 |
+
<<: *pybullet-defaults
|
471 |
+
policy_hyperparams:
|
472 |
+
<<: *pybullet-policy-defaults
|
473 |
+
algo_hyperparams:
|
474 |
+
<<: *pybullet-algo-defaults
|
475 |
+
|
476 |
+
Walker2DBulletEnv-v0:
|
477 |
+
<<: *pybullet-defaults
|
478 |
+
algo_hyperparams:
|
479 |
+
<<: *pybullet-algo-defaults
|
480 |
+
clip_range_decay: linear
|
481 |
+
|
482 |
+
HopperBulletEnv-v0:
|
483 |
+
<<: *pybullet-defaults
|
484 |
+
algo_hyperparams:
|
485 |
+
<<: *pybullet-algo-defaults
|
486 |
+
clip_range_decay: linear
|
487 |
+
|
488 |
+
HumanoidBulletEnv-v0:
|
489 |
+
<<: *pybullet-defaults
|
490 |
+
n_timesteps: !!float 1e7
|
491 |
+
env_hyperparams:
|
492 |
+
<<: *pybullet-env-defaults
|
493 |
+
n_envs: 8
|
494 |
+
policy_hyperparams:
|
495 |
+
<<: *pybullet-policy-defaults
|
496 |
+
# log_std_init: -1
|
497 |
+
algo_hyperparams:
|
498 |
+
<<: *pybullet-algo-defaults
|
499 |
+
n_steps: 2048
|
500 |
+
batch_size: 64
|
501 |
+
n_epochs: 10
|
502 |
+
gae_lambda: 0.95
|
503 |
+
learning_rate: !!float 2.5e-4
|
504 |
+
clip_range: 0.2
|
505 |
+
|
506 |
+
_procgen: &procgen-defaults
|
507 |
+
env_hyperparams: &procgen-env-defaults
|
508 |
+
env_type: procgen
|
509 |
+
n_envs: 64
|
510 |
+
# grayscale: false
|
511 |
+
# frame_stack: 4
|
512 |
+
normalize: true # procgen only normalizes reward
|
513 |
+
make_kwargs: &procgen-make-kwargs-defaults
|
514 |
+
num_threads: 8
|
515 |
+
policy_hyperparams: &procgen-policy-defaults
|
516 |
+
activation_fn: relu
|
517 |
+
cnn_style: impala
|
518 |
+
cnn_flatten_dim: 256
|
519 |
+
init_layers_orthogonal: true
|
520 |
+
cnn_layers_init_orthogonal: false
|
521 |
+
algo_hyperparams: &procgen-algo-defaults
|
522 |
+
gamma: 0.999
|
523 |
+
gae_lambda: 0.95
|
524 |
+
n_steps: 256
|
525 |
+
batch_size: 2048
|
526 |
+
n_epochs: 3
|
527 |
+
ent_coef: 0.01
|
528 |
+
clip_range: 0.2
|
529 |
+
# clip_range_decay: linear
|
530 |
+
clip_range_vf: 0.2
|
531 |
+
learning_rate: !!float 5e-4
|
532 |
+
# learning_rate_decay: linear
|
533 |
+
vf_coef: 0.5
|
534 |
+
eval_hyperparams: &procgen-eval-defaults
|
535 |
+
ignore_first_episode: true
|
536 |
+
# deterministic: false
|
537 |
+
step_freq: !!float 1e5
|
538 |
+
|
539 |
+
_procgen-easy: &procgen-easy-defaults
|
540 |
+
<<: *procgen-defaults
|
541 |
+
n_timesteps: !!float 25e6
|
542 |
+
env_hyperparams: &procgen-easy-env-defaults
|
543 |
+
<<: *procgen-env-defaults
|
544 |
+
make_kwargs:
|
545 |
+
<<: *procgen-make-kwargs-defaults
|
546 |
+
distribution_mode: easy
|
547 |
+
|
548 |
+
procgen-coinrun-easy: &coinrun-easy-defaults
|
549 |
+
<<: *procgen-easy-defaults
|
550 |
+
env_id: coinrun
|
551 |
+
|
552 |
+
debug-procgen-coinrun:
|
553 |
+
<<: *coinrun-easy-defaults
|
554 |
+
device: cpu
|
555 |
+
|
556 |
+
procgen-starpilot-easy:
|
557 |
+
<<: *procgen-easy-defaults
|
558 |
+
env_id: starpilot
|
559 |
+
|
560 |
+
procgen-bossfight-easy:
|
561 |
+
<<: *procgen-easy-defaults
|
562 |
+
env_id: bossfight
|
563 |
+
|
564 |
+
procgen-bigfish-easy:
|
565 |
+
<<: *procgen-easy-defaults
|
566 |
+
env_id: bigfish
|
567 |
+
|
568 |
+
_procgen-hard: &procgen-hard-defaults
|
569 |
+
<<: *procgen-defaults
|
570 |
+
n_timesteps: !!float 200e6
|
571 |
+
env_hyperparams: &procgen-hard-env-defaults
|
572 |
+
<<: *procgen-env-defaults
|
573 |
+
n_envs: 256
|
574 |
+
make_kwargs:
|
575 |
+
<<: *procgen-make-kwargs-defaults
|
576 |
+
distribution_mode: hard
|
577 |
+
algo_hyperparams: &procgen-hard-algo-defaults
|
578 |
+
<<: *procgen-algo-defaults
|
579 |
+
batch_size: 8192
|
580 |
+
clip_range_decay: linear
|
581 |
+
learning_rate_decay: linear
|
582 |
+
eval_hyperparams:
|
583 |
+
<<: *procgen-eval-defaults
|
584 |
+
step_freq: !!float 5e5
|
585 |
+
|
586 |
+
procgen-starpilot-hard: &procgen-starpilot-hard-defaults
|
587 |
+
<<: *procgen-hard-defaults
|
588 |
+
env_id: starpilot
|
589 |
+
|
590 |
+
procgen-starpilot-hard-2xIMPALA:
|
591 |
+
<<: *procgen-starpilot-hard-defaults
|
592 |
+
policy_hyperparams:
|
593 |
+
<<: *procgen-policy-defaults
|
594 |
+
impala_channels: [32, 64, 64]
|
595 |
+
algo_hyperparams:
|
596 |
+
<<: *procgen-hard-algo-defaults
|
597 |
+
learning_rate: !!float 3.3e-4
|
598 |
+
|
599 |
+
procgen-starpilot-hard-2xIMPALA-fat:
|
600 |
+
<<: *procgen-starpilot-hard-defaults
|
601 |
+
policy_hyperparams:
|
602 |
+
<<: *procgen-policy-defaults
|
603 |
+
impala_channels: [32, 64, 64]
|
604 |
+
cnn_flatten_dim: 512
|
605 |
+
algo_hyperparams:
|
606 |
+
<<: *procgen-hard-algo-defaults
|
607 |
+
learning_rate: !!float 2.5e-4
|
608 |
+
|
609 |
+
procgen-starpilot-hard-4xIMPALA:
|
610 |
+
<<: *procgen-starpilot-hard-defaults
|
611 |
+
policy_hyperparams:
|
612 |
+
<<: *procgen-policy-defaults
|
613 |
+
impala_channels: [64, 128, 128]
|
614 |
+
algo_hyperparams:
|
615 |
+
<<: *procgen-hard-algo-defaults
|
616 |
+
learning_rate: !!float 2.1e-4
|
rl_algo_impls/hyperparams/vpg.yml
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CartPole-v1: &cartpole-defaults
|
2 |
+
n_timesteps: !!float 4e5
|
3 |
+
algo_hyperparams:
|
4 |
+
n_steps: 4096
|
5 |
+
pi_lr: 0.01
|
6 |
+
gamma: 0.99
|
7 |
+
gae_lambda: 1
|
8 |
+
val_lr: 0.01
|
9 |
+
train_v_iters: 80
|
10 |
+
eval_hyperparams:
|
11 |
+
step_freq: !!float 2.5e4
|
12 |
+
|
13 |
+
CartPole-v0:
|
14 |
+
<<: *cartpole-defaults
|
15 |
+
n_timesteps: !!float 1e5
|
16 |
+
algo_hyperparams:
|
17 |
+
n_steps: 1024
|
18 |
+
pi_lr: 0.01
|
19 |
+
gamma: 0.99
|
20 |
+
gae_lambda: 1
|
21 |
+
val_lr: 0.01
|
22 |
+
train_v_iters: 80
|
23 |
+
|
24 |
+
MountainCar-v0:
|
25 |
+
n_timesteps: !!float 1e6
|
26 |
+
env_hyperparams:
|
27 |
+
normalize: true
|
28 |
+
n_envs: 16
|
29 |
+
algo_hyperparams:
|
30 |
+
n_steps: 200
|
31 |
+
pi_lr: 0.005
|
32 |
+
gamma: 0.99
|
33 |
+
gae_lambda: 0.97
|
34 |
+
val_lr: 0.01
|
35 |
+
train_v_iters: 80
|
36 |
+
max_grad_norm: 0.5
|
37 |
+
|
38 |
+
MountainCarContinuous-v0:
|
39 |
+
n_timesteps: !!float 3e5
|
40 |
+
env_hyperparams:
|
41 |
+
normalize: true
|
42 |
+
n_envs: 4
|
43 |
+
# policy_hyperparams:
|
44 |
+
# init_layers_orthogonal: false
|
45 |
+
# log_std_init: -3.29
|
46 |
+
# use_sde: true
|
47 |
+
algo_hyperparams:
|
48 |
+
n_steps: 1000
|
49 |
+
pi_lr: !!float 5e-4
|
50 |
+
gamma: 0.99
|
51 |
+
gae_lambda: 0.9
|
52 |
+
val_lr: !!float 1e-3
|
53 |
+
train_v_iters: 80
|
54 |
+
max_grad_norm: 5
|
55 |
+
eval_hyperparams:
|
56 |
+
step_freq: 5000
|
57 |
+
|
58 |
+
Acrobot-v1:
|
59 |
+
n_timesteps: !!float 2e5
|
60 |
+
algo_hyperparams:
|
61 |
+
n_steps: 2048
|
62 |
+
pi_lr: 0.005
|
63 |
+
gamma: 0.99
|
64 |
+
gae_lambda: 0.97
|
65 |
+
val_lr: 0.01
|
66 |
+
train_v_iters: 80
|
67 |
+
max_grad_norm: 0.5
|
68 |
+
|
69 |
+
LunarLander-v2:
|
70 |
+
n_timesteps: !!float 4e6
|
71 |
+
policy_hyperparams:
|
72 |
+
hidden_sizes: [256, 256]
|
73 |
+
algo_hyperparams:
|
74 |
+
n_steps: 2048
|
75 |
+
pi_lr: 0.0001
|
76 |
+
gamma: 0.999
|
77 |
+
gae_lambda: 0.97
|
78 |
+
val_lr: 0.0001
|
79 |
+
train_v_iters: 80
|
80 |
+
max_grad_norm: 0.5
|
81 |
+
eval_hyperparams:
|
82 |
+
deterministic: false
|
83 |
+
|
84 |
+
BipedalWalker-v3:
|
85 |
+
n_timesteps: !!float 10e6
|
86 |
+
env_hyperparams:
|
87 |
+
n_envs: 16
|
88 |
+
normalize: true
|
89 |
+
policy_hyperparams:
|
90 |
+
hidden_sizes: [256, 256]
|
91 |
+
algo_hyperparams:
|
92 |
+
n_steps: 1600
|
93 |
+
gae_lambda: 0.95
|
94 |
+
gamma: 0.99
|
95 |
+
pi_lr: !!float 1e-4
|
96 |
+
val_lr: !!float 1e-4
|
97 |
+
train_v_iters: 80
|
98 |
+
max_grad_norm: 0.5
|
99 |
+
eval_hyperparams:
|
100 |
+
deterministic: false
|
101 |
+
|
102 |
+
CarRacing-v0:
|
103 |
+
n_timesteps: !!float 4e6
|
104 |
+
env_hyperparams:
|
105 |
+
frame_stack: 4
|
106 |
+
n_envs: 4
|
107 |
+
vec_env_class: sync
|
108 |
+
policy_hyperparams:
|
109 |
+
use_sde: true
|
110 |
+
log_std_init: -2
|
111 |
+
init_layers_orthogonal: false
|
112 |
+
activation_fn: relu
|
113 |
+
cnn_flatten_dim: 256
|
114 |
+
hidden_sizes: [256]
|
115 |
+
algo_hyperparams:
|
116 |
+
n_steps: 1000
|
117 |
+
pi_lr: !!float 5e-5
|
118 |
+
gamma: 0.99
|
119 |
+
gae_lambda: 0.95
|
120 |
+
val_lr: !!float 1e-4
|
121 |
+
train_v_iters: 40
|
122 |
+
max_grad_norm: 0.5
|
123 |
+
sde_sample_freq: 4
|
124 |
+
|
125 |
+
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
126 |
+
n_timesteps: !!float 2e6
|
127 |
+
env_hyperparams: &pybullet-env-defaults
|
128 |
+
normalize: true
|
129 |
+
policy_hyperparams: &pybullet-policy-defaults
|
130 |
+
hidden_sizes: [256, 256]
|
131 |
+
algo_hyperparams: &pybullet-algo-defaults
|
132 |
+
n_steps: 4000
|
133 |
+
pi_lr: !!float 3e-4
|
134 |
+
gamma: 0.99
|
135 |
+
gae_lambda: 0.97
|
136 |
+
val_lr: !!float 1e-3
|
137 |
+
train_v_iters: 80
|
138 |
+
max_grad_norm: 0.5
|
139 |
+
|
140 |
+
AntBulletEnv-v0:
|
141 |
+
<<: *pybullet-defaults
|
142 |
+
policy_hyperparams:
|
143 |
+
<<: *pybullet-policy-defaults
|
144 |
+
hidden_sizes: [400, 300]
|
145 |
+
algo_hyperparams:
|
146 |
+
<<: *pybullet-algo-defaults
|
147 |
+
pi_lr: !!float 7e-4
|
148 |
+
val_lr: !!float 7e-3
|
149 |
+
|
150 |
+
HopperBulletEnv-v0:
|
151 |
+
<<: *pybullet-defaults
|
152 |
+
|
153 |
+
Walker2DBulletEnv-v0:
|
154 |
+
<<: *pybullet-defaults
|
155 |
+
|
156 |
+
FrozenLake-v1:
|
157 |
+
n_timesteps: !!float 8e5
|
158 |
+
env_params:
|
159 |
+
make_kwargs:
|
160 |
+
map_name: 8x8
|
161 |
+
is_slippery: true
|
162 |
+
policy_hyperparams:
|
163 |
+
hidden_sizes: [64]
|
164 |
+
algo_hyperparams:
|
165 |
+
n_steps: 2048
|
166 |
+
pi_lr: 0.01
|
167 |
+
gamma: 0.99
|
168 |
+
gae_lambda: 0.98
|
169 |
+
val_lr: 0.01
|
170 |
+
train_v_iters: 80
|
171 |
+
max_grad_norm: 0.5
|
172 |
+
eval_hyperparams:
|
173 |
+
step_freq: !!float 5e4
|
174 |
+
n_episodes: 10
|
175 |
+
save_best: true
|
176 |
+
|
177 |
+
_atari: &atari-defaults
|
178 |
+
n_timesteps: !!float 10e6
|
179 |
+
env_hyperparams:
|
180 |
+
n_envs: 2
|
181 |
+
frame_stack: 4
|
182 |
+
no_reward_timeout_steps: 1000
|
183 |
+
no_reward_fire_steps: 500
|
184 |
+
vec_env_class: async
|
185 |
+
policy_hyperparams:
|
186 |
+
activation_fn: relu
|
187 |
+
algo_hyperparams:
|
188 |
+
n_steps: 3072
|
189 |
+
pi_lr: !!float 5e-5
|
190 |
+
gamma: 0.99
|
191 |
+
gae_lambda: 0.95
|
192 |
+
val_lr: !!float 1e-4
|
193 |
+
train_v_iters: 80
|
194 |
+
max_grad_norm: 0.5
|
195 |
+
ent_coef: 0.01
|
196 |
+
eval_hyperparams:
|
197 |
+
deterministic: false
|
rl_algo_impls/optimize.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import gc
|
3 |
+
import inspect
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from dataclasses import asdict, dataclass
|
7 |
+
from typing import Callable, List, NamedTuple, Optional, Sequence, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import optuna
|
11 |
+
import torch
|
12 |
+
from optuna.pruners import HyperbandPruner
|
13 |
+
from optuna.samplers import TPESampler
|
14 |
+
from optuna.visualization import plot_optimization_history, plot_param_importances
|
15 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
16 |
+
|
17 |
+
import wandb
|
18 |
+
from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
|
19 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
20 |
+
from rl_algo_impls.runner.running_utils import (
|
21 |
+
ALGOS,
|
22 |
+
base_parser,
|
23 |
+
get_device,
|
24 |
+
hparam_dict,
|
25 |
+
load_hyperparams,
|
26 |
+
make_policy,
|
27 |
+
set_seeds,
|
28 |
+
)
|
29 |
+
from rl_algo_impls.shared.callbacks import Callback
|
30 |
+
from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
|
31 |
+
MicrortsRewardDecayCallback,
|
32 |
+
)
|
33 |
+
from rl_algo_impls.shared.callbacks.optimize_callback import (
|
34 |
+
Evaluation,
|
35 |
+
OptimizeCallback,
|
36 |
+
evaluation,
|
37 |
+
)
|
38 |
+
from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
|
39 |
+
from rl_algo_impls.shared.stats import EpisodesStats
|
40 |
+
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
41 |
+
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
|
42 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class StudyArgs:
|
47 |
+
load_study: bool
|
48 |
+
study_name: Optional[str] = None
|
49 |
+
storage_path: Optional[str] = None
|
50 |
+
n_trials: int = 100
|
51 |
+
n_jobs: int = 1
|
52 |
+
n_evaluations: int = 4
|
53 |
+
n_eval_envs: int = 8
|
54 |
+
n_eval_episodes: int = 16
|
55 |
+
timeout: Union[int, float, None] = None
|
56 |
+
wandb_project_name: Optional[str] = None
|
57 |
+
wandb_entity: Optional[str] = None
|
58 |
+
wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
|
59 |
+
wandb_group: Optional[str] = None
|
60 |
+
virtual_display: bool = False
|
61 |
+
|
62 |
+
|
63 |
+
class Args(NamedTuple):
|
64 |
+
train_args: Sequence[RunArgs]
|
65 |
+
study_args: StudyArgs
|
66 |
+
|
67 |
+
|
68 |
+
def parse_args() -> Args:
|
69 |
+
parser = base_parser()
|
70 |
+
parser.add_argument(
|
71 |
+
"--load-study",
|
72 |
+
action="store_true",
|
73 |
+
help="Load a preexisting study, useful for parallelization",
|
74 |
+
)
|
75 |
+
parser.add_argument("--study-name", type=str, help="Optuna study name")
|
76 |
+
parser.add_argument(
|
77 |
+
"--storage-path",
|
78 |
+
type=str,
|
79 |
+
help="Path of database for Optuna to persist to",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--wandb-project-name",
|
83 |
+
type=str,
|
84 |
+
default="rl-algo-impls-tuning",
|
85 |
+
help="WandB project name to upload tuning data to. If none, won't upload",
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--wandb-entity",
|
89 |
+
type=str,
|
90 |
+
help="WandB team. None uses the default entity",
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--wandb-group", type=str, help="WandB group to group trials under"
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--n-trials", type=int, default=100, help="Maximum number of trials"
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--n-jobs", type=int, default=1, help="Number of jobs to run in parallel"
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--n-evaluations",
|
106 |
+
type=int,
|
107 |
+
default=4,
|
108 |
+
help="Number of evaluations during the training",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--n-eval-envs",
|
112 |
+
type=int,
|
113 |
+
default=8,
|
114 |
+
help="Number of envs in vectorized eval environment",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--n-eval-episodes",
|
118 |
+
type=int,
|
119 |
+
default=16,
|
120 |
+
help="Number of episodes to complete for evaluation",
|
121 |
+
)
|
122 |
+
parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization")
|
123 |
+
parser.add_argument(
|
124 |
+
"--virtual-display", action="store_true", help="Use headless virtual display"
|
125 |
+
)
|
126 |
+
# parser.set_defaults(
|
127 |
+
# algo=["a2c"],
|
128 |
+
# env=["CartPole-v1"],
|
129 |
+
# seed=[100, 200, 300],
|
130 |
+
# n_trials=5,
|
131 |
+
# virtual_display=True,
|
132 |
+
# )
|
133 |
+
train_dict, study_dict = {}, {}
|
134 |
+
for k, v in vars(parser.parse_args()).items():
|
135 |
+
if k in inspect.signature(StudyArgs).parameters:
|
136 |
+
study_dict[k] = v
|
137 |
+
else:
|
138 |
+
train_dict[k] = v
|
139 |
+
|
140 |
+
study_args = StudyArgs(**study_dict)
|
141 |
+
# Hyperparameter tuning across algos and envs not supported
|
142 |
+
assert len(train_dict["algo"]) == 1
|
143 |
+
assert len(train_dict["env"]) == 1
|
144 |
+
train_args = RunArgs.expand_from_dict(train_dict)
|
145 |
+
|
146 |
+
if not all((study_args.study_name, study_args.storage_path)):
|
147 |
+
hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env)
|
148 |
+
config = Config(train_args[0], hyperparams, os.getcwd())
|
149 |
+
if study_args.study_name is None:
|
150 |
+
study_args.study_name = config.run_name(include_seed=False)
|
151 |
+
if study_args.storage_path is None:
|
152 |
+
study_args.storage_path = (
|
153 |
+
f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}"
|
154 |
+
)
|
155 |
+
# Default set group name to study name
|
156 |
+
study_args.wandb_group = study_args.wandb_group or study_args.study_name
|
157 |
+
|
158 |
+
return Args(train_args, study_args)
|
159 |
+
|
160 |
+
|
161 |
+
def objective_fn(
|
162 |
+
args: Sequence[RunArgs], study_args: StudyArgs
|
163 |
+
) -> Callable[[optuna.Trial], float]:
|
164 |
+
def objective(trial: optuna.Trial) -> float:
|
165 |
+
if len(args) == 1:
|
166 |
+
return simple_optimize(trial, args[0], study_args)
|
167 |
+
else:
|
168 |
+
return stepwise_optimize(trial, args, study_args)
|
169 |
+
|
170 |
+
return objective
|
171 |
+
|
172 |
+
|
173 |
+
def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float:
|
174 |
+
base_hyperparams = load_hyperparams(args.algo, args.env)
|
175 |
+
base_config = Config(args, base_hyperparams, os.getcwd())
|
176 |
+
if args.algo == "a2c":
|
177 |
+
hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
|
178 |
+
else:
|
179 |
+
raise ValueError(f"Optimizing {args.algo} isn't supported")
|
180 |
+
config = Config(args, hyperparams, os.getcwd())
|
181 |
+
|
182 |
+
wandb_enabled = bool(study_args.wandb_project_name)
|
183 |
+
if wandb_enabled:
|
184 |
+
wandb.init(
|
185 |
+
project=study_args.wandb_project_name,
|
186 |
+
entity=study_args.wandb_entity,
|
187 |
+
config=asdict(hyperparams),
|
188 |
+
name=f"{config.model_name()}-{str(trial.number)}",
|
189 |
+
tags=study_args.wandb_tags,
|
190 |
+
group=study_args.wandb_group,
|
191 |
+
sync_tensorboard=True,
|
192 |
+
monitor_gym=True,
|
193 |
+
save_code=True,
|
194 |
+
reinit=True,
|
195 |
+
)
|
196 |
+
wandb.config.update(args)
|
197 |
+
|
198 |
+
tb_writer = SummaryWriter(config.tensorboard_summary_path)
|
199 |
+
set_seeds(args.seed, args.use_deterministic_algorithms)
|
200 |
+
|
201 |
+
env = make_env(
|
202 |
+
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
203 |
+
)
|
204 |
+
device = get_device(config, env)
|
205 |
+
policy_factory = lambda: make_policy(
|
206 |
+
args.algo, env, device, **config.policy_hyperparams
|
207 |
+
)
|
208 |
+
policy = policy_factory()
|
209 |
+
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
210 |
+
|
211 |
+
eval_env = make_eval_env(
|
212 |
+
config,
|
213 |
+
EnvHyperparams(**config.env_hyperparams),
|
214 |
+
override_n_envs=study_args.n_eval_envs,
|
215 |
+
)
|
216 |
+
optimize_callback = OptimizeCallback(
|
217 |
+
policy,
|
218 |
+
eval_env,
|
219 |
+
trial,
|
220 |
+
tb_writer,
|
221 |
+
step_freq=config.n_timesteps // study_args.n_evaluations,
|
222 |
+
n_episodes=study_args.n_eval_episodes,
|
223 |
+
deterministic=config.eval_hyperparams.get("deterministic", True),
|
224 |
+
)
|
225 |
+
callbacks: List[Callback] = [optimize_callback]
|
226 |
+
if config.hyperparams.microrts_reward_decay_callback:
|
227 |
+
callbacks.append(MicrortsRewardDecayCallback(config, env))
|
228 |
+
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
|
229 |
+
if selfPlayWrapper:
|
230 |
+
callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
|
231 |
+
try:
|
232 |
+
algo.learn(config.n_timesteps, callbacks=callbacks)
|
233 |
+
|
234 |
+
if not optimize_callback.is_pruned:
|
235 |
+
optimize_callback.evaluate()
|
236 |
+
if not optimize_callback.is_pruned:
|
237 |
+
policy.save(config.model_dir_path(best=False))
|
238 |
+
|
239 |
+
eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
|
240 |
+
train_stat: EpisodesStats = callback.last_train_stat # type: ignore
|
241 |
+
|
242 |
+
tb_writer.add_hparams(
|
243 |
+
hparam_dict(hyperparams, vars(args)),
|
244 |
+
{
|
245 |
+
"hparam/last_mean": eval_stat.score.mean,
|
246 |
+
"hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
|
247 |
+
"hparam/train_mean": train_stat.score.mean,
|
248 |
+
"hparam/train_result": train_stat.score.mean - train_stat.score.std,
|
249 |
+
"hparam/score": optimize_callback.last_score,
|
250 |
+
"hparam/is_pruned": optimize_callback.is_pruned,
|
251 |
+
},
|
252 |
+
None,
|
253 |
+
config.run_name(),
|
254 |
+
)
|
255 |
+
tb_writer.close()
|
256 |
+
|
257 |
+
if wandb_enabled:
|
258 |
+
wandb.run.summary["state"] = ( # type: ignore
|
259 |
+
"Pruned" if optimize_callback.is_pruned else "Complete"
|
260 |
+
)
|
261 |
+
wandb.finish(quiet=True)
|
262 |
+
|
263 |
+
if optimize_callback.is_pruned:
|
264 |
+
raise optuna.exceptions.TrialPruned()
|
265 |
+
|
266 |
+
return optimize_callback.last_score
|
267 |
+
except AssertionError as e:
|
268 |
+
logging.warning(e)
|
269 |
+
return np.nan
|
270 |
+
finally:
|
271 |
+
env.close()
|
272 |
+
eval_env.close()
|
273 |
+
gc.collect()
|
274 |
+
torch.cuda.empty_cache()
|
275 |
+
|
276 |
+
|
277 |
+
def stepwise_optimize(
|
278 |
+
trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs
|
279 |
+
) -> float:
|
280 |
+
algo = args[0].algo
|
281 |
+
env_id = args[0].env
|
282 |
+
base_hyperparams = load_hyperparams(algo, env_id)
|
283 |
+
base_config = Config(args[0], base_hyperparams, os.getcwd())
|
284 |
+
if algo == "a2c":
|
285 |
+
hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
|
286 |
+
else:
|
287 |
+
raise ValueError(f"Optimizing {algo} isn't supported")
|
288 |
+
|
289 |
+
wandb_enabled = bool(study_args.wandb_project_name)
|
290 |
+
if wandb_enabled:
|
291 |
+
wandb.init(
|
292 |
+
project=study_args.wandb_project_name,
|
293 |
+
entity=study_args.wandb_entity,
|
294 |
+
config=asdict(hyperparams),
|
295 |
+
name=f"{str(trial.number)}-S{base_config.seed()}",
|
296 |
+
tags=study_args.wandb_tags,
|
297 |
+
group=study_args.wandb_group,
|
298 |
+
save_code=True,
|
299 |
+
reinit=True,
|
300 |
+
)
|
301 |
+
|
302 |
+
score = -np.inf
|
303 |
+
|
304 |
+
for i in range(study_args.n_evaluations):
|
305 |
+
evaluations: List[Evaluation] = []
|
306 |
+
|
307 |
+
for arg in args:
|
308 |
+
config = Config(arg, hyperparams, os.getcwd())
|
309 |
+
|
310 |
+
tb_writer = SummaryWriter(config.tensorboard_summary_path)
|
311 |
+
set_seeds(arg.seed, arg.use_deterministic_algorithms)
|
312 |
+
|
313 |
+
env = make_env(
|
314 |
+
config,
|
315 |
+
EnvHyperparams(**config.env_hyperparams),
|
316 |
+
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
317 |
+
tb_writer=tb_writer,
|
318 |
+
)
|
319 |
+
device = get_device(config, env)
|
320 |
+
policy_factory = lambda: make_policy(
|
321 |
+
arg.algo, env, device, **config.policy_hyperparams
|
322 |
+
)
|
323 |
+
policy = policy_factory()
|
324 |
+
if i > 0:
|
325 |
+
policy.load(config.model_dir_path())
|
326 |
+
algo = ALGOS[arg.algo](
|
327 |
+
policy, env, device, tb_writer, **config.algo_hyperparams
|
328 |
+
)
|
329 |
+
|
330 |
+
eval_env = make_eval_env(
|
331 |
+
config,
|
332 |
+
EnvHyperparams(**config.env_hyperparams),
|
333 |
+
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
334 |
+
override_n_envs=study_args.n_eval_envs,
|
335 |
+
)
|
336 |
+
|
337 |
+
start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
|
338 |
+
train_timesteps = (
|
339 |
+
int((i + 1) * config.n_timesteps / study_args.n_evaluations)
|
340 |
+
- start_timesteps
|
341 |
+
)
|
342 |
+
|
343 |
+
callbacks = []
|
344 |
+
if config.hyperparams.microrts_reward_decay_callback:
|
345 |
+
callbacks.append(
|
346 |
+
MicrortsRewardDecayCallback(
|
347 |
+
config, env, start_timesteps=start_timesteps
|
348 |
+
)
|
349 |
+
)
|
350 |
+
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
|
351 |
+
if selfPlayWrapper:
|
352 |
+
callbacks.append(
|
353 |
+
SelfPlayCallback(policy, policy_factory, selfPlayWrapper)
|
354 |
+
)
|
355 |
+
try:
|
356 |
+
algo.learn(
|
357 |
+
train_timesteps,
|
358 |
+
callbacks=callbacks,
|
359 |
+
total_timesteps=config.n_timesteps,
|
360 |
+
start_timesteps=start_timesteps,
|
361 |
+
)
|
362 |
+
|
363 |
+
evaluations.append(
|
364 |
+
evaluation(
|
365 |
+
policy,
|
366 |
+
eval_env,
|
367 |
+
tb_writer,
|
368 |
+
study_args.n_eval_episodes,
|
369 |
+
config.eval_hyperparams.get("deterministic", True),
|
370 |
+
start_timesteps + train_timesteps,
|
371 |
+
)
|
372 |
+
)
|
373 |
+
|
374 |
+
policy.save(config.model_dir_path())
|
375 |
+
|
376 |
+
tb_writer.close()
|
377 |
+
|
378 |
+
except AssertionError as e:
|
379 |
+
logging.warning(e)
|
380 |
+
if wandb_enabled:
|
381 |
+
wandb_finish("Error")
|
382 |
+
return np.nan
|
383 |
+
finally:
|
384 |
+
env.close()
|
385 |
+
eval_env.close()
|
386 |
+
gc.collect()
|
387 |
+
torch.cuda.empty_cache()
|
388 |
+
|
389 |
+
d = {}
|
390 |
+
for idx, e in enumerate(evaluations):
|
391 |
+
d[f"{idx}/eval_mean"] = e.eval_stat.score.mean
|
392 |
+
d[f"{idx}/train_mean"] = e.train_stat.score.mean
|
393 |
+
d[f"{idx}/score"] = e.score
|
394 |
+
d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item()
|
395 |
+
d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item()
|
396 |
+
score = np.mean([e.score for e in evaluations]).item()
|
397 |
+
d["score"] = score
|
398 |
+
|
399 |
+
step = i + 1
|
400 |
+
wandb.log(d, step=step)
|
401 |
+
|
402 |
+
print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}")
|
403 |
+
trial.report(score, step)
|
404 |
+
if trial.should_prune():
|
405 |
+
if wandb_enabled:
|
406 |
+
wandb_finish("Pruned")
|
407 |
+
raise optuna.exceptions.TrialPruned()
|
408 |
+
|
409 |
+
if wandb_enabled:
|
410 |
+
wandb_finish("Complete")
|
411 |
+
return score
|
412 |
+
|
413 |
+
|
414 |
+
def wandb_finish(state: str) -> None:
|
415 |
+
wandb.run.summary["state"] = state # type: ignore
|
416 |
+
wandb.finish(quiet=True)
|
417 |
+
|
418 |
+
|
419 |
+
def optimize() -> None:
|
420 |
+
from pyvirtualdisplay.display import Display
|
421 |
+
|
422 |
+
train_args, study_args = parse_args()
|
423 |
+
if study_args.virtual_display:
|
424 |
+
virtual_display = Display(visible=False, size=(1400, 900))
|
425 |
+
virtual_display.start()
|
426 |
+
|
427 |
+
sampler = TPESampler(**TPESampler.hyperopt_parameters())
|
428 |
+
pruner = HyperbandPruner()
|
429 |
+
if study_args.load_study:
|
430 |
+
assert study_args.study_name
|
431 |
+
assert study_args.storage_path
|
432 |
+
study = optuna.load_study(
|
433 |
+
study_name=study_args.study_name,
|
434 |
+
storage=study_args.storage_path,
|
435 |
+
sampler=sampler,
|
436 |
+
pruner=pruner,
|
437 |
+
)
|
438 |
+
else:
|
439 |
+
study = optuna.create_study(
|
440 |
+
study_name=study_args.study_name,
|
441 |
+
storage=study_args.storage_path,
|
442 |
+
sampler=sampler,
|
443 |
+
pruner=pruner,
|
444 |
+
direction="maximize",
|
445 |
+
)
|
446 |
+
|
447 |
+
try:
|
448 |
+
study.optimize(
|
449 |
+
objective_fn(train_args, study_args),
|
450 |
+
n_trials=study_args.n_trials,
|
451 |
+
n_jobs=study_args.n_jobs,
|
452 |
+
timeout=study_args.timeout,
|
453 |
+
)
|
454 |
+
except KeyboardInterrupt:
|
455 |
+
pass
|
456 |
+
|
457 |
+
best = study.best_trial
|
458 |
+
print(f"Best Trial Value: {best.value}")
|
459 |
+
print("Attributes:")
|
460 |
+
for key, value in list(best.params.items()) + list(best.user_attrs.items()):
|
461 |
+
print(f" {key}: {value}")
|
462 |
+
|
463 |
+
df = study.trials_dataframe()
|
464 |
+
df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False)
|
465 |
+
print(df.to_markdown(index=False))
|
466 |
+
|
467 |
+
fig1 = plot_optimization_history(study)
|
468 |
+
fig1.write_image("opt_history.png")
|
469 |
+
|
470 |
+
fig2 = plot_param_importances(study)
|
471 |
+
fig2.write_image("param_importances.png")
|
472 |
+
|
473 |
+
|
474 |
+
if __name__ == "__main__":
|
475 |
+
optimize()
|
rl_algo_impls/ppo/ppo.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from time import perf_counter
|
4 |
+
from typing import List, NamedTuple, Optional, TypeVar
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.optim import Adam
|
10 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
11 |
+
|
12 |
+
from rl_algo_impls.shared.algorithm import Algorithm
|
13 |
+
from rl_algo_impls.shared.callbacks import Callback
|
14 |
+
from rl_algo_impls.shared.gae import compute_advantages
|
15 |
+
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
|
16 |
+
from rl_algo_impls.shared.schedule import (
|
17 |
+
constant_schedule,
|
18 |
+
linear_schedule,
|
19 |
+
schedule,
|
20 |
+
update_learning_rate,
|
21 |
+
)
|
22 |
+
from rl_algo_impls.shared.stats import log_scalars
|
23 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
24 |
+
VecEnv,
|
25 |
+
single_action_space,
|
26 |
+
single_observation_space,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class TrainStepStats(NamedTuple):
|
31 |
+
loss: float
|
32 |
+
pi_loss: float
|
33 |
+
v_loss: float
|
34 |
+
entropy_loss: float
|
35 |
+
approx_kl: float
|
36 |
+
clipped_frac: float
|
37 |
+
val_clipped_frac: float
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class TrainStats:
|
42 |
+
loss: float
|
43 |
+
pi_loss: float
|
44 |
+
v_loss: float
|
45 |
+
entropy_loss: float
|
46 |
+
approx_kl: float
|
47 |
+
clipped_frac: float
|
48 |
+
val_clipped_frac: float
|
49 |
+
explained_var: float
|
50 |
+
|
51 |
+
def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None:
|
52 |
+
self.loss = np.mean([s.loss for s in step_stats]).item()
|
53 |
+
self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
|
54 |
+
self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
|
55 |
+
self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
|
56 |
+
self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
|
57 |
+
self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
|
58 |
+
self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item()
|
59 |
+
self.explained_var = explained_var
|
60 |
+
|
61 |
+
def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
|
62 |
+
for name, value in asdict(self).items():
|
63 |
+
tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
|
64 |
+
|
65 |
+
def __repr__(self) -> str:
|
66 |
+
return " | ".join(
|
67 |
+
[
|
68 |
+
f"Loss: {round(self.loss, 2)}",
|
69 |
+
f"Pi L: {round(self.pi_loss, 2)}",
|
70 |
+
f"V L: {round(self.v_loss, 2)}",
|
71 |
+
f"E L: {round(self.entropy_loss, 2)}",
|
72 |
+
f"Apx KL Div: {round(self.approx_kl, 2)}",
|
73 |
+
f"Clip Frac: {round(self.clipped_frac, 2)}",
|
74 |
+
f"Val Clip Frac: {round(self.val_clipped_frac, 2)}",
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
PPOSelf = TypeVar("PPOSelf", bound="PPO")
|
80 |
+
|
81 |
+
|
82 |
+
class PPO(Algorithm):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
policy: ActorCritic,
|
86 |
+
env: VecEnv,
|
87 |
+
device: torch.device,
|
88 |
+
tb_writer: SummaryWriter,
|
89 |
+
learning_rate: float = 3e-4,
|
90 |
+
learning_rate_decay: str = "none",
|
91 |
+
n_steps: int = 2048,
|
92 |
+
batch_size: int = 64,
|
93 |
+
n_epochs: int = 10,
|
94 |
+
gamma: float = 0.99,
|
95 |
+
gae_lambda: float = 0.95,
|
96 |
+
clip_range: float = 0.2,
|
97 |
+
clip_range_decay: str = "none",
|
98 |
+
clip_range_vf: Optional[float] = None,
|
99 |
+
clip_range_vf_decay: str = "none",
|
100 |
+
normalize_advantage: bool = True,
|
101 |
+
ent_coef: float = 0.0,
|
102 |
+
ent_coef_decay: str = "none",
|
103 |
+
vf_coef: float = 0.5,
|
104 |
+
ppo2_vf_coef_halving: bool = False,
|
105 |
+
max_grad_norm: float = 0.5,
|
106 |
+
sde_sample_freq: int = -1,
|
107 |
+
update_advantage_between_epochs: bool = True,
|
108 |
+
update_returns_between_epochs: bool = False,
|
109 |
+
gamma_end: Optional[float] = None,
|
110 |
+
) -> None:
|
111 |
+
super().__init__(policy, env, device, tb_writer)
|
112 |
+
self.policy = policy
|
113 |
+
self.get_action_mask = getattr(env, "get_action_mask")
|
114 |
+
|
115 |
+
self.gamma_schedule = (
|
116 |
+
linear_schedule(gamma, gamma_end)
|
117 |
+
if gamma_end is not None
|
118 |
+
else constant_schedule(gamma)
|
119 |
+
)
|
120 |
+
self.gae_lambda = gae_lambda
|
121 |
+
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
|
122 |
+
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
|
123 |
+
self.max_grad_norm = max_grad_norm
|
124 |
+
self.clip_range_schedule = schedule(clip_range_decay, clip_range)
|
125 |
+
self.clip_range_vf_schedule = None
|
126 |
+
if clip_range_vf:
|
127 |
+
self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
|
128 |
+
|
129 |
+
if normalize_advantage:
|
130 |
+
assert (
|
131 |
+
env.num_envs * n_steps > 1 and batch_size > 1
|
132 |
+
), f"Each minibatch must be larger than 1 to support normalization"
|
133 |
+
self.normalize_advantage = normalize_advantage
|
134 |
+
|
135 |
+
self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
|
136 |
+
self.vf_coef = vf_coef
|
137 |
+
self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
|
138 |
+
|
139 |
+
self.n_steps = n_steps
|
140 |
+
self.batch_size = batch_size
|
141 |
+
self.n_epochs = n_epochs
|
142 |
+
self.sde_sample_freq = sde_sample_freq
|
143 |
+
|
144 |
+
self.update_advantage_between_epochs = update_advantage_between_epochs
|
145 |
+
self.update_returns_between_epochs = update_returns_between_epochs
|
146 |
+
|
147 |
+
def learn(
|
148 |
+
self: PPOSelf,
|
149 |
+
train_timesteps: int,
|
150 |
+
callbacks: Optional[List[Callback]] = None,
|
151 |
+
total_timesteps: Optional[int] = None,
|
152 |
+
start_timesteps: int = 0,
|
153 |
+
) -> PPOSelf:
|
154 |
+
if total_timesteps is None:
|
155 |
+
total_timesteps = train_timesteps
|
156 |
+
assert start_timesteps + train_timesteps <= total_timesteps
|
157 |
+
|
158 |
+
epoch_dim = (self.n_steps, self.env.num_envs)
|
159 |
+
step_dim = (self.env.num_envs,)
|
160 |
+
obs_space = single_observation_space(self.env)
|
161 |
+
act_space = single_action_space(self.env)
|
162 |
+
act_shape = self.policy.action_shape
|
163 |
+
|
164 |
+
next_obs = self.env.reset()
|
165 |
+
next_action_masks = self.get_action_mask() if self.get_action_mask else None
|
166 |
+
next_episode_starts = np.full(step_dim, True, dtype=np.bool_)
|
167 |
+
|
168 |
+
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
|
169 |
+
actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
|
170 |
+
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
171 |
+
episode_starts = np.zeros(epoch_dim, dtype=np.bool_)
|
172 |
+
values = np.zeros(epoch_dim, dtype=np.float32)
|
173 |
+
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
174 |
+
action_masks = (
|
175 |
+
np.zeros(
|
176 |
+
(self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype
|
177 |
+
)
|
178 |
+
if next_action_masks is not None
|
179 |
+
else None
|
180 |
+
)
|
181 |
+
|
182 |
+
timesteps_elapsed = start_timesteps
|
183 |
+
while timesteps_elapsed < start_timesteps + train_timesteps:
|
184 |
+
start_time = perf_counter()
|
185 |
+
|
186 |
+
progress = timesteps_elapsed / total_timesteps
|
187 |
+
ent_coef = self.ent_coef_schedule(progress)
|
188 |
+
learning_rate = self.lr_schedule(progress)
|
189 |
+
update_learning_rate(self.optimizer, learning_rate)
|
190 |
+
pi_clip = self.clip_range_schedule(progress)
|
191 |
+
gamma = self.gamma_schedule(progress)
|
192 |
+
chart_scalars = {
|
193 |
+
"learning_rate": self.optimizer.param_groups[0]["lr"],
|
194 |
+
"ent_coef": ent_coef,
|
195 |
+
"pi_clip": pi_clip,
|
196 |
+
"gamma": gamma,
|
197 |
+
}
|
198 |
+
if self.clip_range_vf_schedule:
|
199 |
+
v_clip = self.clip_range_vf_schedule(progress)
|
200 |
+
chart_scalars["v_clip"] = v_clip
|
201 |
+
else:
|
202 |
+
v_clip = None
|
203 |
+
log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed)
|
204 |
+
|
205 |
+
self.policy.eval()
|
206 |
+
self.policy.reset_noise()
|
207 |
+
for s in range(self.n_steps):
|
208 |
+
timesteps_elapsed += self.env.num_envs
|
209 |
+
if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
|
210 |
+
self.policy.reset_noise()
|
211 |
+
|
212 |
+
obs[s] = next_obs
|
213 |
+
episode_starts[s] = next_episode_starts
|
214 |
+
if action_masks is not None:
|
215 |
+
action_masks[s] = next_action_masks
|
216 |
+
|
217 |
+
(
|
218 |
+
actions[s],
|
219 |
+
values[s],
|
220 |
+
logprobs[s],
|
221 |
+
clamped_action,
|
222 |
+
) = self.policy.step(next_obs, action_masks=next_action_masks)
|
223 |
+
next_obs, rewards[s], next_episode_starts, _ = self.env.step(
|
224 |
+
clamped_action
|
225 |
+
)
|
226 |
+
next_action_masks = (
|
227 |
+
self.get_action_mask() if self.get_action_mask else None
|
228 |
+
)
|
229 |
+
|
230 |
+
self.policy.train()
|
231 |
+
|
232 |
+
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore
|
233 |
+
b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to( # type: ignore
|
234 |
+
self.device
|
235 |
+
)
|
236 |
+
b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device)
|
237 |
+
b_action_masks = (
|
238 |
+
torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore
|
239 |
+
self.device
|
240 |
+
)
|
241 |
+
if action_masks is not None
|
242 |
+
else None
|
243 |
+
)
|
244 |
+
|
245 |
+
y_pred = values.reshape(-1)
|
246 |
+
b_values = torch.tensor(y_pred).to(self.device)
|
247 |
+
|
248 |
+
step_stats = []
|
249 |
+
# Define variables that will definitely be set through the first epoch
|
250 |
+
advantages: np.ndarray = None # type: ignore
|
251 |
+
b_advantages: torch.Tensor = None # type: ignore
|
252 |
+
y_true: np.ndarray = None # type: ignore
|
253 |
+
b_returns: torch.Tensor = None # type: ignore
|
254 |
+
for e in range(self.n_epochs):
|
255 |
+
if e == 0 or self.update_advantage_between_epochs:
|
256 |
+
advantages = compute_advantages(
|
257 |
+
rewards,
|
258 |
+
values,
|
259 |
+
episode_starts,
|
260 |
+
next_episode_starts,
|
261 |
+
next_obs,
|
262 |
+
self.policy,
|
263 |
+
gamma,
|
264 |
+
self.gae_lambda,
|
265 |
+
)
|
266 |
+
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
|
267 |
+
if e == 0 or self.update_returns_between_epochs:
|
268 |
+
returns = advantages + values
|
269 |
+
y_true = returns.reshape(-1)
|
270 |
+
b_returns = torch.tensor(y_true).to(self.device)
|
271 |
+
|
272 |
+
b_idxs = torch.randperm(len(b_obs))
|
273 |
+
# Only record last epoch's stats
|
274 |
+
step_stats.clear()
|
275 |
+
for i in range(0, len(b_obs), self.batch_size):
|
276 |
+
self.policy.reset_noise(self.batch_size)
|
277 |
+
|
278 |
+
mb_idxs = b_idxs[i : i + self.batch_size]
|
279 |
+
|
280 |
+
mb_obs = b_obs[mb_idxs]
|
281 |
+
mb_actions = b_actions[mb_idxs]
|
282 |
+
mb_values = b_values[mb_idxs]
|
283 |
+
mb_logprobs = b_logprobs[mb_idxs]
|
284 |
+
mb_action_masks = (
|
285 |
+
b_action_masks[mb_idxs] if b_action_masks is not None else None
|
286 |
+
)
|
287 |
+
|
288 |
+
mb_adv = b_advantages[mb_idxs]
|
289 |
+
if self.normalize_advantage:
|
290 |
+
mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)
|
291 |
+
mb_returns = b_returns[mb_idxs]
|
292 |
+
|
293 |
+
new_logprobs, entropy, new_values = self.policy(
|
294 |
+
mb_obs, mb_actions, action_masks=mb_action_masks
|
295 |
+
)
|
296 |
+
|
297 |
+
logratio = new_logprobs - mb_logprobs
|
298 |
+
ratio = torch.exp(logratio)
|
299 |
+
clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
|
300 |
+
pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean()
|
301 |
+
|
302 |
+
v_loss_unclipped = (new_values - mb_returns) ** 2
|
303 |
+
if v_clip:
|
304 |
+
v_loss_clipped = (
|
305 |
+
mb_values
|
306 |
+
+ torch.clamp(new_values - mb_values, -v_clip, v_clip)
|
307 |
+
- mb_returns
|
308 |
+
) ** 2
|
309 |
+
v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
310 |
+
else:
|
311 |
+
v_loss = v_loss_unclipped.mean()
|
312 |
+
|
313 |
+
if self.ppo2_vf_coef_halving:
|
314 |
+
v_loss *= 0.5
|
315 |
+
|
316 |
+
entropy_loss = -entropy.mean()
|
317 |
+
|
318 |
+
loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
|
319 |
+
|
320 |
+
self.optimizer.zero_grad()
|
321 |
+
loss.backward()
|
322 |
+
nn.utils.clip_grad_norm_(
|
323 |
+
self.policy.parameters(), self.max_grad_norm
|
324 |
+
)
|
325 |
+
self.optimizer.step()
|
326 |
+
|
327 |
+
with torch.no_grad():
|
328 |
+
approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
|
329 |
+
clipped_frac = (
|
330 |
+
((ratio - 1).abs() > pi_clip)
|
331 |
+
.float()
|
332 |
+
.mean()
|
333 |
+
.cpu()
|
334 |
+
.numpy()
|
335 |
+
.item()
|
336 |
+
)
|
337 |
+
val_clipped_frac = (
|
338 |
+
((new_values - mb_values).abs() > v_clip)
|
339 |
+
.float()
|
340 |
+
.mean()
|
341 |
+
.cpu()
|
342 |
+
.numpy()
|
343 |
+
.item()
|
344 |
+
if v_clip
|
345 |
+
else 0
|
346 |
+
)
|
347 |
+
|
348 |
+
step_stats.append(
|
349 |
+
TrainStepStats(
|
350 |
+
loss.item(),
|
351 |
+
pi_loss.item(),
|
352 |
+
v_loss.item(),
|
353 |
+
entropy_loss.item(),
|
354 |
+
approx_kl,
|
355 |
+
clipped_frac,
|
356 |
+
val_clipped_frac,
|
357 |
+
)
|
358 |
+
)
|
359 |
+
|
360 |
+
var_y = np.var(y_true).item()
|
361 |
+
explained_var = (
|
362 |
+
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
|
363 |
+
)
|
364 |
+
TrainStats(step_stats, explained_var).write_to_tensorboard(
|
365 |
+
self.tb_writer, timesteps_elapsed
|
366 |
+
)
|
367 |
+
|
368 |
+
end_time = perf_counter()
|
369 |
+
rollout_steps = self.n_steps * self.env.num_envs
|
370 |
+
self.tb_writer.add_scalar(
|
371 |
+
"train/steps_per_second",
|
372 |
+
rollout_steps / (end_time - start_time),
|
373 |
+
timesteps_elapsed,
|
374 |
+
)
|
375 |
+
|
376 |
+
if callbacks:
|
377 |
+
if not all(
|
378 |
+
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
|
379 |
+
):
|
380 |
+
logging.info(
|
381 |
+
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
382 |
+
)
|
383 |
+
break
|
384 |
+
|
385 |
+
return self
|
rl_algo_impls/publish/markdown_format.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import wandb.apis.public
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
from dataclasses import dataclass, asdict
|
8 |
+
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
|
9 |
+
from urllib.parse import urlparse
|
10 |
+
|
11 |
+
from rl_algo_impls.runner.evaluate import Evaluation
|
12 |
+
|
13 |
+
EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class EvaluationRow:
|
18 |
+
algo: str
|
19 |
+
env: str
|
20 |
+
seed: Optional[int]
|
21 |
+
reward_mean: float
|
22 |
+
reward_std: float
|
23 |
+
eval_episodes: int
|
24 |
+
best: str
|
25 |
+
wandb_url: str
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
|
29 |
+
results = defaultdict(list)
|
30 |
+
for r in rows:
|
31 |
+
for k, v in asdict(r).items():
|
32 |
+
results[k].append(v)
|
33 |
+
return pd.DataFrame(results)
|
34 |
+
|
35 |
+
|
36 |
+
class EvalTableData(NamedTuple):
|
37 |
+
run: wandb.apis.public.Run
|
38 |
+
evaluation: Evaluation
|
39 |
+
|
40 |
+
|
41 |
+
def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
|
42 |
+
best_stats = sorted(
|
43 |
+
[d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
|
44 |
+
)[0]
|
45 |
+
table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
|
46 |
+
rows = [
|
47 |
+
EvaluationRow(
|
48 |
+
config.algo,
|
49 |
+
config.env_id,
|
50 |
+
config.seed(),
|
51 |
+
stats.score.mean,
|
52 |
+
stats.score.std,
|
53 |
+
len(stats),
|
54 |
+
"*" if stats == best_stats else "",
|
55 |
+
f"[wandb]({r.url})",
|
56 |
+
)
|
57 |
+
for (r, (_, stats, config)) in table_data
|
58 |
+
]
|
59 |
+
df = EvaluationRow.data_frame(rows)
|
60 |
+
return df.to_markdown(index=False)
|
61 |
+
|
62 |
+
|
63 |
+
def github_project_link(github_url: str) -> str:
|
64 |
+
return f"[{urlparse(github_url).path}]({github_url})"
|
65 |
+
|
66 |
+
|
67 |
+
def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
|
68 |
+
algo_caps = algo.upper()
|
69 |
+
lines = [
|
70 |
+
f"# **{algo_caps}** Agent playing **{env}**",
|
71 |
+
f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
|
72 |
+
f"the {github_project_link(github_url)} repo.",
|
73 |
+
f"All models trained at this commit can be found at {wandb_report_url}.",
|
74 |
+
]
|
75 |
+
return "\n\n".join(lines)
|
76 |
+
|
77 |
+
|
78 |
+
def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
|
79 |
+
if not commit_hash:
|
80 |
+
return github_project_link(github_url)
|
81 |
+
return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
|
82 |
+
|
83 |
+
|
84 |
+
def results_section(
|
85 |
+
table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
|
86 |
+
) -> str:
|
87 |
+
# type: ignore
|
88 |
+
lines = [
|
89 |
+
"## Training Results",
|
90 |
+
f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
|
91 |
+
+ "agents using different initial seeds. "
|
92 |
+
+ f"These agents were trained by checking out "
|
93 |
+
+ f"{github_tree_link(github_url, commit_hash)}. "
|
94 |
+
+ "The best and last models were kept from each training. "
|
95 |
+
+ "This submission has loaded the best models from each training, reevaluates "
|
96 |
+
+ "them, and selects the best model from these latest evaluations (mean - std).",
|
97 |
+
]
|
98 |
+
lines.append(evaluation_table(table_data))
|
99 |
+
return "\n\n".join(lines)
|
100 |
+
|
101 |
+
|
102 |
+
def prerequisites_section() -> str:
|
103 |
+
return """
|
104 |
+
### Prerequisites: Weights & Biases (WandB)
|
105 |
+
Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
|
106 |
+
By default training goes to a rl-algo-impls project while benchmarks go to
|
107 |
+
rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
|
108 |
+
models and the model weights are uploaded to WandB.
|
109 |
+
|
110 |
+
Before doing anything below, you'll need to create a wandb account and run `wandb
|
111 |
+
login`.
|
112 |
+
"""
|
113 |
+
|
114 |
+
|
115 |
+
def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
|
116 |
+
return f"""
|
117 |
+
## Usage
|
118 |
+
{urlparse(github_url).path}: {github_url}
|
119 |
+
|
120 |
+
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
121 |
+
implementation could be sufficiently different to not be able to reproduce similar
|
122 |
+
results. You might need to checkout the commit the agent was trained on:
|
123 |
+
{github_tree_link(github_url, commit_hash)}.
|
124 |
+
```
|
125 |
+
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
126 |
+
python enjoy.py --wandb-run-path={run_path}
|
127 |
+
```
|
128 |
+
|
129 |
+
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
130 |
+
Colab starting from the
|
131 |
+
[colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
|
132 |
+
notebook.
|
133 |
+
"""
|
134 |
+
|
135 |
+
|
136 |
+
def training_setion(
|
137 |
+
github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
|
138 |
+
) -> str:
|
139 |
+
return f"""
|
140 |
+
## Training
|
141 |
+
If you want the highest chance to reproduce these results, you'll want to checkout the
|
142 |
+
commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
|
143 |
+
training is deterministic, different hardware will give different results.
|
144 |
+
|
145 |
+
```
|
146 |
+
python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
|
147 |
+
```
|
148 |
+
|
149 |
+
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
150 |
+
Colab starting from the
|
151 |
+
[colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
|
152 |
+
notebook.
|
153 |
+
"""
|
154 |
+
|
155 |
+
|
156 |
+
def benchmarking_section(report_url: str) -> str:
|
157 |
+
return f"""
|
158 |
+
## Benchmarking (with Lambda Labs instance)
|
159 |
+
This and other models from {report_url} were generated by running a script on a Lambda
|
160 |
+
Labs instance. In a Lambda Labs instance terminal:
|
161 |
+
```
|
162 |
+
git clone git@github.com:sgoodfriend/rl-algo-impls.git
|
163 |
+
cd rl-algo-impls
|
164 |
+
bash ./lambda_labs/setup.sh
|
165 |
+
wandb login
|
166 |
+
bash ./lambda_labs/benchmark.sh [-a {{"ppo a2c dqn vpg"}}] [-e ENVS] [-j {{6}}] [-p {{rl-algo-impls-benchmarks}}] [-s {{"1 2 3"}}]
|
167 |
+
```
|
168 |
+
|
169 |
+
### Alternative: Google Colab Pro+
|
170 |
+
As an alternative,
|
171 |
+
[colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
|
172 |
+
can be used. However, this requires a Google Colab Pro+ subscription and running across
|
173 |
+
4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
|
174 |
+
"""
|
175 |
+
|
176 |
+
|
177 |
+
def hyperparams_section(run_config: Dict[str, Any]) -> str:
|
178 |
+
return f"""
|
179 |
+
## Hyperparameters
|
180 |
+
This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
|
181 |
+
run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
|
182 |
+
close and has some additional data:
|
183 |
+
```
|
184 |
+
{yaml.dump(run_config)}
|
185 |
+
```
|
186 |
+
"""
|
187 |
+
|
188 |
+
|
189 |
+
def model_card_text(
|
190 |
+
algo: str,
|
191 |
+
env: str,
|
192 |
+
github_url: str,
|
193 |
+
commit_hash: str,
|
194 |
+
wandb_report_url: str,
|
195 |
+
table_data: List[EvalTableData],
|
196 |
+
best_eval: EvalTableData,
|
197 |
+
) -> str:
|
198 |
+
run, (_, _, config) = best_eval
|
199 |
+
run_path = "/".join(run.path)
|
200 |
+
return "\n\n".join(
|
201 |
+
[
|
202 |
+
header_section(algo, env, github_url, wandb_report_url),
|
203 |
+
results_section(table_data, algo, github_url, commit_hash),
|
204 |
+
prerequisites_section(),
|
205 |
+
usage_section(github_url, run_path, commit_hash),
|
206 |
+
training_setion(github_url, commit_hash, algo, env, config.seed()),
|
207 |
+
benchmarking_section(wandb_report_url),
|
208 |
+
hyperparams_section(run.config),
|
209 |
+
]
|
210 |
+
)
|
rl_algo_impls/runner/config.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import inspect
|
3 |
+
import itertools
|
4 |
+
import os
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from datetime import datetime
|
7 |
+
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
8 |
+
|
9 |
+
RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class RunArgs:
|
14 |
+
algo: str
|
15 |
+
env: str
|
16 |
+
seed: Optional[int] = None
|
17 |
+
use_deterministic_algorithms: bool = True
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def expand_from_dict(
|
21 |
+
cls: Type[RunArgsSelf], d: Dict[str, Any]
|
22 |
+
) -> List[RunArgsSelf]:
|
23 |
+
maybe_listify = lambda v: [v] if isinstance(v, str) or isinstance(v, int) else v
|
24 |
+
algos = maybe_listify(d["algo"])
|
25 |
+
envs = maybe_listify(d["env"])
|
26 |
+
seeds = maybe_listify(d["seed"])
|
27 |
+
args = []
|
28 |
+
for algo, env, seed in itertools.product(algos, envs, seeds):
|
29 |
+
_d = d.copy()
|
30 |
+
_d.update({"algo": algo, "env": env, "seed": seed})
|
31 |
+
args.append(cls(**_d))
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class EnvHyperparams:
|
37 |
+
env_type: str = "gymvec"
|
38 |
+
n_envs: int = 1
|
39 |
+
frame_stack: int = 1
|
40 |
+
make_kwargs: Optional[Dict[str, Any]] = None
|
41 |
+
no_reward_timeout_steps: Optional[int] = None
|
42 |
+
no_reward_fire_steps: Optional[int] = None
|
43 |
+
vec_env_class: str = "sync"
|
44 |
+
normalize: bool = False
|
45 |
+
normalize_kwargs: Optional[Dict[str, Any]] = None
|
46 |
+
rolling_length: int = 100
|
47 |
+
train_record_video: bool = False
|
48 |
+
video_step_interval: Union[int, float] = 1_000_000
|
49 |
+
initial_steps_to_truncate: Optional[int] = None
|
50 |
+
clip_atari_rewards: bool = True
|
51 |
+
normalize_type: Optional[str] = None
|
52 |
+
mask_actions: bool = False
|
53 |
+
bots: Optional[Dict[str, int]] = None
|
54 |
+
self_play_kwargs: Optional[Dict[str, Any]] = None
|
55 |
+
|
56 |
+
|
57 |
+
HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class Hyperparams:
|
62 |
+
device: str = "auto"
|
63 |
+
n_timesteps: Union[int, float] = 100_000
|
64 |
+
env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
65 |
+
policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
66 |
+
algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
67 |
+
eval_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
68 |
+
env_id: Optional[str] = None
|
69 |
+
additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
|
70 |
+
microrts_reward_decay_callback: bool = False
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def from_dict_with_extra_fields(
|
74 |
+
cls: Type[HyperparamsSelf], d: Dict[str, Any]
|
75 |
+
) -> HyperparamsSelf:
|
76 |
+
return cls(
|
77 |
+
**{k: v for k, v in d.items() if k in inspect.signature(cls).parameters}
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class Config:
|
83 |
+
args: RunArgs
|
84 |
+
hyperparams: Hyperparams
|
85 |
+
root_dir: str
|
86 |
+
run_id: str = datetime.now().isoformat()
|
87 |
+
|
88 |
+
def seed(self, training: bool = True) -> Optional[int]:
|
89 |
+
seed = self.args.seed
|
90 |
+
if training or seed is None:
|
91 |
+
return seed
|
92 |
+
return seed + self.env_hyperparams.get("n_envs", 1)
|
93 |
+
|
94 |
+
@property
|
95 |
+
def device(self) -> str:
|
96 |
+
return self.hyperparams.device
|
97 |
+
|
98 |
+
@property
|
99 |
+
def n_timesteps(self) -> int:
|
100 |
+
return int(self.hyperparams.n_timesteps)
|
101 |
+
|
102 |
+
@property
|
103 |
+
def env_hyperparams(self) -> Dict[str, Any]:
|
104 |
+
return self.hyperparams.env_hyperparams
|
105 |
+
|
106 |
+
@property
|
107 |
+
def policy_hyperparams(self) -> Dict[str, Any]:
|
108 |
+
return self.hyperparams.policy_hyperparams
|
109 |
+
|
110 |
+
@property
|
111 |
+
def algo_hyperparams(self) -> Dict[str, Any]:
|
112 |
+
return self.hyperparams.algo_hyperparams
|
113 |
+
|
114 |
+
@property
|
115 |
+
def eval_hyperparams(self) -> Dict[str, Any]:
|
116 |
+
return self.hyperparams.eval_hyperparams
|
117 |
+
|
118 |
+
def eval_callback_params(self) -> Dict[str, Any]:
|
119 |
+
eval_hyperparams = self.eval_hyperparams.copy()
|
120 |
+
if "env_overrides" in eval_hyperparams:
|
121 |
+
del eval_hyperparams["env_overrides"]
|
122 |
+
return eval_hyperparams
|
123 |
+
|
124 |
+
@property
|
125 |
+
def algo(self) -> str:
|
126 |
+
return self.args.algo
|
127 |
+
|
128 |
+
@property
|
129 |
+
def env_id(self) -> str:
|
130 |
+
return self.hyperparams.env_id or self.args.env
|
131 |
+
|
132 |
+
@property
|
133 |
+
def additional_keys_to_log(self) -> List[str]:
|
134 |
+
return self.hyperparams.additional_keys_to_log
|
135 |
+
|
136 |
+
def model_name(self, include_seed: bool = True) -> str:
|
137 |
+
# Use arg env name instead of environment name
|
138 |
+
parts = [self.algo, self.args.env]
|
139 |
+
if include_seed and self.args.seed is not None:
|
140 |
+
parts.append(f"S{self.args.seed}")
|
141 |
+
|
142 |
+
# Assume that the custom arg name already has the necessary information
|
143 |
+
if not self.hyperparams.env_id:
|
144 |
+
make_kwargs = self.env_hyperparams.get("make_kwargs", {})
|
145 |
+
if make_kwargs:
|
146 |
+
for k, v in make_kwargs.items():
|
147 |
+
if type(v) == bool and v:
|
148 |
+
parts.append(k)
|
149 |
+
elif type(v) == int and v:
|
150 |
+
parts.append(f"{k}{v}")
|
151 |
+
else:
|
152 |
+
parts.append(str(v))
|
153 |
+
|
154 |
+
return "-".join(parts)
|
155 |
+
|
156 |
+
def run_name(self, include_seed: bool = True) -> str:
|
157 |
+
parts = [self.model_name(include_seed=include_seed), self.run_id]
|
158 |
+
return "-".join(parts)
|
159 |
+
|
160 |
+
@property
|
161 |
+
def saved_models_dir(self) -> str:
|
162 |
+
return os.path.join(self.root_dir, "saved_models")
|
163 |
+
|
164 |
+
@property
|
165 |
+
def downloaded_models_dir(self) -> str:
|
166 |
+
return os.path.join(self.root_dir, "downloaded_models")
|
167 |
+
|
168 |
+
def model_dir_name(
|
169 |
+
self,
|
170 |
+
best: bool = False,
|
171 |
+
extension: str = "",
|
172 |
+
) -> str:
|
173 |
+
return self.model_name() + ("-best" if best else "") + extension
|
174 |
+
|
175 |
+
def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
|
176 |
+
return os.path.join(
|
177 |
+
self.saved_models_dir if not downloaded else self.downloaded_models_dir,
|
178 |
+
self.model_dir_name(best=best),
|
179 |
+
)
|
180 |
+
|
181 |
+
@property
|
182 |
+
def runs_dir(self) -> str:
|
183 |
+
return os.path.join(self.root_dir, "runs")
|
184 |
+
|
185 |
+
@property
|
186 |
+
def tensorboard_summary_path(self) -> str:
|
187 |
+
return os.path.join(self.runs_dir, self.run_name())
|
188 |
+
|
189 |
+
@property
|
190 |
+
def logs_path(self) -> str:
|
191 |
+
return os.path.join(self.runs_dir, f"log.yml")
|
192 |
+
|
193 |
+
@property
|
194 |
+
def videos_dir(self) -> str:
|
195 |
+
return os.path.join(self.root_dir, "videos")
|
196 |
+
|
197 |
+
@property
|
198 |
+
def video_prefix(self) -> str:
|
199 |
+
return os.path.join(self.videos_dir, self.model_name())
|
200 |
+
|
201 |
+
@property
|
202 |
+
def best_videos_dir(self) -> str:
|
203 |
+
return os.path.join(self.videos_dir, f"{self.model_name()}-best")
|
rl_algo_impls/runner/evaluate.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import NamedTuple, Optional
|
5 |
+
|
6 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
|
7 |
+
from rl_algo_impls.runner.running_utils import (
|
8 |
+
get_device,
|
9 |
+
load_hyperparams,
|
10 |
+
make_policy,
|
11 |
+
set_seeds,
|
12 |
+
)
|
13 |
+
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
14 |
+
from rl_algo_impls.shared.policy.policy import Policy
|
15 |
+
from rl_algo_impls.shared.stats import EpisodesStats
|
16 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class EvalArgs(RunArgs):
|
21 |
+
render: bool = True
|
22 |
+
best: bool = True
|
23 |
+
n_envs: Optional[int] = 1
|
24 |
+
n_episodes: int = 3
|
25 |
+
deterministic_eval: Optional[bool] = None
|
26 |
+
no_print_returns: bool = False
|
27 |
+
wandb_run_path: Optional[str] = None
|
28 |
+
|
29 |
+
|
30 |
+
class Evaluation(NamedTuple):
|
31 |
+
policy: Policy
|
32 |
+
stats: EpisodesStats
|
33 |
+
config: Config
|
34 |
+
|
35 |
+
|
36 |
+
def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
|
37 |
+
if args.wandb_run_path:
|
38 |
+
import wandb
|
39 |
+
|
40 |
+
api = wandb.Api()
|
41 |
+
run = api.run(args.wandb_run_path)
|
42 |
+
params = run.config
|
43 |
+
|
44 |
+
args.algo = params["algo"]
|
45 |
+
args.env = params["env"]
|
46 |
+
args.seed = params.get("seed", None)
|
47 |
+
args.use_deterministic_algorithms = params.get(
|
48 |
+
"use_deterministic_algorithms", True
|
49 |
+
)
|
50 |
+
|
51 |
+
config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
|
52 |
+
model_path = config.model_dir_path(best=args.best, downloaded=True)
|
53 |
+
|
54 |
+
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
|
55 |
+
run.file(model_archive_name).download()
|
56 |
+
if os.path.isdir(model_path):
|
57 |
+
shutil.rmtree(model_path)
|
58 |
+
shutil.unpack_archive(model_archive_name, model_path)
|
59 |
+
os.remove(model_archive_name)
|
60 |
+
else:
|
61 |
+
hyperparams = load_hyperparams(args.algo, args.env)
|
62 |
+
|
63 |
+
config = Config(args, hyperparams, root_dir)
|
64 |
+
model_path = config.model_dir_path(best=args.best)
|
65 |
+
|
66 |
+
print(args)
|
67 |
+
|
68 |
+
set_seeds(args.seed, args.use_deterministic_algorithms)
|
69 |
+
|
70 |
+
env = make_eval_env(
|
71 |
+
config,
|
72 |
+
EnvHyperparams(**config.env_hyperparams),
|
73 |
+
override_n_envs=args.n_envs,
|
74 |
+
render=args.render,
|
75 |
+
normalize_load_path=model_path,
|
76 |
+
)
|
77 |
+
device = get_device(config, env)
|
78 |
+
policy = make_policy(
|
79 |
+
args.algo,
|
80 |
+
env,
|
81 |
+
device,
|
82 |
+
load_path=model_path,
|
83 |
+
**config.policy_hyperparams,
|
84 |
+
).eval()
|
85 |
+
|
86 |
+
deterministic = (
|
87 |
+
args.deterministic_eval
|
88 |
+
if args.deterministic_eval is not None
|
89 |
+
else config.eval_hyperparams.get("deterministic", True)
|
90 |
+
)
|
91 |
+
return Evaluation(
|
92 |
+
policy,
|
93 |
+
evaluate(
|
94 |
+
env,
|
95 |
+
policy,
|
96 |
+
args.n_episodes,
|
97 |
+
render=args.render,
|
98 |
+
deterministic=deterministic,
|
99 |
+
print_returns=not args.no_print_returns,
|
100 |
+
),
|
101 |
+
config,
|
102 |
+
)
|
rl_algo_impls/runner/running_utils.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from dataclasses import asdict
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Dict, Optional, Type, Union
|
8 |
+
|
9 |
+
import gym
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.backends.cudnn
|
14 |
+
import yaml
|
15 |
+
from gym.spaces import Box, Discrete
|
16 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
17 |
+
|
18 |
+
from rl_algo_impls.a2c.a2c import A2C
|
19 |
+
from rl_algo_impls.dqn.dqn import DQN
|
20 |
+
from rl_algo_impls.dqn.policy import DQNPolicy
|
21 |
+
from rl_algo_impls.ppo.ppo import PPO
|
22 |
+
from rl_algo_impls.runner.config import Config, Hyperparams
|
23 |
+
from rl_algo_impls.shared.algorithm import Algorithm
|
24 |
+
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
25 |
+
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
|
26 |
+
from rl_algo_impls.shared.policy.policy import Policy
|
27 |
+
from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
|
28 |
+
from rl_algo_impls.vpg.policy import VPGActorCritic
|
29 |
+
from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
|
30 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
|
31 |
+
|
32 |
+
ALGOS: Dict[str, Type[Algorithm]] = {
|
33 |
+
"dqn": DQN,
|
34 |
+
"vpg": VanillaPolicyGradient,
|
35 |
+
"ppo": PPO,
|
36 |
+
"a2c": A2C,
|
37 |
+
}
|
38 |
+
POLICIES: Dict[str, Type[Policy]] = {
|
39 |
+
"dqn": DQNPolicy,
|
40 |
+
"vpg": VPGActorCritic,
|
41 |
+
"ppo": ActorCritic,
|
42 |
+
"a2c": ActorCritic,
|
43 |
+
}
|
44 |
+
|
45 |
+
HYPERPARAMS_PATH = "hyperparams"
|
46 |
+
|
47 |
+
|
48 |
+
def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
|
49 |
+
parser = argparse.ArgumentParser()
|
50 |
+
parser.add_argument(
|
51 |
+
"--algo",
|
52 |
+
default=["dqn"],
|
53 |
+
type=str,
|
54 |
+
choices=list(ALGOS.keys()),
|
55 |
+
nargs="+" if multiple else 1,
|
56 |
+
help="Abbreviation(s) of algorithm(s)",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--env",
|
60 |
+
default=["CartPole-v1"],
|
61 |
+
type=str,
|
62 |
+
nargs="+" if multiple else 1,
|
63 |
+
help="Name of environment(s) in gym",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--seed",
|
67 |
+
default=[1],
|
68 |
+
type=int,
|
69 |
+
nargs="*" if multiple else "?",
|
70 |
+
help="Seeds to run experiment. Unset will do one run with no set seed",
|
71 |
+
)
|
72 |
+
return parser
|
73 |
+
|
74 |
+
|
75 |
+
def load_hyperparams(algo: str, env_id: str) -> Hyperparams:
|
76 |
+
root_path = Path(__file__).parent.parent
|
77 |
+
hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
|
78 |
+
with open(hyperparams_path, "r") as f:
|
79 |
+
hyperparams_dict = yaml.safe_load(f)
|
80 |
+
|
81 |
+
if env_id in hyperparams_dict:
|
82 |
+
return Hyperparams(**hyperparams_dict[env_id])
|
83 |
+
|
84 |
+
import_for_env_id(env_id)
|
85 |
+
spec = gym.spec(env_id)
|
86 |
+
entry_point_name = str(spec.entry_point) # type: ignore
|
87 |
+
if "AtariEnv" in entry_point_name and "_atari" in hyperparams_dict:
|
88 |
+
return Hyperparams(**hyperparams_dict["_atari"])
|
89 |
+
elif "gym_microrts" in entry_point_name and "_microrts" in hyperparams_dict:
|
90 |
+
return Hyperparams(**hyperparams_dict["_microrts"])
|
91 |
+
else:
|
92 |
+
raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
|
93 |
+
|
94 |
+
|
95 |
+
def get_device(config: Config, env: VecEnv) -> torch.device:
|
96 |
+
device = config.device
|
97 |
+
# cuda by default
|
98 |
+
if device == "auto":
|
99 |
+
device = "cuda"
|
100 |
+
# Apple MPS is a second choice (sometimes)
|
101 |
+
if device == "cuda" and not torch.cuda.is_available():
|
102 |
+
device = "mps"
|
103 |
+
# If no MPS, fallback to cpu
|
104 |
+
if device == "mps" and not torch.backends.mps.is_available():
|
105 |
+
device = "cpu"
|
106 |
+
# Simple environments like Discreet and 1-D Boxes might also be better
|
107 |
+
# served with the CPU.
|
108 |
+
if device == "mps":
|
109 |
+
obs_space = single_observation_space(env)
|
110 |
+
if isinstance(obs_space, Discrete):
|
111 |
+
device = "cpu"
|
112 |
+
elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
|
113 |
+
device = "cpu"
|
114 |
+
if is_microrts(config):
|
115 |
+
device = "cpu"
|
116 |
+
print(f"Device: {device}")
|
117 |
+
return torch.device(device)
|
118 |
+
|
119 |
+
|
120 |
+
def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
|
121 |
+
if seed is None:
|
122 |
+
return
|
123 |
+
random.seed(seed)
|
124 |
+
np.random.seed(seed)
|
125 |
+
torch.manual_seed(seed)
|
126 |
+
torch.backends.cudnn.benchmark = False
|
127 |
+
torch.use_deterministic_algorithms(use_deterministic_algorithms)
|
128 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
129 |
+
# Stop warning and it would introduce stochasticity if I was using TF
|
130 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
131 |
+
|
132 |
+
|
133 |
+
def make_policy(
|
134 |
+
algo: str,
|
135 |
+
env: VecEnv,
|
136 |
+
device: torch.device,
|
137 |
+
load_path: Optional[str] = None,
|
138 |
+
**kwargs,
|
139 |
+
) -> Policy:
|
140 |
+
policy = POLICIES[algo](env, **kwargs).to(device)
|
141 |
+
if load_path:
|
142 |
+
policy.load(load_path)
|
143 |
+
return policy
|
144 |
+
|
145 |
+
|
146 |
+
def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
|
147 |
+
figure = plt.figure()
|
148 |
+
cumulative_steps = [
|
149 |
+
(idx + 1) * callback.step_freq for idx in range(len(callback.stats))
|
150 |
+
]
|
151 |
+
plt.plot(
|
152 |
+
cumulative_steps,
|
153 |
+
[s.score.mean for s in callback.stats],
|
154 |
+
"b-",
|
155 |
+
label="mean",
|
156 |
+
)
|
157 |
+
plt.plot(
|
158 |
+
cumulative_steps,
|
159 |
+
[s.score.mean - s.score.std for s in callback.stats],
|
160 |
+
"g--",
|
161 |
+
label="mean-std",
|
162 |
+
)
|
163 |
+
plt.fill_between(
|
164 |
+
cumulative_steps,
|
165 |
+
[s.score.min for s in callback.stats], # type: ignore
|
166 |
+
[s.score.max for s in callback.stats], # type: ignore
|
167 |
+
facecolor="cyan",
|
168 |
+
label="range",
|
169 |
+
)
|
170 |
+
plt.xlabel("Steps")
|
171 |
+
plt.ylabel("Score")
|
172 |
+
plt.legend()
|
173 |
+
plt.title(f"Eval {run_name}")
|
174 |
+
tb_writer.add_figure("eval", figure)
|
175 |
+
|
176 |
+
|
177 |
+
Scalar = Union[bool, str, float, int, None]
|
178 |
+
|
179 |
+
|
180 |
+
def hparam_dict(
|
181 |
+
hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
|
182 |
+
) -> Dict[str, Scalar]:
|
183 |
+
flattened = args.copy()
|
184 |
+
for k, v in flattened.items():
|
185 |
+
if isinstance(v, list):
|
186 |
+
flattened[k] = json.dumps(v)
|
187 |
+
for k, v in asdict(hyperparams).items():
|
188 |
+
if isinstance(v, dict):
|
189 |
+
for sk, sv in v.items():
|
190 |
+
key = f"{k}/{sk}"
|
191 |
+
if isinstance(sv, dict) or isinstance(sv, list):
|
192 |
+
flattened[key] = str(sv)
|
193 |
+
else:
|
194 |
+
flattened[key] = sv
|
195 |
+
elif isinstance(v, list):
|
196 |
+
flattened[k] = json.dumps(v)
|
197 |
+
else:
|
198 |
+
flattened[k] = v # type: ignore
|
199 |
+
return flattened # type: ignore
|
rl_algo_impls/runner/train.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
|
2 |
+
import os
|
3 |
+
|
4 |
+
from rl_algo_impls.shared.callbacks import Callback
|
5 |
+
from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
|
6 |
+
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
|
7 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
|
8 |
+
|
9 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
10 |
+
|
11 |
+
import dataclasses
|
12 |
+
import shutil
|
13 |
+
from dataclasses import asdict, dataclass
|
14 |
+
from typing import Any, Dict, List, Optional, Sequence
|
15 |
+
|
16 |
+
import yaml
|
17 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
18 |
+
|
19 |
+
import wandb
|
20 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
21 |
+
from rl_algo_impls.runner.running_utils import (
|
22 |
+
ALGOS,
|
23 |
+
get_device,
|
24 |
+
hparam_dict,
|
25 |
+
load_hyperparams,
|
26 |
+
make_policy,
|
27 |
+
plot_eval_callback,
|
28 |
+
set_seeds,
|
29 |
+
)
|
30 |
+
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
31 |
+
from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
|
32 |
+
MicrortsRewardDecayCallback,
|
33 |
+
)
|
34 |
+
from rl_algo_impls.shared.stats import EpisodesStats
|
35 |
+
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class TrainArgs(RunArgs):
|
40 |
+
wandb_project_name: Optional[str] = None
|
41 |
+
wandb_entity: Optional[str] = None
|
42 |
+
wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
|
43 |
+
wandb_group: Optional[str] = None
|
44 |
+
|
45 |
+
|
46 |
+
def train(args: TrainArgs):
|
47 |
+
print(args)
|
48 |
+
hyperparams = load_hyperparams(args.algo, args.env)
|
49 |
+
print(hyperparams)
|
50 |
+
config = Config(args, hyperparams, os.getcwd())
|
51 |
+
|
52 |
+
wandb_enabled = args.wandb_project_name
|
53 |
+
if wandb_enabled:
|
54 |
+
wandb.tensorboard.patch(
|
55 |
+
root_logdir=config.tensorboard_summary_path, pytorch=True
|
56 |
+
)
|
57 |
+
wandb.init(
|
58 |
+
project=args.wandb_project_name,
|
59 |
+
entity=args.wandb_entity,
|
60 |
+
config=asdict(hyperparams),
|
61 |
+
name=config.run_name(),
|
62 |
+
monitor_gym=True,
|
63 |
+
save_code=True,
|
64 |
+
tags=args.wandb_tags,
|
65 |
+
group=args.wandb_group,
|
66 |
+
)
|
67 |
+
wandb.config.update(args)
|
68 |
+
|
69 |
+
tb_writer = SummaryWriter(config.tensorboard_summary_path)
|
70 |
+
|
71 |
+
set_seeds(args.seed, args.use_deterministic_algorithms)
|
72 |
+
|
73 |
+
env = make_env(
|
74 |
+
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
75 |
+
)
|
76 |
+
device = get_device(config, env)
|
77 |
+
policy_factory = lambda: make_policy(
|
78 |
+
args.algo, env, device, **config.policy_hyperparams
|
79 |
+
)
|
80 |
+
policy = policy_factory()
|
81 |
+
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
82 |
+
|
83 |
+
num_parameters = policy.num_parameters()
|
84 |
+
num_trainable_parameters = policy.num_trainable_parameters()
|
85 |
+
if wandb_enabled:
|
86 |
+
wandb.run.summary["num_parameters"] = num_parameters # type: ignore
|
87 |
+
wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters # type: ignore
|
88 |
+
else:
|
89 |
+
print(
|
90 |
+
f"num_parameters = {num_parameters} ; "
|
91 |
+
f"num_trainable_parameters = {num_trainable_parameters}"
|
92 |
+
)
|
93 |
+
|
94 |
+
eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
|
95 |
+
record_best_videos = config.eval_hyperparams.get("record_best_videos", True)
|
96 |
+
eval_callback = EvalCallback(
|
97 |
+
policy,
|
98 |
+
eval_env,
|
99 |
+
tb_writer,
|
100 |
+
best_model_path=config.model_dir_path(best=True),
|
101 |
+
**config.eval_callback_params(),
|
102 |
+
video_env=make_eval_env(
|
103 |
+
config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
|
104 |
+
)
|
105 |
+
if record_best_videos
|
106 |
+
else None,
|
107 |
+
best_video_dir=config.best_videos_dir,
|
108 |
+
additional_keys_to_log=config.additional_keys_to_log,
|
109 |
+
)
|
110 |
+
callbacks: List[Callback] = [eval_callback]
|
111 |
+
if config.hyperparams.microrts_reward_decay_callback:
|
112 |
+
callbacks.append(MicrortsRewardDecayCallback(config, env))
|
113 |
+
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
|
114 |
+
if selfPlayWrapper:
|
115 |
+
callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
|
116 |
+
algo.learn(config.n_timesteps, callbacks=callbacks)
|
117 |
+
|
118 |
+
policy.save(config.model_dir_path(best=False))
|
119 |
+
|
120 |
+
eval_stats = eval_callback.evaluate(n_episodes=10, print_returns=True)
|
121 |
+
|
122 |
+
plot_eval_callback(eval_callback, tb_writer, config.run_name())
|
123 |
+
|
124 |
+
log_dict: Dict[str, Any] = {
|
125 |
+
"eval": eval_stats._asdict(),
|
126 |
+
}
|
127 |
+
if eval_callback.best:
|
128 |
+
log_dict["best_eval"] = eval_callback.best._asdict()
|
129 |
+
log_dict.update(asdict(hyperparams))
|
130 |
+
log_dict.update(vars(args))
|
131 |
+
with open(config.logs_path, "a") as f:
|
132 |
+
yaml.dump({config.run_name(): log_dict}, f)
|
133 |
+
|
134 |
+
best_eval_stats: EpisodesStats = eval_callback.best # type: ignore
|
135 |
+
tb_writer.add_hparams(
|
136 |
+
hparam_dict(hyperparams, vars(args)),
|
137 |
+
{
|
138 |
+
"hparam/best_mean": best_eval_stats.score.mean,
|
139 |
+
"hparam/best_result": best_eval_stats.score.mean
|
140 |
+
- best_eval_stats.score.std,
|
141 |
+
"hparam/last_mean": eval_stats.score.mean,
|
142 |
+
"hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
|
143 |
+
},
|
144 |
+
None,
|
145 |
+
config.run_name(),
|
146 |
+
)
|
147 |
+
|
148 |
+
tb_writer.close()
|
149 |
+
|
150 |
+
if wandb_enabled:
|
151 |
+
shutil.make_archive(
|
152 |
+
os.path.join(wandb.run.dir, config.model_dir_name()),
|
153 |
+
"zip",
|
154 |
+
config.model_dir_path(),
|
155 |
+
)
|
156 |
+
shutil.make_archive(
|
157 |
+
os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
|
158 |
+
"zip",
|
159 |
+
config.model_dir_path(best=True),
|
160 |
+
)
|
161 |
+
wandb.finish()
|
rl_algo_impls/shared/actor/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
|
2 |
+
from rl_algo_impls.shared.actor.make_actor import actor_head
|
rl_algo_impls/shared/actor/actor.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import NamedTuple, Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.distributions import Distribution
|
8 |
+
|
9 |
+
|
10 |
+
class PiForward(NamedTuple):
|
11 |
+
pi: Distribution
|
12 |
+
logp_a: Optional[torch.Tensor]
|
13 |
+
entropy: Optional[torch.Tensor]
|
14 |
+
|
15 |
+
|
16 |
+
class Actor(nn.Module, ABC):
|
17 |
+
@abstractmethod
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
obs: torch.Tensor,
|
21 |
+
actions: Optional[torch.Tensor] = None,
|
22 |
+
action_masks: Optional[torch.Tensor] = None,
|
23 |
+
) -> PiForward:
|
24 |
+
...
|
25 |
+
|
26 |
+
def sample_weights(self, batch_size: int = 1) -> None:
|
27 |
+
pass
|
28 |
+
|
29 |
+
@property
|
30 |
+
@abstractmethod
|
31 |
+
def action_shape(self) -> Tuple[int, ...]:
|
32 |
+
...
|
33 |
+
|
34 |
+
|
35 |
+
def pi_forward(
|
36 |
+
distribution: Distribution, actions: Optional[torch.Tensor] = None
|
37 |
+
) -> PiForward:
|
38 |
+
logp_a = None
|
39 |
+
entropy = None
|
40 |
+
if actions is not None:
|
41 |
+
logp_a = distribution.log_prob(actions)
|
42 |
+
entropy = distribution.entropy()
|
43 |
+
return PiForward(distribution, logp_a, entropy)
|
rl_algo_impls/shared/actor/categorical.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.distributions import Categorical
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
|
8 |
+
from rl_algo_impls.shared.module.utils import mlp
|
9 |
+
|
10 |
+
|
11 |
+
class MaskedCategorical(Categorical):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
probs=None,
|
15 |
+
logits=None,
|
16 |
+
validate_args=None,
|
17 |
+
mask: Optional[torch.Tensor] = None,
|
18 |
+
):
|
19 |
+
if mask is not None:
|
20 |
+
assert logits is not None, "mask requires logits and not probs"
|
21 |
+
logits = torch.where(mask, logits, -1e8)
|
22 |
+
self.mask = mask
|
23 |
+
super().__init__(probs, logits, validate_args)
|
24 |
+
|
25 |
+
def entropy(self) -> torch.Tensor:
|
26 |
+
if self.mask is None:
|
27 |
+
return super().entropy()
|
28 |
+
# If mask set, then use approximation for entropy
|
29 |
+
p_log_p = self.logits * self.probs # type: ignore
|
30 |
+
masked = torch.where(self.mask, p_log_p, 0)
|
31 |
+
return -masked.sum(-1)
|
32 |
+
|
33 |
+
|
34 |
+
class CategoricalActorHead(Actor):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
act_dim: int,
|
38 |
+
in_dim: int,
|
39 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
40 |
+
activation: Type[nn.Module] = nn.Tanh,
|
41 |
+
init_layers_orthogonal: bool = True,
|
42 |
+
) -> None:
|
43 |
+
super().__init__()
|
44 |
+
layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
|
45 |
+
self._fc = mlp(
|
46 |
+
layer_sizes,
|
47 |
+
activation,
|
48 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
49 |
+
final_layer_gain=0.01,
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
obs: torch.Tensor,
|
55 |
+
actions: Optional[torch.Tensor] = None,
|
56 |
+
action_masks: Optional[torch.Tensor] = None,
|
57 |
+
) -> PiForward:
|
58 |
+
logits = self._fc(obs)
|
59 |
+
pi = MaskedCategorical(logits=logits, mask=action_masks)
|
60 |
+
return pi_forward(pi, actions)
|
61 |
+
|
62 |
+
@property
|
63 |
+
def action_shape(self) -> Tuple[int, ...]:
|
64 |
+
return ()
|
rl_algo_impls/shared/actor/gaussian.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.distributions import Distribution, Normal
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
|
8 |
+
from rl_algo_impls.shared.module.utils import mlp
|
9 |
+
|
10 |
+
|
11 |
+
class GaussianDistribution(Normal):
|
12 |
+
def log_prob(self, a: torch.Tensor) -> torch.Tensor:
|
13 |
+
return super().log_prob(a).sum(axis=-1)
|
14 |
+
|
15 |
+
def sample(self) -> torch.Tensor:
|
16 |
+
return self.rsample()
|
17 |
+
|
18 |
+
|
19 |
+
class GaussianActorHead(Actor):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
act_dim: int,
|
23 |
+
in_dim: int,
|
24 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
25 |
+
activation: Type[nn.Module] = nn.Tanh,
|
26 |
+
init_layers_orthogonal: bool = True,
|
27 |
+
log_std_init: float = -0.5,
|
28 |
+
) -> None:
|
29 |
+
super().__init__()
|
30 |
+
self.act_dim = act_dim
|
31 |
+
layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
|
32 |
+
self.mu_net = mlp(
|
33 |
+
layer_sizes,
|
34 |
+
activation,
|
35 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
36 |
+
final_layer_gain=0.01,
|
37 |
+
)
|
38 |
+
self.log_std = nn.Parameter(
|
39 |
+
torch.ones(act_dim, dtype=torch.float32) * log_std_init
|
40 |
+
)
|
41 |
+
|
42 |
+
def _distribution(self, obs: torch.Tensor) -> Distribution:
|
43 |
+
mu = self.mu_net(obs)
|
44 |
+
std = torch.exp(self.log_std)
|
45 |
+
return GaussianDistribution(mu, std)
|
46 |
+
|
47 |
+
def forward(
|
48 |
+
self,
|
49 |
+
obs: torch.Tensor,
|
50 |
+
actions: Optional[torch.Tensor] = None,
|
51 |
+
action_masks: Optional[torch.Tensor] = None,
|
52 |
+
) -> PiForward:
|
53 |
+
assert (
|
54 |
+
not action_masks
|
55 |
+
), f"{self.__class__.__name__} does not support action_masks"
|
56 |
+
pi = self._distribution(obs)
|
57 |
+
return pi_forward(pi, actions)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def action_shape(self) -> Tuple[int, ...]:
|
61 |
+
return (self.act_dim,)
|
rl_algo_impls/shared/actor/gridnet.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Tuple, Type
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from numpy.typing import NDArray
|
7 |
+
from torch.distributions import Distribution, constraints
|
8 |
+
|
9 |
+
from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
|
10 |
+
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
+
from rl_algo_impls.shared.module.utils import mlp
|
13 |
+
|
14 |
+
|
15 |
+
class GridnetDistribution(Distribution):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
map_size: int,
|
19 |
+
action_vec: NDArray[np.int64],
|
20 |
+
logits: torch.Tensor,
|
21 |
+
masks: torch.Tensor,
|
22 |
+
validate_args: Optional[bool] = None,
|
23 |
+
) -> None:
|
24 |
+
self.map_size = map_size
|
25 |
+
self.action_vec = action_vec
|
26 |
+
|
27 |
+
masks = masks.view(-1, masks.shape[-1])
|
28 |
+
split_masks = torch.split(masks, action_vec.tolist(), dim=1)
|
29 |
+
|
30 |
+
grid_logits = logits.reshape(-1, action_vec.sum())
|
31 |
+
split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
|
32 |
+
self.categoricals = [
|
33 |
+
MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
|
34 |
+
for lg, m in zip(split_logits, split_masks)
|
35 |
+
]
|
36 |
+
|
37 |
+
batch_shape = logits.size()[:-1] if logits.ndimension() > 1 else torch.Size()
|
38 |
+
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
|
39 |
+
|
40 |
+
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
|
41 |
+
prob_stack = torch.stack(
|
42 |
+
[
|
43 |
+
c.log_prob(a)
|
44 |
+
for a, c in zip(action.view(-1, action.shape[-1]).T, self.categoricals)
|
45 |
+
],
|
46 |
+
dim=-1,
|
47 |
+
)
|
48 |
+
logprob = prob_stack.view(-1, self.map_size, len(self.action_vec))
|
49 |
+
return logprob.sum(dim=(1, 2))
|
50 |
+
|
51 |
+
def entropy(self) -> torch.Tensor:
|
52 |
+
ent = torch.stack([c.entropy() for c in self.categoricals], dim=-1)
|
53 |
+
ent = ent.view(-1, self.map_size, len(self.action_vec))
|
54 |
+
return ent.sum(dim=(1, 2))
|
55 |
+
|
56 |
+
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
|
57 |
+
s = torch.stack([c.sample(sample_shape) for c in self.categoricals], dim=-1)
|
58 |
+
return s.view(-1, self.map_size, len(self.action_vec))
|
59 |
+
|
60 |
+
@property
|
61 |
+
def mode(self) -> torch.Tensor:
|
62 |
+
m = torch.stack([c.mode for c in self.categoricals], dim=-1)
|
63 |
+
return m.view(-1, self.map_size, len(self.action_vec))
|
64 |
+
|
65 |
+
@property
|
66 |
+
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
|
67 |
+
# Constraints handled by child distributions in dist
|
68 |
+
return {}
|
69 |
+
|
70 |
+
|
71 |
+
class GridnetActorHead(Actor):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
map_size: int,
|
75 |
+
action_vec: NDArray[np.int64],
|
76 |
+
in_dim: EncoderOutDim,
|
77 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
78 |
+
activation: Type[nn.Module] = nn.ReLU,
|
79 |
+
init_layers_orthogonal: bool = True,
|
80 |
+
) -> None:
|
81 |
+
super().__init__()
|
82 |
+
self.map_size = map_size
|
83 |
+
self.action_vec = action_vec
|
84 |
+
assert isinstance(in_dim, int)
|
85 |
+
layer_sizes = (in_dim,) + hidden_sizes + (map_size * action_vec.sum(),)
|
86 |
+
self._fc = mlp(
|
87 |
+
layer_sizes,
|
88 |
+
activation,
|
89 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
90 |
+
final_layer_gain=0.01,
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
obs: torch.Tensor,
|
96 |
+
actions: Optional[torch.Tensor] = None,
|
97 |
+
action_masks: Optional[torch.Tensor] = None,
|
98 |
+
) -> PiForward:
|
99 |
+
assert (
|
100 |
+
action_masks is not None
|
101 |
+
), f"No mask case unhandled in {self.__class__.__name__}"
|
102 |
+
logits = self._fc(obs)
|
103 |
+
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
104 |
+
return pi_forward(pi, actions)
|
105 |
+
|
106 |
+
@property
|
107 |
+
def action_shape(self) -> Tuple[int, ...]:
|
108 |
+
return (self.map_size, len(self.action_vec))
|
rl_algo_impls/shared/actor/gridnet_decoder.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from numpy.typing import NDArray
|
7 |
+
|
8 |
+
from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
|
9 |
+
from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
|
10 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
11 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
12 |
+
|
13 |
+
|
14 |
+
class Transpose(nn.Module):
|
15 |
+
def __init__(self, permutation: Tuple[int, ...]) -> None:
|
16 |
+
super().__init__()
|
17 |
+
self.permutation = permutation
|
18 |
+
|
19 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
20 |
+
return x.permute(self.permutation)
|
21 |
+
|
22 |
+
|
23 |
+
class GridnetDecoder(Actor):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
map_size: int,
|
27 |
+
action_vec: NDArray[np.int64],
|
28 |
+
in_dim: EncoderOutDim,
|
29 |
+
activation: Type[nn.Module] = nn.ReLU,
|
30 |
+
init_layers_orthogonal: bool = True,
|
31 |
+
) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.map_size = map_size
|
34 |
+
self.action_vec = action_vec
|
35 |
+
assert isinstance(in_dim, tuple)
|
36 |
+
self.deconv = nn.Sequential(
|
37 |
+
layer_init(
|
38 |
+
nn.ConvTranspose2d(
|
39 |
+
in_dim[0], 128, 3, stride=2, padding=1, output_padding=1
|
40 |
+
),
|
41 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
42 |
+
),
|
43 |
+
activation(),
|
44 |
+
layer_init(
|
45 |
+
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
|
46 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
47 |
+
),
|
48 |
+
activation(),
|
49 |
+
layer_init(
|
50 |
+
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
|
51 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
52 |
+
),
|
53 |
+
activation(),
|
54 |
+
layer_init(
|
55 |
+
nn.ConvTranspose2d(
|
56 |
+
32, action_vec.sum(), 3, stride=2, padding=1, output_padding=1
|
57 |
+
),
|
58 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
59 |
+
std=0.01,
|
60 |
+
),
|
61 |
+
Transpose((0, 2, 3, 1)),
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
obs: torch.Tensor,
|
67 |
+
actions: Optional[torch.Tensor] = None,
|
68 |
+
action_masks: Optional[torch.Tensor] = None,
|
69 |
+
) -> PiForward:
|
70 |
+
assert (
|
71 |
+
action_masks is not None
|
72 |
+
), f"No mask case unhandled in {self.__class__.__name__}"
|
73 |
+
logits = self.deconv(obs)
|
74 |
+
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
75 |
+
return pi_forward(pi, actions)
|
76 |
+
|
77 |
+
@property
|
78 |
+
def action_shape(self) -> Tuple[int, ...]:
|
79 |
+
return (self.map_size, len(self.action_vec))
|
rl_algo_impls/shared/actor/make_actor.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
+
|
3 |
+
import gym
|
4 |
+
import torch.nn as nn
|
5 |
+
from gym.spaces import Box, Discrete, MultiDiscrete
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor.actor import Actor
|
8 |
+
from rl_algo_impls.shared.actor.categorical import CategoricalActorHead
|
9 |
+
from rl_algo_impls.shared.actor.gaussian import GaussianActorHead
|
10 |
+
from rl_algo_impls.shared.actor.gridnet import GridnetActorHead
|
11 |
+
from rl_algo_impls.shared.actor.gridnet_decoder import GridnetDecoder
|
12 |
+
from rl_algo_impls.shared.actor.multi_discrete import MultiDiscreteActorHead
|
13 |
+
from rl_algo_impls.shared.actor.state_dependent_noise import (
|
14 |
+
StateDependentNoiseActorHead,
|
15 |
+
)
|
16 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
17 |
+
|
18 |
+
|
19 |
+
def actor_head(
|
20 |
+
action_space: gym.Space,
|
21 |
+
in_dim: EncoderOutDim,
|
22 |
+
hidden_sizes: Tuple[int, ...],
|
23 |
+
init_layers_orthogonal: bool,
|
24 |
+
activation: Type[nn.Module],
|
25 |
+
log_std_init: float = -0.5,
|
26 |
+
use_sde: bool = False,
|
27 |
+
full_std: bool = True,
|
28 |
+
squash_output: bool = False,
|
29 |
+
actor_head_style: str = "single",
|
30 |
+
action_plane_space: Optional[bool] = None,
|
31 |
+
) -> Actor:
|
32 |
+
assert not use_sde or isinstance(
|
33 |
+
action_space, Box
|
34 |
+
), "use_sde only valid if Box action_space"
|
35 |
+
assert not squash_output or use_sde, "squash_output only valid if use_sde"
|
36 |
+
if isinstance(action_space, Discrete):
|
37 |
+
assert isinstance(in_dim, int)
|
38 |
+
return CategoricalActorHead(
|
39 |
+
action_space.n, # type: ignore
|
40 |
+
in_dim=in_dim,
|
41 |
+
hidden_sizes=hidden_sizes,
|
42 |
+
activation=activation,
|
43 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
44 |
+
)
|
45 |
+
elif isinstance(action_space, Box):
|
46 |
+
assert isinstance(in_dim, int)
|
47 |
+
if use_sde:
|
48 |
+
return StateDependentNoiseActorHead(
|
49 |
+
action_space.shape[0], # type: ignore
|
50 |
+
in_dim=in_dim,
|
51 |
+
hidden_sizes=hidden_sizes,
|
52 |
+
activation=activation,
|
53 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
54 |
+
log_std_init=log_std_init,
|
55 |
+
full_std=full_std,
|
56 |
+
squash_output=squash_output,
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
return GaussianActorHead(
|
60 |
+
action_space.shape[0], # type: ignore
|
61 |
+
in_dim=in_dim,
|
62 |
+
hidden_sizes=hidden_sizes,
|
63 |
+
activation=activation,
|
64 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
65 |
+
log_std_init=log_std_init,
|
66 |
+
)
|
67 |
+
elif isinstance(action_space, MultiDiscrete):
|
68 |
+
if actor_head_style == "single":
|
69 |
+
return MultiDiscreteActorHead(
|
70 |
+
action_space.nvec, # type: ignore
|
71 |
+
in_dim=in_dim,
|
72 |
+
hidden_sizes=hidden_sizes,
|
73 |
+
activation=activation,
|
74 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
75 |
+
)
|
76 |
+
elif actor_head_style == "gridnet":
|
77 |
+
assert isinstance(action_plane_space, MultiDiscrete)
|
78 |
+
return GridnetActorHead(
|
79 |
+
len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
|
80 |
+
action_plane_space.nvec, # type: ignore
|
81 |
+
in_dim=in_dim,
|
82 |
+
hidden_sizes=hidden_sizes,
|
83 |
+
activation=activation,
|
84 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
85 |
+
)
|
86 |
+
elif actor_head_style == "gridnet_decoder":
|
87 |
+
assert isinstance(action_plane_space, MultiDiscrete)
|
88 |
+
return GridnetDecoder(
|
89 |
+
len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
|
90 |
+
action_plane_space.nvec, # type: ignore
|
91 |
+
in_dim=in_dim,
|
92 |
+
activation=activation,
|
93 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
raise ValueError(f"Doesn't support actor_head_style {actor_head_style}")
|
97 |
+
else:
|
98 |
+
raise ValueError(f"Unsupported action space: {action_space}")
|
rl_algo_impls/shared/actor/multi_discrete.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Tuple, Type
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from numpy.typing import NDArray
|
7 |
+
from torch.distributions import Distribution, constraints
|
8 |
+
|
9 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
|
10 |
+
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
+
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
+
from rl_algo_impls.shared.module.utils import mlp
|
13 |
+
|
14 |
+
|
15 |
+
class MultiCategorical(Distribution):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
nvec: NDArray[np.int64],
|
19 |
+
probs=None,
|
20 |
+
logits=None,
|
21 |
+
validate_args=None,
|
22 |
+
masks: Optional[torch.Tensor] = None,
|
23 |
+
):
|
24 |
+
# Either probs or logits should be set
|
25 |
+
assert (probs is None) != (logits is None)
|
26 |
+
masks_split = (
|
27 |
+
torch.split(masks, nvec.tolist(), dim=1)
|
28 |
+
if masks is not None
|
29 |
+
else [None] * len(nvec)
|
30 |
+
)
|
31 |
+
if probs:
|
32 |
+
self.dists = [
|
33 |
+
MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
|
34 |
+
for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
|
35 |
+
]
|
36 |
+
param = probs
|
37 |
+
else:
|
38 |
+
assert logits is not None
|
39 |
+
self.dists = [
|
40 |
+
MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
|
41 |
+
for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
|
42 |
+
]
|
43 |
+
param = logits
|
44 |
+
batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
|
45 |
+
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
|
46 |
+
|
47 |
+
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
|
48 |
+
prob_stack = torch.stack(
|
49 |
+
[c.log_prob(a) for a, c in zip(action.T, self.dists)], dim=-1
|
50 |
+
)
|
51 |
+
return prob_stack.sum(dim=-1)
|
52 |
+
|
53 |
+
def entropy(self) -> torch.Tensor:
|
54 |
+
return torch.stack([c.entropy() for c in self.dists], dim=-1).sum(dim=-1)
|
55 |
+
|
56 |
+
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
|
57 |
+
return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def mode(self) -> torch.Tensor:
|
61 |
+
return torch.stack([c.mode for c in self.dists], dim=-1)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
|
65 |
+
# Constraints handled by child distributions in dist
|
66 |
+
return {}
|
67 |
+
|
68 |
+
|
69 |
+
class MultiDiscreteActorHead(Actor):
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
nvec: NDArray[np.int64],
|
73 |
+
in_dim: EncoderOutDim,
|
74 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
75 |
+
activation: Type[nn.Module] = nn.ReLU,
|
76 |
+
init_layers_orthogonal: bool = True,
|
77 |
+
) -> None:
|
78 |
+
super().__init__()
|
79 |
+
self.nvec = nvec
|
80 |
+
assert isinstance(in_dim, int)
|
81 |
+
layer_sizes = (in_dim,) + hidden_sizes + (nvec.sum(),)
|
82 |
+
self._fc = mlp(
|
83 |
+
layer_sizes,
|
84 |
+
activation,
|
85 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
86 |
+
final_layer_gain=0.01,
|
87 |
+
)
|
88 |
+
|
89 |
+
def forward(
|
90 |
+
self,
|
91 |
+
obs: torch.Tensor,
|
92 |
+
actions: Optional[torch.Tensor] = None,
|
93 |
+
action_masks: Optional[torch.Tensor] = None,
|
94 |
+
) -> PiForward:
|
95 |
+
logits = self._fc(obs)
|
96 |
+
pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
|
97 |
+
return pi_forward(pi, actions)
|
98 |
+
|
99 |
+
@property
|
100 |
+
def action_shape(self) -> Tuple[int, ...]:
|
101 |
+
return (len(self.nvec),)
|
rl_algo_impls/shared/actor/state_dependent_noise.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Type, TypeVar, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.distributions import Distribution, Normal
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
8 |
+
from rl_algo_impls.shared.module.utils import mlp
|
9 |
+
|
10 |
+
|
11 |
+
class TanhBijector:
|
12 |
+
def __init__(self, epsilon: float = 1e-6) -> None:
|
13 |
+
self.epsilon = epsilon
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def forward(x: torch.Tensor) -> torch.Tensor:
|
17 |
+
return torch.tanh(x)
|
18 |
+
|
19 |
+
@staticmethod
|
20 |
+
def inverse(y: torch.Tensor) -> torch.Tensor:
|
21 |
+
eps = torch.finfo(y.dtype).eps
|
22 |
+
clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps)
|
23 |
+
return torch.atanh(clamped_y)
|
24 |
+
|
25 |
+
def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
|
26 |
+
return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
|
27 |
+
|
28 |
+
|
29 |
+
def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor:
|
30 |
+
if len(tensor.shape) > 1:
|
31 |
+
return tensor.sum(dim=1)
|
32 |
+
return tensor.sum()
|
33 |
+
|
34 |
+
|
35 |
+
class StateDependentNoiseDistribution(Normal):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
loc,
|
39 |
+
scale,
|
40 |
+
latent_sde: torch.Tensor,
|
41 |
+
exploration_mat: torch.Tensor,
|
42 |
+
exploration_matrices: torch.Tensor,
|
43 |
+
bijector: Optional[TanhBijector] = None,
|
44 |
+
validate_args=None,
|
45 |
+
):
|
46 |
+
super().__init__(loc, scale, validate_args)
|
47 |
+
self.latent_sde = latent_sde
|
48 |
+
self.exploration_mat = exploration_mat
|
49 |
+
self.exploration_matrices = exploration_matrices
|
50 |
+
self.bijector = bijector
|
51 |
+
|
52 |
+
def log_prob(self, a: torch.Tensor) -> torch.Tensor:
|
53 |
+
gaussian_a = self.bijector.inverse(a) if self.bijector else a
|
54 |
+
log_prob = sum_independent_dims(super().log_prob(gaussian_a))
|
55 |
+
if self.bijector:
|
56 |
+
log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1)
|
57 |
+
return log_prob
|
58 |
+
|
59 |
+
def sample(self) -> torch.Tensor:
|
60 |
+
noise = self._get_noise()
|
61 |
+
actions = self.mean + noise
|
62 |
+
return self.bijector.forward(actions) if self.bijector else actions
|
63 |
+
|
64 |
+
def _get_noise(self) -> torch.Tensor:
|
65 |
+
if len(self.latent_sde) == 1 or len(self.latent_sde) != len(
|
66 |
+
self.exploration_matrices
|
67 |
+
):
|
68 |
+
return torch.mm(self.latent_sde, self.exploration_mat)
|
69 |
+
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
70 |
+
latent_sde = self.latent_sde.unsqueeze(dim=1)
|
71 |
+
# (batch_size, 1, n_actions)
|
72 |
+
noise = torch.bmm(latent_sde, self.exploration_matrices)
|
73 |
+
return noise.squeeze(dim=1)
|
74 |
+
|
75 |
+
@property
|
76 |
+
def mode(self) -> torch.Tensor:
|
77 |
+
mean = super().mode
|
78 |
+
return self.bijector.forward(mean) if self.bijector else mean
|
79 |
+
|
80 |
+
|
81 |
+
StateDependentNoiseActorHeadSelf = TypeVar(
|
82 |
+
"StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead"
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
class StateDependentNoiseActorHead(Actor):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
act_dim: int,
|
90 |
+
in_dim: int,
|
91 |
+
hidden_sizes: Tuple[int, ...] = (32,),
|
92 |
+
activation: Type[nn.Module] = nn.Tanh,
|
93 |
+
init_layers_orthogonal: bool = True,
|
94 |
+
log_std_init: float = -0.5,
|
95 |
+
full_std: bool = True,
|
96 |
+
squash_output: bool = False,
|
97 |
+
learn_std: bool = False,
|
98 |
+
) -> None:
|
99 |
+
super().__init__()
|
100 |
+
self.act_dim = act_dim
|
101 |
+
layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
|
102 |
+
if len(layer_sizes) == 2:
|
103 |
+
self.latent_net = nn.Identity()
|
104 |
+
elif len(layer_sizes) > 2:
|
105 |
+
self.latent_net = mlp(
|
106 |
+
layer_sizes[:-1],
|
107 |
+
activation,
|
108 |
+
output_activation=activation,
|
109 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
110 |
+
)
|
111 |
+
self.mu_net = mlp(
|
112 |
+
layer_sizes[-2:],
|
113 |
+
activation,
|
114 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
115 |
+
final_layer_gain=0.01,
|
116 |
+
)
|
117 |
+
self.full_std = full_std
|
118 |
+
std_dim = (layer_sizes[-2], act_dim if self.full_std else 1)
|
119 |
+
self.log_std = nn.Parameter(
|
120 |
+
torch.ones(std_dim, dtype=torch.float32) * log_std_init
|
121 |
+
)
|
122 |
+
self.bijector = TanhBijector() if squash_output else None
|
123 |
+
self.learn_std = learn_std
|
124 |
+
self.device = None
|
125 |
+
|
126 |
+
self.exploration_mat = None
|
127 |
+
self.exploration_matrices = None
|
128 |
+
self.sample_weights()
|
129 |
+
|
130 |
+
def to(
|
131 |
+
self: StateDependentNoiseActorHeadSelf,
|
132 |
+
device: Optional[torch.device] = None,
|
133 |
+
dtype: Optional[Union[torch.dtype, str]] = None,
|
134 |
+
non_blocking: bool = False,
|
135 |
+
) -> StateDependentNoiseActorHeadSelf:
|
136 |
+
super().to(device, dtype, non_blocking)
|
137 |
+
self.device = device
|
138 |
+
return self
|
139 |
+
|
140 |
+
def _distribution(self, obs: torch.Tensor) -> Distribution:
|
141 |
+
latent = self.latent_net(obs)
|
142 |
+
mu = self.mu_net(latent)
|
143 |
+
latent_sde = latent if self.learn_std else latent.detach()
|
144 |
+
variance = torch.mm(latent_sde**2, self._get_std() ** 2)
|
145 |
+
assert self.exploration_mat is not None
|
146 |
+
assert self.exploration_matrices is not None
|
147 |
+
return StateDependentNoiseDistribution(
|
148 |
+
mu,
|
149 |
+
torch.sqrt(variance + 1e-6),
|
150 |
+
latent_sde,
|
151 |
+
self.exploration_mat,
|
152 |
+
self.exploration_matrices,
|
153 |
+
self.bijector,
|
154 |
+
)
|
155 |
+
|
156 |
+
def _get_std(self) -> torch.Tensor:
|
157 |
+
std = torch.exp(self.log_std)
|
158 |
+
if self.full_std:
|
159 |
+
return std
|
160 |
+
ones = torch.ones(self.log_std.shape[0], self.act_dim)
|
161 |
+
if self.device:
|
162 |
+
ones = ones.to(self.device)
|
163 |
+
return ones * std
|
164 |
+
|
165 |
+
def forward(
|
166 |
+
self,
|
167 |
+
obs: torch.Tensor,
|
168 |
+
actions: Optional[torch.Tensor] = None,
|
169 |
+
action_masks: Optional[torch.Tensor] = None,
|
170 |
+
) -> PiForward:
|
171 |
+
assert (
|
172 |
+
not action_masks
|
173 |
+
), f"{self.__class__.__name__} does not support action_masks"
|
174 |
+
pi = self._distribution(obs)
|
175 |
+
return pi_forward(pi, actions)
|
176 |
+
|
177 |
+
def sample_weights(self, batch_size: int = 1) -> None:
|
178 |
+
std = self._get_std()
|
179 |
+
weights_dist = Normal(torch.zeros_like(std), std)
|
180 |
+
# Reparametrization trick to pass gradients
|
181 |
+
self.exploration_mat = weights_dist.rsample()
|
182 |
+
self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
|
183 |
+
|
184 |
+
@property
|
185 |
+
def action_shape(self) -> Tuple[int, ...]:
|
186 |
+
return (self.act_dim,)
|
187 |
+
|
188 |
+
|
189 |
+
def pi_forward(
|
190 |
+
distribution: Distribution, actions: Optional[torch.Tensor] = None
|
191 |
+
) -> PiForward:
|
192 |
+
logp_a = None
|
193 |
+
entropy = None
|
194 |
+
if actions is not None:
|
195 |
+
logp_a = distribution.log_prob(actions)
|
196 |
+
entropy = (
|
197 |
+
-logp_a if self.bijector else sum_independent_dims(distribution.entropy())
|
198 |
+
)
|
199 |
+
return PiForward(distribution, logp_a, entropy)
|
rl_algo_impls/shared/algorithm.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import List, Optional, TypeVar
|
3 |
+
|
4 |
+
import gym
|
5 |
+
import torch
|
6 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
7 |
+
|
8 |
+
from rl_algo_impls.shared.callbacks import Callback
|
9 |
+
from rl_algo_impls.shared.policy.policy import Policy
|
10 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
11 |
+
|
12 |
+
AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
|
13 |
+
|
14 |
+
|
15 |
+
class Algorithm(ABC):
|
16 |
+
@abstractmethod
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
policy: Policy,
|
20 |
+
env: VecEnv,
|
21 |
+
device: torch.device,
|
22 |
+
tb_writer: SummaryWriter,
|
23 |
+
**kwargs,
|
24 |
+
) -> None:
|
25 |
+
super().__init__()
|
26 |
+
self.policy = policy
|
27 |
+
self.env = env
|
28 |
+
self.device = device
|
29 |
+
self.tb_writer = tb_writer
|
30 |
+
|
31 |
+
@abstractmethod
|
32 |
+
def learn(
|
33 |
+
self: AlgorithmSelf,
|
34 |
+
train_timesteps: int,
|
35 |
+
callbacks: Optional[List[Callback]] = None,
|
36 |
+
total_timesteps: Optional[int] = None,
|
37 |
+
start_timesteps: int = 0,
|
38 |
+
) -> AlgorithmSelf:
|
39 |
+
...
|