sgoodfriend commited on
Commit
7058eb5
·
1 Parent(s): 2ee8433

PPO playing MicrortsDefeatCoacAIShaped-v3 from https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +147 -0
  2. LICENSE +21 -0
  3. README.md +206 -0
  4. benchmark_publish.py +4 -0
  5. colab/colab_atari1.sh +4 -0
  6. colab/colab_atari2.sh +4 -0
  7. colab/colab_basic.sh +4 -0
  8. colab/colab_benchmark.ipynb +195 -0
  9. colab/colab_carracing.sh +4 -0
  10. colab/colab_enjoy.ipynb +198 -0
  11. colab/colab_pybullet.sh +4 -0
  12. colab/colab_train.ipynb +200 -0
  13. compare_runs.py +4 -0
  14. enjoy.py +4 -0
  15. environment.yml +12 -0
  16. huggingface_publish.py +4 -0
  17. optimize.py +4 -0
  18. pyproject.toml +88 -0
  19. replay.meta.json +1 -0
  20. replay.mp4 +0 -0
  21. rl_algo_impls/a2c/a2c.py +205 -0
  22. rl_algo_impls/a2c/optimize.py +77 -0
  23. rl_algo_impls/benchmark_publish.py +111 -0
  24. rl_algo_impls/compare_runs.py +199 -0
  25. rl_algo_impls/dqn/dqn.py +189 -0
  26. rl_algo_impls/dqn/policy.py +62 -0
  27. rl_algo_impls/dqn/q_net.py +41 -0
  28. rl_algo_impls/enjoy.py +35 -0
  29. rl_algo_impls/huggingface_publish.py +193 -0
  30. rl_algo_impls/hyperparams/a2c.yml +142 -0
  31. rl_algo_impls/hyperparams/dqn.yml +130 -0
  32. rl_algo_impls/hyperparams/ppo.yml +616 -0
  33. rl_algo_impls/hyperparams/vpg.yml +197 -0
  34. rl_algo_impls/optimize.py +475 -0
  35. rl_algo_impls/ppo/ppo.py +385 -0
  36. rl_algo_impls/publish/markdown_format.py +210 -0
  37. rl_algo_impls/runner/config.py +203 -0
  38. rl_algo_impls/runner/evaluate.py +102 -0
  39. rl_algo_impls/runner/running_utils.py +199 -0
  40. rl_algo_impls/runner/train.py +161 -0
  41. rl_algo_impls/shared/actor/__init__.py +2 -0
  42. rl_algo_impls/shared/actor/actor.py +43 -0
  43. rl_algo_impls/shared/actor/categorical.py +64 -0
  44. rl_algo_impls/shared/actor/gaussian.py +61 -0
  45. rl_algo_impls/shared/actor/gridnet.py +108 -0
  46. rl_algo_impls/shared/actor/gridnet_decoder.py +79 -0
  47. rl_algo_impls/shared/actor/make_actor.py +98 -0
  48. rl_algo_impls/shared/actor/multi_discrete.py +101 -0
  49. rl_algo_impls/shared/actor/state_dependent_noise.py +199 -0
  50. rl_algo_impls/shared/algorithm.py +39 -0
.gitignore ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Logging into tensorboard and wandb
132
+ runs/*
133
+ wandb
134
+
135
+ # macOS
136
+ .DS_STORE
137
+
138
+ # Local scratch work
139
+ scratch/*
140
+
141
+ # vscode
142
+ .vscode/
143
+
144
+ # Don't bother tracking saved_models or videos
145
+ saved_models/*
146
+ downloaded_models/*
147
+ videos/*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Scott Goodfriend
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: rl-algo-impls
3
+ tags:
4
+ - MicrortsDefeatCoacAIShaped-v3
5
+ - ppo
6
+ - deep-reinforcement-learning
7
+ - reinforcement-learning
8
+ model-index:
9
+ - name: ppo
10
+ results:
11
+ - metrics:
12
+ - type: mean_reward
13
+ value: 0.69 +/- 0.72
14
+ name: mean_reward
15
+ task:
16
+ type: reinforcement-learning
17
+ name: reinforcement-learning
18
+ dataset:
19
+ name: MicrortsDefeatCoacAIShaped-v3
20
+ type: MicrortsDefeatCoacAIShaped-v3
21
+ ---
22
+ # **PPO** Agent playing **MicrortsDefeatCoacAIShaped-v3**
23
+
24
+ This is a trained model of a **PPO** agent playing **MicrortsDefeatCoacAIShaped-v3** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
+
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/lf7j0hrv.
27
+
28
+ ## Training Results
29
+
30
+ This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [4706d8d](https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
+
32
+ | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
+ |:-------|:------------------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | ppo | MicrortsDefeatCoacAIShaped-v3 | 1 | 0.461538 | 0.88712 | 26 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/arhv1foe) |
35
+ | ppo | MicrortsDefeatCoacAIShaped-v3 | 2 | 0.461538 | 0.84265 | 26 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/kd89zf31) |
36
+ | ppo | MicrortsDefeatCoacAIShaped-v3 | 3 | 0.692308 | 0.721602 | 26 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/1ak14nj4) |
37
+
38
+
39
+ ### Prerequisites: Weights & Biases (WandB)
40
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
41
+ By default training goes to a rl-algo-impls project while benchmarks go to
42
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
43
+ models and the model weights are uploaded to WandB.
44
+
45
+ Before doing anything below, you'll need to create a wandb account and run `wandb
46
+ login`.
47
+
48
+
49
+
50
+ ## Usage
51
+ /sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
52
+
53
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
54
+ implementation could be sufficiently different to not be able to reproduce similar
55
+ results. You might need to checkout the commit the agent was trained on:
56
+ [4706d8d](https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f).
57
+ ```
58
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/1ak14nj4
60
+ ```
61
+
62
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
63
+ Colab starting from the
64
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
65
+ notebook.
66
+
67
+
68
+
69
+ ## Training
70
+ If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [4706d8d](https://github.com/sgoodfriend/rl-algo-impls/tree/4706d8dbb99b38e70d080c3de68d0751ea585a2f). While
72
+ training is deterministic, different hardware will give different results.
73
+
74
+ ```
75
+ python train.py --algo ppo --env MicrortsDefeatCoacAIShaped-v3 --seed 3
76
+ ```
77
+
78
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
79
+ Colab starting from the
80
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
81
+ notebook.
82
+
83
+
84
+
85
+ ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/lf7j0hrv were generated by running a script on a Lambda
87
+ Labs instance. In a Lambda Labs instance terminal:
88
+ ```
89
+ git clone git@github.com:sgoodfriend/rl-algo-impls.git
90
+ cd rl-algo-impls
91
+ bash ./lambda_labs/setup.sh
92
+ wandb login
93
+ bash ./lambda_labs/benchmark.sh [-a {"ppo a2c dqn vpg"}] [-e ENVS] [-j {6}] [-p {rl-algo-impls-benchmarks}] [-s {"1 2 3"}]
94
+ ```
95
+
96
+ ### Alternative: Google Colab Pro+
97
+ As an alternative,
98
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
99
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
100
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
101
+
102
+
103
+
104
+ ## Hyperparameters
105
+ This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
106
+ close and has some additional data:
107
+ ```
108
+ additional_keys_to_log:
109
+ - microrts_stats
110
+ algo: ppo
111
+ algo_hyperparams:
112
+ batch_size: 3072
113
+ clip_range: 0.1
114
+ clip_range_decay: none
115
+ clip_range_vf: 0.1
116
+ ent_coef: 0.01
117
+ learning_rate: 0.00025
118
+ learning_rate_decay: spike
119
+ max_grad_norm: 0.5
120
+ n_epochs: 4
121
+ n_steps: 512
122
+ ppo2_vf_coef_halving: true
123
+ vf_coef: 0.5
124
+ device: auto
125
+ env: Microrts-selfplay-unet
126
+ env_hyperparams:
127
+ env_type: microrts
128
+ make_kwargs:
129
+ map_paths:
130
+ - maps/16x16/basesWorkers16x16.xml
131
+ max_steps: 2000
132
+ num_selfplay_envs: 36
133
+ render_theme: 2
134
+ reward_weight:
135
+ - 10
136
+ - 1
137
+ - 1
138
+ - 0.2
139
+ - 1
140
+ - 4
141
+ n_envs: 24
142
+ self_play_kwargs:
143
+ num_old_policies: 12
144
+ save_steps: 200000
145
+ swap_steps: 10000
146
+ swap_window_size: 4
147
+ window: 25
148
+ env_id: MicrortsDefeatCoacAIShaped-v3
149
+ eval_hyperparams:
150
+ deterministic: false
151
+ env_overrides:
152
+ bots:
153
+ coacAI: 2
154
+ droplet: 2
155
+ guidedRojoA3N: 2
156
+ izanagi: 2
157
+ lightRushAI: 2
158
+ mixedBot: 2
159
+ naiveMCTSAI: 2
160
+ passiveAI: 2
161
+ randomAI: 2
162
+ randomBiasedAI: 2
163
+ rojo: 2
164
+ tiamat: 2
165
+ workerRushAI: 2
166
+ make_kwargs:
167
+ map_paths:
168
+ - maps/16x16/basesWorkers16x16.xml
169
+ max_steps: 4000
170
+ num_selfplay_envs: 0
171
+ render_theme: 2
172
+ reward_weight:
173
+ - 1
174
+ - 0
175
+ - 0
176
+ - 0
177
+ - 0
178
+ - 0
179
+ n_envs: 26
180
+ self_play_kwargs: {}
181
+ max_video_length: 4000
182
+ n_episodes: 26
183
+ score_function: mean
184
+ step_freq: 1000000
185
+ microrts_reward_decay_callback: false
186
+ n_timesteps: 300000000
187
+ policy_hyperparams:
188
+ activation_fn: relu
189
+ actor_head_style: unet
190
+ cnn_flatten_dim: 256
191
+ cnn_style: microrts
192
+ v_hidden_sizes:
193
+ - 256
194
+ - 128
195
+ seed: 3
196
+ use_deterministic_algorithms: true
197
+ wandb_entity: null
198
+ wandb_group: null
199
+ wandb_project_name: rl-algo-impls-benchmarks
200
+ wandb_tags:
201
+ - benchmark_4706d8d
202
+ - host_192-9-146-21
203
+ - branch_selfplay
204
+ - v0.0.9
205
+
206
+ ```
benchmark_publish.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from rl_algo_impls.benchmark_publish import benchmark_publish
2
+
3
+ if __name__ == "__main__":
4
+ benchmark_publish()
colab/colab_atari1.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_atari2.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_basic.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_benchmark.ipynb ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "## Setup\n",
63
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
64
+ ],
65
+ "metadata": {
66
+ "id": "bsG35Io0hmKG"
67
+ }
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "%%capture\n",
73
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
74
+ ],
75
+ "metadata": {
76
+ "id": "k5ynTV25hdAf"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "Installing the correct packages:\n",
85
+ "\n",
86
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
87
+ ],
88
+ "metadata": {
89
+ "id": "jKxGok-ElYQ7"
90
+ }
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "source": [
95
+ "%%capture\n",
96
+ "!apt install python-opengl\n",
97
+ "!apt install ffmpeg\n",
98
+ "!apt install xvfb\n",
99
+ "!apt install swig"
100
+ ],
101
+ "metadata": {
102
+ "id": "nn6EETTc2Ewf"
103
+ },
104
+ "execution_count": null,
105
+ "outputs": []
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "source": [
110
+ "%%capture\n",
111
+ "%cd /content/rl-algo-impls\n",
112
+ "python -m pip install ."
113
+ ],
114
+ "metadata": {
115
+ "id": "AfZh9rH3yQii"
116
+ },
117
+ "execution_count": null,
118
+ "outputs": []
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "source": [
123
+ "## Run Once Per Runtime"
124
+ ],
125
+ "metadata": {
126
+ "id": "4o5HOLjc4wq7"
127
+ }
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "import wandb\n",
133
+ "wandb.login()"
134
+ ],
135
+ "metadata": {
136
+ "id": "PCXa5tdS2qFX"
137
+ },
138
+ "execution_count": null,
139
+ "outputs": []
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "source": [
144
+ "## Restart Session beteween runs"
145
+ ],
146
+ "metadata": {
147
+ "id": "AZBZfSUV43JQ"
148
+ }
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "source": [
153
+ "%%capture\n",
154
+ "from pyvirtualdisplay import Display\n",
155
+ "\n",
156
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
157
+ "virtual_display.start()"
158
+ ],
159
+ "metadata": {
160
+ "id": "VzemeQJP2NO9"
161
+ },
162
+ "execution_count": null,
163
+ "outputs": []
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "source": [
168
+ "The below 5 bash scripts train agents on environments with 3 seeds each:\n",
169
+ "- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
170
+ "- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
171
+ "- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
172
+ ],
173
+ "metadata": {
174
+ "id": "nSHfna0hLlO1"
175
+ }
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "source": [
180
+ "%cd /content/rl-algo-impls\n",
181
+ "os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
182
+ "!./benchmarks/colab_basic.sh\n",
183
+ "!./benchmarks/colab_pybullet.sh\n",
184
+ "# !./benchmarks/colab_carracing.sh\n",
185
+ "# !./benchmarks/colab_atari1.sh\n",
186
+ "# !./benchmarks/colab_atari2.sh"
187
+ ],
188
+ "metadata": {
189
+ "id": "07aHYFH1zfXa"
190
+ },
191
+ "execution_count": null,
192
+ "outputs": []
193
+ }
194
+ ]
195
+ }
colab/colab_carracing.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="CarRacing-v0"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_enjoy.ipynb ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. enjoy.py parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
72
+ ],
73
+ "metadata": {
74
+ "id": "jKL_NFhVOjSc"
75
+ },
76
+ "execution_count": 2,
77
+ "outputs": []
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "source": [
82
+ "## Setup\n",
83
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
84
+ ],
85
+ "metadata": {
86
+ "id": "bsG35Io0hmKG"
87
+ }
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "source": [
92
+ "%%capture\n",
93
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
94
+ ],
95
+ "metadata": {
96
+ "id": "k5ynTV25hdAf"
97
+ },
98
+ "execution_count": 3,
99
+ "outputs": []
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "source": [
104
+ "Installing the correct packages:\n",
105
+ "\n",
106
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
107
+ ],
108
+ "metadata": {
109
+ "id": "jKxGok-ElYQ7"
110
+ }
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "source": [
115
+ "%%capture\n",
116
+ "!apt install python-opengl\n",
117
+ "!apt install ffmpeg\n",
118
+ "!apt install xvfb\n",
119
+ "!apt install swig"
120
+ ],
121
+ "metadata": {
122
+ "id": "nn6EETTc2Ewf"
123
+ },
124
+ "execution_count": 4,
125
+ "outputs": []
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "source": [
130
+ "%%capture\n",
131
+ "%cd /content/rl-algo-impls\n",
132
+ "python -m pip install ."
133
+ ],
134
+ "metadata": {
135
+ "id": "AfZh9rH3yQii"
136
+ },
137
+ "execution_count": 5,
138
+ "outputs": []
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "source": [
143
+ "## Run Once Per Runtime"
144
+ ],
145
+ "metadata": {
146
+ "id": "4o5HOLjc4wq7"
147
+ }
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "source": [
152
+ "import wandb\n",
153
+ "wandb.login()"
154
+ ],
155
+ "metadata": {
156
+ "id": "PCXa5tdS2qFX"
157
+ },
158
+ "execution_count": null,
159
+ "outputs": []
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "source": [
164
+ "## Restart Session beteween runs"
165
+ ],
166
+ "metadata": {
167
+ "id": "AZBZfSUV43JQ"
168
+ }
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "source": [
173
+ "%%capture\n",
174
+ "from pyvirtualdisplay import Display\n",
175
+ "\n",
176
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
177
+ "virtual_display.start()"
178
+ ],
179
+ "metadata": {
180
+ "id": "VzemeQJP2NO9"
181
+ },
182
+ "execution_count": 7,
183
+ "outputs": []
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "source": [
188
+ "%cd /content/rl-algo-impls\n",
189
+ "!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
190
+ ],
191
+ "metadata": {
192
+ "id": "07aHYFH1zfXa"
193
+ },
194
+ "execution_count": null,
195
+ "outputs": []
196
+ }
197
+ ]
198
+ }
colab/colab_pybullet.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_train.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. train run parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "ALGO = \"ppo\"\n",
72
+ "ENV = \"CartPole-v1\"\n",
73
+ "SEED = 1"
74
+ ],
75
+ "metadata": {
76
+ "id": "jKL_NFhVOjSc"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "## Setup\n",
85
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
86
+ ],
87
+ "metadata": {
88
+ "id": "bsG35Io0hmKG"
89
+ }
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "source": [
94
+ "%%capture\n",
95
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
96
+ ],
97
+ "metadata": {
98
+ "id": "k5ynTV25hdAf"
99
+ },
100
+ "execution_count": null,
101
+ "outputs": []
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "source": [
106
+ "Installing the correct packages:\n",
107
+ "\n",
108
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
109
+ ],
110
+ "metadata": {
111
+ "id": "jKxGok-ElYQ7"
112
+ }
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "source": [
117
+ "%%capture\n",
118
+ "!apt install python-opengl\n",
119
+ "!apt install ffmpeg\n",
120
+ "!apt install xvfb\n",
121
+ "!apt install swig"
122
+ ],
123
+ "metadata": {
124
+ "id": "nn6EETTc2Ewf"
125
+ },
126
+ "execution_count": null,
127
+ "outputs": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "%%capture\n",
133
+ "%cd /content/rl-algo-impls\n",
134
+ "python -m pip install ."
135
+ ],
136
+ "metadata": {
137
+ "id": "AfZh9rH3yQii"
138
+ },
139
+ "execution_count": null,
140
+ "outputs": []
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "source": [
145
+ "## Run Once Per Runtime"
146
+ ],
147
+ "metadata": {
148
+ "id": "4o5HOLjc4wq7"
149
+ }
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "source": [
154
+ "import wandb\n",
155
+ "wandb.login()"
156
+ ],
157
+ "metadata": {
158
+ "id": "PCXa5tdS2qFX"
159
+ },
160
+ "execution_count": null,
161
+ "outputs": []
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "source": [
166
+ "## Restart Session beteween runs"
167
+ ],
168
+ "metadata": {
169
+ "id": "AZBZfSUV43JQ"
170
+ }
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "source": [
175
+ "%%capture\n",
176
+ "from pyvirtualdisplay import Display\n",
177
+ "\n",
178
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
179
+ "virtual_display.start()"
180
+ ],
181
+ "metadata": {
182
+ "id": "VzemeQJP2NO9"
183
+ },
184
+ "execution_count": null,
185
+ "outputs": []
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "source": [
190
+ "%cd /content/rl-algo-impls\n",
191
+ "!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
192
+ ],
193
+ "metadata": {
194
+ "id": "07aHYFH1zfXa"
195
+ },
196
+ "execution_count": null,
197
+ "outputs": []
198
+ }
199
+ ]
200
+ }
compare_runs.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from rl_algo_impls.compare_runs import compare_runs
2
+
3
+ if __name__ == "__main__":
4
+ compare_runs()
enjoy.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from rl_algo_impls.enjoy import enjoy
2
+
3
+ if __name__ == "__main__":
4
+ enjoy()
environment.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: rl_algo_impls
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - nodefaults
6
+ dependencies:
7
+ - python>=3.8, <3.10
8
+ - mamba
9
+ - pip
10
+ - pytorch
11
+ - torchvision
12
+ - torchaudio
huggingface_publish.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from rl_algo_impls.huggingface_publish import huggingface_publish
2
+
3
+ if __name__ == "__main__":
4
+ huggingface_publish()
optimize.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from rl_algo_impls.optimize import optimize
2
+
3
+ if __name__ == "__main__":
4
+ optimize()
pyproject.toml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rl_algo_impls"
3
+ version = "0.0.9"
4
+ description = "Implementations of reinforcement learning algorithms"
5
+ authors = [
6
+ {name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"},
7
+ ]
8
+ license = {file = "LICENSE"}
9
+ readme = "README.md"
10
+ requires-python = ">= 3.8"
11
+ classifiers = [
12
+ "License :: OSI Approved :: MIT License",
13
+ "Development Status :: 3 - Alpha",
14
+ "Programming Language :: Python :: 3.8",
15
+ "Programming Language :: Python :: 3.9",
16
+ "Programming Language :: Python :: 3.10",
17
+ ]
18
+ dependencies = [
19
+ "cmake",
20
+ "swig",
21
+ "scipy",
22
+ "torch",
23
+ "torchvision",
24
+ "tensorboard >= 2.11.2, < 2.12",
25
+ "AutoROM.accept-rom-license >= 0.4.2, < 0.5",
26
+ "stable-baselines3[extra] >= 1.7.0, < 1.8",
27
+ "gym[box2d] >= 0.21.0, < 0.22",
28
+ "pyglet == 1.5.27",
29
+ "wandb",
30
+ "pyvirtualdisplay",
31
+ "pybullet",
32
+ "tabulate",
33
+ "huggingface-hub",
34
+ "optuna",
35
+ "dash",
36
+ "kaleido",
37
+ "PyYAML",
38
+ "scikit-learn",
39
+ ]
40
+
41
+ [tool.setuptools]
42
+ packages = ["rl_algo_impls"]
43
+
44
+ [project.optional-dependencies]
45
+ test = [
46
+ "pytest",
47
+ "black",
48
+ "mypy",
49
+ "flake8",
50
+ "flake8-bugbear",
51
+ "isort",
52
+ ]
53
+ procgen = [
54
+ "numexpr >= 2.8.4",
55
+ "gym3",
56
+ "glfw >= 1.12.0, < 1.13",
57
+ "procgen; platform_machine=='x86_64'",
58
+ ]
59
+ microrts-ppo = [
60
+ "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
61
+ "gym-microrts == 0.2.0", # Match ppo-implementation-details
62
+ ]
63
+ microrts-paper = [
64
+ "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
65
+ "gym-microrts == 0.3.2",
66
+ ]
67
+ microrts = [
68
+ "gym-microrts",
69
+ ]
70
+ jupyter = [
71
+ "jupyter",
72
+ "notebook"
73
+ ]
74
+ all = [
75
+ "rl-algo-impls[test]",
76
+ "rl-algo-impls[procgen]",
77
+ "rl-algo-impls[microrts]",
78
+ ]
79
+
80
+ [project.urls]
81
+ "Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
82
+
83
+ [build-system]
84
+ requires = ["setuptools==65.5.0", "setuptools-scm"]
85
+ build-backend = "setuptools.build_meta"
86
+
87
+ [tool.isort]
88
+ profile = "black"
replay.meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "640x640", "-pix_fmt", "rgb24", "-framerate", "150", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "150", "/tmp/tmpiof1jwbr/ppo-Microrts-selfplay-unet/replay.mp4"]}, "episode": {"r": 1.0, "l": 794, "t": 10.908112}}
replay.mp4 ADDED
Binary file (322 kB). View file
 
rl_algo_impls/a2c/a2c.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from time import perf_counter
3
+ from typing import List, Optional, TypeVar
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.tensorboard.writer import SummaryWriter
10
+
11
+ from rl_algo_impls.shared.algorithm import Algorithm
12
+ from rl_algo_impls.shared.callbacks import Callback
13
+ from rl_algo_impls.shared.gae import compute_advantages
14
+ from rl_algo_impls.shared.policy.actor_critic import ActorCritic
15
+ from rl_algo_impls.shared.schedule import schedule, update_learning_rate
16
+ from rl_algo_impls.shared.stats import log_scalars
17
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
18
+ VecEnv,
19
+ single_action_space,
20
+ single_observation_space,
21
+ )
22
+
23
+ A2CSelf = TypeVar("A2CSelf", bound="A2C")
24
+
25
+
26
+ class A2C(Algorithm):
27
+ def __init__(
28
+ self,
29
+ policy: ActorCritic,
30
+ env: VecEnv,
31
+ device: torch.device,
32
+ tb_writer: SummaryWriter,
33
+ learning_rate: float = 7e-4,
34
+ learning_rate_decay: str = "none",
35
+ n_steps: int = 5,
36
+ gamma: float = 0.99,
37
+ gae_lambda: float = 1.0,
38
+ ent_coef: float = 0.0,
39
+ ent_coef_decay: str = "none",
40
+ vf_coef: float = 0.5,
41
+ max_grad_norm: float = 0.5,
42
+ rms_prop_eps: float = 1e-5,
43
+ use_rms_prop: bool = True,
44
+ sde_sample_freq: int = -1,
45
+ normalize_advantage: bool = False,
46
+ ) -> None:
47
+ super().__init__(policy, env, device, tb_writer)
48
+ self.policy = policy
49
+
50
+ self.lr_schedule = schedule(learning_rate_decay, learning_rate)
51
+ if use_rms_prop:
52
+ self.optimizer = torch.optim.RMSprop(
53
+ policy.parameters(), lr=learning_rate, eps=rms_prop_eps
54
+ )
55
+ else:
56
+ self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
57
+
58
+ self.n_steps = n_steps
59
+
60
+ self.gamma = gamma
61
+ self.gae_lambda = gae_lambda
62
+
63
+ self.vf_coef = vf_coef
64
+ self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
65
+ self.max_grad_norm = max_grad_norm
66
+
67
+ self.sde_sample_freq = sde_sample_freq
68
+ self.normalize_advantage = normalize_advantage
69
+
70
+ def learn(
71
+ self: A2CSelf,
72
+ train_timesteps: int,
73
+ callbacks: Optional[List[Callback]] = None,
74
+ total_timesteps: Optional[int] = None,
75
+ start_timesteps: int = 0,
76
+ ) -> A2CSelf:
77
+ if total_timesteps is None:
78
+ total_timesteps = train_timesteps
79
+ assert start_timesteps + train_timesteps <= total_timesteps
80
+ epoch_dim = (self.n_steps, self.env.num_envs)
81
+ step_dim = (self.env.num_envs,)
82
+ obs_space = single_observation_space(self.env)
83
+ act_space = single_action_space(self.env)
84
+
85
+ obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
86
+ actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
87
+ rewards = np.zeros(epoch_dim, dtype=np.float32)
88
+ episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
89
+ values = np.zeros(epoch_dim, dtype=np.float32)
90
+ logprobs = np.zeros(epoch_dim, dtype=np.float32)
91
+
92
+ next_obs = self.env.reset()
93
+ next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
94
+
95
+ timesteps_elapsed = start_timesteps
96
+ while timesteps_elapsed < start_timesteps + train_timesteps:
97
+ start_time = perf_counter()
98
+
99
+ progress = timesteps_elapsed / total_timesteps
100
+ ent_coef = self.ent_coef_schedule(progress)
101
+ learning_rate = self.lr_schedule(progress)
102
+ update_learning_rate(self.optimizer, learning_rate)
103
+ log_scalars(
104
+ self.tb_writer,
105
+ "charts",
106
+ {
107
+ "ent_coef": ent_coef,
108
+ "learning_rate": learning_rate,
109
+ },
110
+ timesteps_elapsed,
111
+ )
112
+
113
+ self.policy.eval()
114
+ self.policy.reset_noise()
115
+ for s in range(self.n_steps):
116
+ timesteps_elapsed += self.env.num_envs
117
+ if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
118
+ self.policy.reset_noise()
119
+
120
+ obs[s] = next_obs
121
+ episode_starts[s] = next_episode_starts
122
+
123
+ actions[s], values[s], logprobs[s], clamped_action = self.policy.step(
124
+ next_obs
125
+ )
126
+ next_obs, rewards[s], next_episode_starts, _ = self.env.step(
127
+ clamped_action
128
+ )
129
+
130
+ advantages = compute_advantages(
131
+ rewards,
132
+ values,
133
+ episode_starts,
134
+ next_episode_starts,
135
+ next_obs,
136
+ self.policy,
137
+ self.gamma,
138
+ self.gae_lambda,
139
+ )
140
+ returns = advantages + values
141
+
142
+ b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
143
+ b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to(
144
+ self.device
145
+ )
146
+ b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
147
+ b_returns = torch.tensor(returns.reshape(-1)).to(self.device)
148
+
149
+ if self.normalize_advantage:
150
+ b_advantages = (b_advantages - b_advantages.mean()) / (
151
+ b_advantages.std() + 1e-8
152
+ )
153
+
154
+ self.policy.train()
155
+ logp_a, entropy, v = self.policy(b_obs, b_actions)
156
+
157
+ pi_loss = -(b_advantages * logp_a).mean()
158
+ value_loss = F.mse_loss(b_returns, v)
159
+ entropy_loss = -entropy.mean()
160
+
161
+ loss = pi_loss + self.vf_coef * value_loss + ent_coef * entropy_loss
162
+
163
+ self.optimizer.zero_grad()
164
+ loss.backward()
165
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
166
+ self.optimizer.step()
167
+
168
+ y_pred = values.reshape(-1)
169
+ y_true = returns.reshape(-1)
170
+ var_y = np.var(y_true).item()
171
+ explained_var = (
172
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
173
+ )
174
+
175
+ end_time = perf_counter()
176
+ rollout_steps = self.n_steps * self.env.num_envs
177
+ self.tb_writer.add_scalar(
178
+ "train/steps_per_second",
179
+ (rollout_steps) / (end_time - start_time),
180
+ timesteps_elapsed,
181
+ )
182
+
183
+ log_scalars(
184
+ self.tb_writer,
185
+ "losses",
186
+ {
187
+ "loss": loss.item(),
188
+ "pi_loss": pi_loss.item(),
189
+ "v_loss": value_loss.item(),
190
+ "entropy_loss": entropy_loss.item(),
191
+ "explained_var": explained_var,
192
+ },
193
+ timesteps_elapsed,
194
+ )
195
+
196
+ if callbacks:
197
+ if not all(
198
+ c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
199
+ ):
200
+ logging.info(
201
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
202
+ )
203
+ break
204
+
205
+ return self
rl_algo_impls/a2c/optimize.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+
3
+ from copy import deepcopy
4
+
5
+ from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
6
+ from rl_algo_impls.shared.vec_env import make_eval_env
7
+ from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
8
+ from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
9
+
10
+
11
+ def sample_params(
12
+ trial: optuna.Trial,
13
+ base_hyperparams: Hyperparams,
14
+ base_config: Config,
15
+ ) -> Hyperparams:
16
+ hyperparams = deepcopy(base_hyperparams)
17
+
18
+ base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
19
+ env = make_eval_env(base_config, base_env_hyperparams, override_n_envs=1)
20
+
21
+ # env_hyperparams
22
+ env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
23
+
24
+ # policy_hyperparams
25
+ policy_hyperparams = sample_on_policy_hyperparams(
26
+ trial, hyperparams.policy_hyperparams, env
27
+ )
28
+
29
+ # algo_hyperparams
30
+ algo_hyperparams = hyperparams.algo_hyperparams
31
+
32
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 2e-3, log=True)
33
+ learning_rate_decay = trial.suggest_categorical(
34
+ "learning_rate_decay", ["none", "linear"]
35
+ )
36
+ n_steps_exp = trial.suggest_int("n_steps_exp", 1, 10)
37
+ n_steps = 2**n_steps_exp
38
+ trial.set_user_attr("n_steps", n_steps)
39
+ gamma = 1.0 - trial.suggest_float("gamma_om", 1e-4, 1e-1, log=True)
40
+ trial.set_user_attr("gamma", gamma)
41
+ gae_lambda = 1 - trial.suggest_float("gae_lambda_om", 1e-4, 1e-1)
42
+ trial.set_user_attr("gae_lambda", gae_lambda)
43
+ ent_coef = trial.suggest_float("ent_coef", 1e-8, 2.5e-2, log=True)
44
+ ent_coef_decay = trial.suggest_categorical("ent_coef_decay", ["none", "linear"])
45
+ vf_coef = trial.suggest_float("vf_coef", 0.1, 0.7)
46
+ max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e1, log=True)
47
+ use_rms_prop = trial.suggest_categorical("use_rms_prop", [True, False])
48
+ normalize_advantage = trial.suggest_categorical(
49
+ "normalize_advantage", [True, False]
50
+ )
51
+
52
+ algo_hyperparams.update(
53
+ {
54
+ "learning_rate": learning_rate,
55
+ "learning_rate_decay": learning_rate_decay,
56
+ "n_steps": n_steps,
57
+ "gamma": gamma,
58
+ "gae_lambda": gae_lambda,
59
+ "ent_coef": ent_coef,
60
+ "ent_coef_decay": ent_coef_decay,
61
+ "vf_coef": vf_coef,
62
+ "max_grad_norm": max_grad_norm,
63
+ "use_rms_prop": use_rms_prop,
64
+ "normalize_advantage": normalize_advantage,
65
+ }
66
+ )
67
+
68
+ if policy_hyperparams.get("use_sde", False):
69
+ sde_sample_freq = 2 ** trial.suggest_int("sde_sample_freq_exp", 0, n_steps_exp)
70
+ trial.set_user_attr("sde_sample_freq", sde_sample_freq)
71
+ algo_hyperparams["sde_sample_freq"] = sde_sample_freq
72
+ elif "sde_sample_freq" in algo_hyperparams:
73
+ del algo_hyperparams["sde_sample_freq"]
74
+
75
+ env.close()
76
+
77
+ return hyperparams
rl_algo_impls/benchmark_publish.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import wandb
4
+ import wandb.apis.public
5
+
6
+ from collections import defaultdict
7
+ from multiprocessing.pool import ThreadPool
8
+ from typing import List, NamedTuple
9
+
10
+
11
+ class RunGroup(NamedTuple):
12
+ algo: str
13
+ env_id: str
14
+
15
+
16
+ def benchmark_publish() -> None:
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--wandb-project-name",
20
+ type=str,
21
+ default="rl-algo-impls-benchmarks",
22
+ help="WandB project name to load runs from",
23
+ )
24
+ parser.add_argument(
25
+ "--wandb-entity",
26
+ type=str,
27
+ default=None,
28
+ help="WandB team of project. None uses default entity",
29
+ )
30
+ parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
31
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
32
+ parser.add_argument(
33
+ "--envs", type=str, nargs="*", help="Optional filter down to these envs"
34
+ )
35
+ parser.add_argument(
36
+ "--exclude-envs",
37
+ type=str,
38
+ nargs="*",
39
+ help="Environments to exclude from publishing",
40
+ )
41
+ parser.add_argument(
42
+ "--huggingface-user",
43
+ type=str,
44
+ default=None,
45
+ help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
46
+ )
47
+ parser.add_argument(
48
+ "--pool-size",
49
+ type=int,
50
+ default=3,
51
+ help="How many publish jobs can run in parallel",
52
+ )
53
+ parser.add_argument(
54
+ "--virtual-display", action="store_true", help="Use headless virtual display"
55
+ )
56
+ # parser.set_defaults(
57
+ # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
58
+ # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
59
+ # envs=[],
60
+ # exclude_envs=[],
61
+ # )
62
+ args = parser.parse_args()
63
+ print(args)
64
+
65
+ api = wandb.Api()
66
+ all_runs = api.runs(
67
+ f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
68
+ )
69
+
70
+ required_tags = set(args.wandb_tags)
71
+ runs: List[wandb.apis.public.Run] = [
72
+ r
73
+ for r in all_runs
74
+ if required_tags.issubset(set(r.config.get("wandb_tags", [])))
75
+ ]
76
+
77
+ runs_paths_by_group = defaultdict(list)
78
+ for r in runs:
79
+ if r.state != "finished":
80
+ continue
81
+ algo = r.config["algo"]
82
+ env = r.config["env"]
83
+ if args.envs and env not in args.envs:
84
+ continue
85
+ if args.exclude_envs and env in args.exclude_envs:
86
+ continue
87
+ run_group = RunGroup(algo, env)
88
+ runs_paths_by_group[run_group].append("/".join(r.path))
89
+
90
+ def run(run_paths: List[str]) -> None:
91
+ publish_args = ["python", "huggingface_publish.py"]
92
+ publish_args.append("--wandb-run-paths")
93
+ publish_args.extend(run_paths)
94
+ publish_args.append("--wandb-report-url")
95
+ publish_args.append(args.wandb_report_url)
96
+ if args.huggingface_user:
97
+ publish_args.append("--huggingface-user")
98
+ publish_args.append(args.huggingface_user)
99
+ if args.virtual_display:
100
+ publish_args.append("--virtual-display")
101
+ subprocess.run(publish_args)
102
+
103
+ tp = ThreadPool(args.pool_size)
104
+ for run_paths in runs_paths_by_group.values():
105
+ tp.apply_async(run, (run_paths,))
106
+ tp.close()
107
+ tp.join()
108
+
109
+
110
+ if __name__ == "__main__":
111
+ benchmark_publish()
rl_algo_impls/compare_runs.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import numpy as np
4
+ import pandas as pd
5
+ import wandb
6
+ import wandb.apis.public
7
+
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Iterable, List, TypeVar
11
+
12
+ from rl_algo_impls.benchmark_publish import RunGroup
13
+
14
+
15
+ @dataclass
16
+ class Comparison:
17
+ control_values: List[float]
18
+ experiment_values: List[float]
19
+
20
+ def mean_diff_percentage(self) -> float:
21
+ return self._diff_percentage(
22
+ np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
23
+ )
24
+
25
+ def median_diff_percentage(self) -> float:
26
+ return self._diff_percentage(
27
+ np.median(self.control_values).item(),
28
+ np.median(self.experiment_values).item(),
29
+ )
30
+
31
+ def _diff_percentage(self, c: float, e: float) -> float:
32
+ if c == e:
33
+ return 0
34
+ elif c == 0:
35
+ return float("inf") if e > 0 else float("-inf")
36
+ return 100 * (e - c) / c
37
+
38
+ def score(self) -> float:
39
+ return (
40
+ np.sum(
41
+ np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
42
+ ).item()
43
+ / 2
44
+ )
45
+
46
+
47
+ RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
48
+
49
+
50
+ class RunGroupRuns:
51
+ def __init__(
52
+ self,
53
+ run_group: RunGroup,
54
+ control: List[str],
55
+ experiment: List[str],
56
+ summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
57
+ summary_metrics: List[str] = ["mean", "result"],
58
+ ) -> None:
59
+ self.algo = run_group.algo
60
+ self.env = run_group.env_id
61
+ self.control = set(control)
62
+ self.experiment = set(experiment)
63
+
64
+ self.summary_stats = summary_stats
65
+ self.summary_metrics = summary_metrics
66
+
67
+ self.control_runs = []
68
+ self.experiment_runs = []
69
+
70
+ def add_run(self, run: wandb.apis.public.Run) -> None:
71
+ wandb_tags = set(run.config.get("wandb_tags", []))
72
+ if self.control & wandb_tags:
73
+ self.control_runs.append(run)
74
+ elif self.experiment & wandb_tags:
75
+ self.experiment_runs.append(run)
76
+
77
+ def comparisons_by_metric(self) -> Dict[str, Comparison]:
78
+ c_by_m = {}
79
+ for metric in (
80
+ f"{s}/{m}"
81
+ for s, m in itertools.product(self.summary_stats, self.summary_metrics)
82
+ ):
83
+ c_by_m[metric] = Comparison(
84
+ [c.summary[metric] for c in self.control_runs],
85
+ [e.summary[metric] for e in self.experiment_runs],
86
+ )
87
+ return c_by_m
88
+
89
+ @staticmethod
90
+ def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
91
+ results = defaultdict(list)
92
+ for r in rows:
93
+ if not r.control_runs or not r.experiment_runs:
94
+ continue
95
+ results["algo"].append(r.algo)
96
+ results["env"].append(r.env)
97
+ results["control"].append(r.control)
98
+ results["expierment"].append(r.experiment)
99
+ c_by_m = r.comparisons_by_metric()
100
+ results["score"].append(
101
+ sum(m.score() for m in c_by_m.values()) / len(c_by_m)
102
+ )
103
+ for m, c in c_by_m.items():
104
+ results[f"{m}_mean"].append(c.mean_diff_percentage())
105
+ results[f"{m}_median"].append(c.median_diff_percentage())
106
+ return pd.DataFrame(results)
107
+
108
+
109
+ def compare_runs() -> None:
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument(
112
+ "-p",
113
+ "--wandb-project-name",
114
+ type=str,
115
+ default="rl-algo-impls-benchmarks",
116
+ help="WandB project name to load runs from",
117
+ )
118
+ parser.add_argument(
119
+ "--wandb-entity",
120
+ type=str,
121
+ default=None,
122
+ help="WandB team. None uses default entity",
123
+ )
124
+ parser.add_argument(
125
+ "-n",
126
+ "--wandb-hostname-tag",
127
+ type=str,
128
+ nargs="*",
129
+ help="WandB tags for hostname (i.e. host_192-9-145-26)",
130
+ )
131
+ parser.add_argument(
132
+ "-c",
133
+ "--wandb-control-tag",
134
+ type=str,
135
+ nargs="+",
136
+ help="WandB tag for control commit (i.e. benchmark_5598ebc)",
137
+ )
138
+ parser.add_argument(
139
+ "-e",
140
+ "--wandb-experiment-tag",
141
+ type=str,
142
+ nargs="+",
143
+ help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
144
+ )
145
+ parser.add_argument(
146
+ "--envs",
147
+ type=str,
148
+ nargs="*",
149
+ help="If specified, only compare these envs",
150
+ )
151
+ parser.add_argument(
152
+ "--exclude-envs",
153
+ type=str,
154
+ nargs="*",
155
+ help="Environments to exclude from comparison",
156
+ )
157
+ # parser.set_defaults(
158
+ # wandb_hostname_tag=["host_150-230-44-105", "host_155-248-214-128"],
159
+ # wandb_control_tag=["benchmark_fbc943f"],
160
+ # wandb_experiment_tag=["benchmark_f59bf74"],
161
+ # exclude_envs=[],
162
+ # )
163
+ args = parser.parse_args()
164
+ print(args)
165
+
166
+ api = wandb.Api()
167
+ all_runs = api.runs(
168
+ path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
169
+ order="+created_at",
170
+ )
171
+
172
+ runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
173
+ wandb_hostname_tags = set(args.wandb_hostname_tag)
174
+ for r in all_runs:
175
+ if r.state != "finished":
176
+ continue
177
+ wandb_tags = set(r.config.get("wandb_tags", []))
178
+ if not wandb_tags or not wandb_hostname_tags & wandb_tags:
179
+ continue
180
+ rg = RunGroup(r.config["algo"], r.config.get("env_id") or r.config["env"])
181
+ if args.exclude_envs and rg.env_id in args.exclude_envs:
182
+ continue
183
+ if args.envs and rg.env_id not in args.envs:
184
+ continue
185
+ if rg not in runs_by_run_group:
186
+ runs_by_run_group[rg] = RunGroupRuns(
187
+ rg,
188
+ args.wandb_control_tag,
189
+ args.wandb_experiment_tag,
190
+ )
191
+ runs_by_run_group[rg].add_run(r)
192
+ df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
193
+ print(f"**Total Score: {sum(df.score)}**")
194
+ df.loc["mean"] = df.mean(numeric_only=True)
195
+ print(df.to_markdown())
196
+
197
+
198
+ if __name__ == "__main__":
199
+ compare_runs()
rl_algo_impls/dqn/dqn.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import random
4
+ from collections import deque
5
+ from typing import List, NamedTuple, Optional, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.optim import Adam
12
+ from torch.utils.tensorboard.writer import SummaryWriter
13
+
14
+ from rl_algo_impls.dqn.policy import DQNPolicy
15
+ from rl_algo_impls.shared.algorithm import Algorithm
16
+ from rl_algo_impls.shared.callbacks import Callback
17
+ from rl_algo_impls.shared.schedule import linear_schedule
18
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
19
+
20
+
21
+ class Transition(NamedTuple):
22
+ obs: np.ndarray
23
+ action: np.ndarray
24
+ reward: float
25
+ done: bool
26
+ next_obs: np.ndarray
27
+
28
+
29
+ class Batch(NamedTuple):
30
+ obs: np.ndarray
31
+ actions: np.ndarray
32
+ rewards: np.ndarray
33
+ dones: np.ndarray
34
+ next_obs: np.ndarray
35
+
36
+
37
+ class ReplayBuffer:
38
+ def __init__(self, num_envs: int, maxlen: int) -> None:
39
+ self.num_envs = num_envs
40
+ self.buffer = deque(maxlen=maxlen)
41
+
42
+ def add(
43
+ self,
44
+ obs: VecEnvObs,
45
+ action: np.ndarray,
46
+ reward: np.ndarray,
47
+ done: np.ndarray,
48
+ next_obs: VecEnvObs,
49
+ ) -> None:
50
+ assert isinstance(obs, np.ndarray)
51
+ assert isinstance(next_obs, np.ndarray)
52
+ for i in range(self.num_envs):
53
+ self.buffer.append(
54
+ Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
55
+ )
56
+
57
+ def sample(self, batch_size: int) -> Batch:
58
+ ts = random.sample(self.buffer, batch_size)
59
+ return Batch(
60
+ obs=np.array([t.obs for t in ts]),
61
+ actions=np.array([t.action for t in ts]),
62
+ rewards=np.array([t.reward for t in ts]),
63
+ dones=np.array([t.done for t in ts]),
64
+ next_obs=np.array([t.next_obs for t in ts]),
65
+ )
66
+
67
+ def __len__(self) -> int:
68
+ return len(self.buffer)
69
+
70
+
71
+ DQNSelf = TypeVar("DQNSelf", bound="DQN")
72
+
73
+
74
+ class DQN(Algorithm):
75
+ def __init__(
76
+ self,
77
+ policy: DQNPolicy,
78
+ env: VecEnv,
79
+ device: torch.device,
80
+ tb_writer: SummaryWriter,
81
+ learning_rate: float = 1e-4,
82
+ buffer_size: int = 1_000_000,
83
+ learning_starts: int = 50_000,
84
+ batch_size: int = 32,
85
+ tau: float = 1.0,
86
+ gamma: float = 0.99,
87
+ train_freq: int = 4,
88
+ gradient_steps: int = 1,
89
+ target_update_interval: int = 10_000,
90
+ exploration_fraction: float = 0.1,
91
+ exploration_initial_eps: float = 1.0,
92
+ exploration_final_eps: float = 0.05,
93
+ max_grad_norm: float = 10.0,
94
+ ) -> None:
95
+ super().__init__(policy, env, device, tb_writer)
96
+ self.policy = policy
97
+
98
+ self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
99
+
100
+ self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
101
+ self.target_q_net.train(False)
102
+ self.tau = tau
103
+ self.target_update_interval = target_update_interval
104
+
105
+ self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
106
+ self.batch_size = batch_size
107
+
108
+ self.learning_starts = learning_starts
109
+ self.train_freq = train_freq
110
+ self.gradient_steps = gradient_steps
111
+
112
+ self.gamma = gamma
113
+ self.exploration_eps_schedule = linear_schedule(
114
+ exploration_initial_eps,
115
+ exploration_final_eps,
116
+ end_fraction=exploration_fraction,
117
+ )
118
+
119
+ self.max_grad_norm = max_grad_norm
120
+
121
+ def learn(
122
+ self: DQNSelf, total_timesteps: int, callbacks: Optional[List[Callback]] = None
123
+ ) -> DQNSelf:
124
+ self.policy.train(True)
125
+ obs = self.env.reset()
126
+ obs = self._collect_rollout(self.learning_starts, obs, 1)
127
+ learning_steps = total_timesteps - self.learning_starts
128
+ timesteps_elapsed = 0
129
+ steps_since_target_update = 0
130
+ while timesteps_elapsed < learning_steps:
131
+ progress = timesteps_elapsed / learning_steps
132
+ eps = self.exploration_eps_schedule(progress)
133
+ obs = self._collect_rollout(self.train_freq, obs, eps)
134
+ rollout_steps = self.train_freq
135
+ timesteps_elapsed += rollout_steps
136
+ for _ in range(
137
+ self.gradient_steps if self.gradient_steps > 0 else self.train_freq
138
+ ):
139
+ self.train()
140
+ steps_since_target_update += rollout_steps
141
+ if steps_since_target_update >= self.target_update_interval:
142
+ self._update_target()
143
+ steps_since_target_update = 0
144
+ if callbacks:
145
+ if not all(
146
+ c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
147
+ ):
148
+ logging.info(
149
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
150
+ )
151
+ break
152
+ return self
153
+
154
+ def train(self) -> None:
155
+ if len(self.replay_buffer) < self.batch_size:
156
+ return
157
+ o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
158
+ o = torch.as_tensor(o, device=self.device)
159
+ a = torch.as_tensor(a, device=self.device).unsqueeze(1)
160
+ r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
161
+ d = torch.as_tensor(d, dtype=torch.long, device=self.device)
162
+ next_o = torch.as_tensor(next_o, device=self.device)
163
+
164
+ with torch.no_grad():
165
+ target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
166
+ current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
167
+ loss = F.smooth_l1_loss(current, target)
168
+
169
+ self.optimizer.zero_grad()
170
+ loss.backward()
171
+ if self.max_grad_norm:
172
+ nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
173
+ self.optimizer.step()
174
+
175
+ def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
176
+ for _ in range(0, timesteps, self.env.num_envs):
177
+ action = self.policy.act(obs, eps, deterministic=False)
178
+ next_obs, reward, done, _ = self.env.step(action)
179
+ self.replay_buffer.add(obs, action, reward, done, next_obs)
180
+ obs = next_obs
181
+ return obs
182
+
183
+ def _update_target(self) -> None:
184
+ for target_param, param in zip(
185
+ self.target_q_net.parameters(), self.policy.q_net.parameters()
186
+ ):
187
+ target_param.data.copy_(
188
+ self.tau * param.data + (1 - self.tau) * target_param.data
189
+ )
rl_algo_impls/dqn/policy.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Sequence, TypeVar
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from rl_algo_impls.dqn.q_net import QNetwork
8
+ from rl_algo_impls.shared.policy.policy import Policy
9
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
10
+ VecEnv,
11
+ VecEnvObs,
12
+ single_action_space,
13
+ single_observation_space,
14
+ )
15
+
16
+ DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
17
+
18
+
19
+ class DQNPolicy(Policy):
20
+ def __init__(
21
+ self,
22
+ env: VecEnv,
23
+ hidden_sizes: Sequence[int] = [],
24
+ cnn_flatten_dim: int = 512,
25
+ cnn_style: str = "nature",
26
+ cnn_layers_init_orthogonal: Optional[bool] = None,
27
+ impala_channels: Sequence[int] = (16, 32, 32),
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(env, **kwargs)
31
+ self.q_net = QNetwork(
32
+ single_observation_space(env),
33
+ single_action_space(env),
34
+ hidden_sizes,
35
+ cnn_flatten_dim=cnn_flatten_dim,
36
+ cnn_style=cnn_style,
37
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
38
+ impala_channels=impala_channels,
39
+ )
40
+
41
+ def act(
42
+ self,
43
+ obs: VecEnvObs,
44
+ eps: float = 0,
45
+ deterministic: bool = True,
46
+ action_masks: Optional[np.ndarray] = None,
47
+ ) -> np.ndarray:
48
+ assert eps == 0 if deterministic else eps >= 0
49
+ assert (
50
+ action_masks is None
51
+ ), f"action_masks not currently supported in {self.__class__.__name__}"
52
+ if not deterministic and np.random.random() < eps:
53
+ return np.array(
54
+ [
55
+ single_action_space(self.env).sample()
56
+ for _ in range(self.env.num_envs)
57
+ ]
58
+ )
59
+ else:
60
+ o = self._as_tensor(obs)
61
+ with torch.no_grad():
62
+ return self.q_net(o).argmax(axis=1).cpu().numpy()
rl_algo_impls/dqn/q_net.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Type
2
+
3
+ import gym
4
+ import torch as th
5
+ import torch.nn as nn
6
+ from gym.spaces import Discrete
7
+
8
+ from rl_algo_impls.shared.encoder import Encoder
9
+ from rl_algo_impls.shared.module.utils import mlp
10
+
11
+
12
+ class QNetwork(nn.Module):
13
+ def __init__(
14
+ self,
15
+ observation_space: gym.Space,
16
+ action_space: gym.Space,
17
+ hidden_sizes: Sequence[int] = [],
18
+ activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
19
+ cnn_flatten_dim: int = 512,
20
+ cnn_style: str = "nature",
21
+ cnn_layers_init_orthogonal: Optional[bool] = None,
22
+ impala_channels: Sequence[int] = (16, 32, 32),
23
+ ) -> None:
24
+ super().__init__()
25
+ assert isinstance(action_space, Discrete)
26
+ self._feature_extractor = Encoder(
27
+ observation_space,
28
+ activation,
29
+ cnn_flatten_dim=cnn_flatten_dim,
30
+ cnn_style=cnn_style,
31
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
32
+ impala_channels=impala_channels,
33
+ )
34
+ layer_sizes = (
35
+ (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
36
+ )
37
+ self._fc = mlp(layer_sizes, activation)
38
+
39
+ def forward(self, obs: th.Tensor) -> th.Tensor:
40
+ x = self._feature_extractor(obs)
41
+ return self._fc(x)
rl_algo_impls/enjoy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
7
+ from rl_algo_impls.runner.running_utils import base_parser
8
+
9
+
10
+ def enjoy() -> None:
11
+ parser = base_parser(multiple=False)
12
+ parser.add_argument("--render", default=True, type=bool)
13
+ parser.add_argument("--best", default=True, type=bool)
14
+ parser.add_argument("--n_envs", default=1, type=int)
15
+ parser.add_argument("--n_episodes", default=3, type=int)
16
+ parser.add_argument("--deterministic-eval", default=None, type=bool)
17
+ parser.add_argument(
18
+ "--no-print-returns", action="store_true", help="Limit printing"
19
+ )
20
+ # wandb-run-path overrides base RunArgs
21
+ parser.add_argument("--wandb-run-path", default=None, type=str)
22
+ parser.set_defaults(
23
+ algo=["ppo"],
24
+ wandb_run_path="sgoodfriend/rl-algo-impls/m5c1t7g5",
25
+ )
26
+ args = parser.parse_args()
27
+ args.algo = args.algo[0]
28
+ args.env = args.env[0]
29
+ args = EvalArgs(**vars(args))
30
+
31
+ evaluate_model(args, os.getcwd())
32
+
33
+
34
+ if __name__ == "__main__":
35
+ enjoy()
rl_algo_impls/huggingface_publish.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
+
5
+ import argparse
6
+ import shutil
7
+ import subprocess
8
+ import tempfile
9
+ from typing import List, Optional
10
+
11
+ import requests
12
+ import wandb.apis.public
13
+ from huggingface_hub.hf_api import HfApi, upload_folder
14
+ from huggingface_hub.repocard import metadata_save
15
+ from pyvirtualdisplay.display import Display
16
+
17
+ import wandb
18
+ from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
19
+ from rl_algo_impls.runner.config import EnvHyperparams
20
+ from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
21
+ from rl_algo_impls.shared.callbacks.eval_callback import evaluate
22
+ from rl_algo_impls.shared.vec_env import make_eval_env
23
+ from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
24
+
25
+
26
+ def publish(
27
+ wandb_run_paths: List[str],
28
+ wandb_report_url: str,
29
+ huggingface_user: Optional[str] = None,
30
+ huggingface_token: Optional[str] = None,
31
+ virtual_display: bool = False,
32
+ ) -> None:
33
+ if virtual_display:
34
+ display = Display(visible=False, size=(1400, 900))
35
+ display.start()
36
+
37
+ api = wandb.Api()
38
+ runs = [api.run(rp) for rp in wandb_run_paths]
39
+ algo = runs[0].config["algo"]
40
+ hyperparam_id = runs[0].config["env"]
41
+ evaluations = [
42
+ evaluate_model(
43
+ EvalArgs(
44
+ algo,
45
+ hyperparam_id,
46
+ seed=r.config.get("seed", None),
47
+ render=False,
48
+ best=True,
49
+ n_envs=None,
50
+ n_episodes=10,
51
+ no_print_returns=True,
52
+ wandb_run_path="/".join(r.path),
53
+ ),
54
+ os.getcwd(),
55
+ )
56
+ for r in runs
57
+ ]
58
+ run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
59
+ table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
60
+ best_eval = sorted(
61
+ table_data, key=lambda d: d.evaluation.stats.score, reverse=True
62
+ )[0]
63
+
64
+ with tempfile.TemporaryDirectory() as tmpdirname:
65
+ _, (policy, stats, config) = best_eval
66
+
67
+ repo_name = config.model_name(include_seed=False)
68
+ repo_dir_path = os.path.join(tmpdirname, repo_name)
69
+ # Locally clone this repo to a temp directory
70
+ subprocess.run(["git", "clone", ".", repo_dir_path])
71
+ shutil.rmtree(os.path.join(repo_dir_path, ".git"))
72
+ model_path = config.model_dir_path(best=True, downloaded=True)
73
+ shutil.copytree(
74
+ model_path,
75
+ os.path.join(
76
+ repo_dir_path, "saved_models", config.model_dir_name(best=True)
77
+ ),
78
+ )
79
+
80
+ github_url = "https://github.com/sgoodfriend/rl-algo-impls"
81
+ commit_hash = run_metadata.get("git", {}).get("commit", None)
82
+ env_id = runs[0].config.get("env_id") or runs[0].config["env"]
83
+ card_text = model_card_text(
84
+ algo,
85
+ env_id,
86
+ github_url,
87
+ commit_hash,
88
+ wandb_report_url,
89
+ table_data,
90
+ best_eval,
91
+ )
92
+ readme_filepath = os.path.join(repo_dir_path, "README.md")
93
+ os.remove(readme_filepath)
94
+ with open(readme_filepath, "w") as f:
95
+ f.write(card_text)
96
+
97
+ metadata = {
98
+ "library_name": "rl-algo-impls",
99
+ "tags": [
100
+ env_id,
101
+ algo,
102
+ "deep-reinforcement-learning",
103
+ "reinforcement-learning",
104
+ ],
105
+ "model-index": [
106
+ {
107
+ "name": algo,
108
+ "results": [
109
+ {
110
+ "metrics": [
111
+ {
112
+ "type": "mean_reward",
113
+ "value": str(stats.score),
114
+ "name": "mean_reward",
115
+ }
116
+ ],
117
+ "task": {
118
+ "type": "reinforcement-learning",
119
+ "name": "reinforcement-learning",
120
+ },
121
+ "dataset": {
122
+ "name": env_id,
123
+ "type": env_id,
124
+ },
125
+ }
126
+ ],
127
+ }
128
+ ],
129
+ }
130
+ metadata_save(readme_filepath, metadata)
131
+
132
+ video_env = VecEpisodeRecorder(
133
+ make_eval_env(
134
+ config,
135
+ EnvHyperparams(**config.env_hyperparams),
136
+ override_n_envs=1,
137
+ normalize_load_path=model_path,
138
+ ),
139
+ os.path.join(repo_dir_path, "replay"),
140
+ max_video_length=3600,
141
+ )
142
+ evaluate(
143
+ video_env,
144
+ policy,
145
+ 1,
146
+ deterministic=config.eval_hyperparams.get("deterministic", True),
147
+ )
148
+
149
+ api = HfApi()
150
+ huggingface_user = huggingface_user or api.whoami()["name"]
151
+ huggingface_repo = f"{huggingface_user}/{repo_name}"
152
+ api.create_repo(
153
+ token=huggingface_token,
154
+ repo_id=huggingface_repo,
155
+ private=False,
156
+ exist_ok=True,
157
+ )
158
+ repo_url = upload_folder(
159
+ repo_id=huggingface_repo,
160
+ folder_path=repo_dir_path,
161
+ path_in_repo="",
162
+ commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
163
+ token=huggingface_token,
164
+ delete_patterns="*",
165
+ )
166
+ print(f"Pushed model to the hub: {repo_url}")
167
+
168
+
169
+ def huggingface_publish():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument(
172
+ "--wandb-run-paths",
173
+ type=str,
174
+ nargs="+",
175
+ help="Run paths of the form entity/project/run_id",
176
+ )
177
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
178
+ parser.add_argument(
179
+ "--huggingface-user",
180
+ type=str,
181
+ help="Huggingface user or team to upload model cards",
182
+ default=None,
183
+ )
184
+ parser.add_argument(
185
+ "--virtual-display", action="store_true", help="Use headless virtual display"
186
+ )
187
+ args = parser.parse_args()
188
+ print(args)
189
+ publish(**vars(args))
190
+
191
+
192
+ if __name__ == "__main__":
193
+ huggingface_publish()
rl_algo_impls/hyperparams/a2c.yml ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 5e5
3
+ env_hyperparams:
4
+ n_envs: 8
5
+
6
+ CartPole-v0:
7
+ <<: *cartpole-defaults
8
+
9
+ MountainCar-v0:
10
+ n_timesteps: !!float 1e6
11
+ env_hyperparams:
12
+ n_envs: 16
13
+ normalize: true
14
+
15
+ MountainCarContinuous-v0:
16
+ n_timesteps: !!float 1e5
17
+ env_hyperparams:
18
+ n_envs: 4
19
+ normalize: true
20
+ # policy_hyperparams:
21
+ # use_sde: true
22
+ # log_std_init: 0.0
23
+ # init_layers_orthogonal: false
24
+ algo_hyperparams:
25
+ n_steps: 100
26
+ sde_sample_freq: 16
27
+
28
+ Acrobot-v1:
29
+ n_timesteps: !!float 5e5
30
+ env_hyperparams:
31
+ normalize: true
32
+ n_envs: 16
33
+
34
+ # Tuned
35
+ LunarLander-v2:
36
+ device: cpu
37
+ n_timesteps: !!float 1e6
38
+ env_hyperparams:
39
+ n_envs: 4
40
+ normalize: true
41
+ algo_hyperparams:
42
+ n_steps: 2
43
+ gamma: 0.9955517404308908
44
+ gae_lambda: 0.9875340918797773
45
+ learning_rate: 0.0013814130817068916
46
+ learning_rate_decay: linear
47
+ ent_coef: !!float 3.388369146384422e-7
48
+ ent_coef_decay: none
49
+ max_grad_norm: 3.33982095073364
50
+ normalize_advantage: true
51
+ vf_coef: 0.1667838310548184
52
+
53
+ BipedalWalker-v3:
54
+ n_timesteps: !!float 5e6
55
+ env_hyperparams:
56
+ n_envs: 16
57
+ normalize: true
58
+ policy_hyperparams:
59
+ use_sde: true
60
+ log_std_init: -2
61
+ init_layers_orthogonal: false
62
+ algo_hyperparams:
63
+ ent_coef: 0
64
+ max_grad_norm: 0.5
65
+ n_steps: 8
66
+ gae_lambda: 0.9
67
+ vf_coef: 0.4
68
+ gamma: 0.99
69
+ learning_rate: !!float 9.6e-4
70
+ learning_rate_decay: linear
71
+
72
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
73
+ n_timesteps: !!float 2e6
74
+ env_hyperparams:
75
+ n_envs: 4
76
+ normalize: true
77
+ policy_hyperparams:
78
+ use_sde: true
79
+ log_std_init: -2
80
+ init_layers_orthogonal: false
81
+ algo_hyperparams: &pybullet-algo-defaults
82
+ n_steps: 8
83
+ ent_coef: 0
84
+ max_grad_norm: 0.5
85
+ gae_lambda: 0.9
86
+ gamma: 0.99
87
+ vf_coef: 0.4
88
+ learning_rate: !!float 9.6e-4
89
+ learning_rate_decay: linear
90
+
91
+ AntBulletEnv-v0:
92
+ <<: *pybullet-defaults
93
+
94
+ Walker2DBulletEnv-v0:
95
+ <<: *pybullet-defaults
96
+
97
+ HopperBulletEnv-v0:
98
+ <<: *pybullet-defaults
99
+
100
+ # Tuned
101
+ CarRacing-v0:
102
+ n_timesteps: !!float 4e6
103
+ env_hyperparams:
104
+ n_envs: 16
105
+ frame_stack: 4
106
+ normalize: true
107
+ normalize_kwargs:
108
+ norm_obs: false
109
+ norm_reward: true
110
+ policy_hyperparams:
111
+ use_sde: false
112
+ log_std_init: -1.3502584927786276
113
+ init_layers_orthogonal: true
114
+ activation_fn: tanh
115
+ share_features_extractor: false
116
+ cnn_flatten_dim: 256
117
+ hidden_sizes: [256]
118
+ algo_hyperparams:
119
+ n_steps: 16
120
+ learning_rate: 0.000025630993245026736
121
+ learning_rate_decay: linear
122
+ gamma: 0.99957617037542
123
+ gae_lambda: 0.949455676599436
124
+ ent_coef: !!float 1.707983205298309e-7
125
+ vf_coef: 0.10428178193833336
126
+ max_grad_norm: 0.5406643389792273
127
+ normalize_advantage: true
128
+ use_rms_prop: false
129
+
130
+ _atari: &atari-defaults
131
+ n_timesteps: !!float 1e7
132
+ env_hyperparams: &atari-env-defaults
133
+ n_envs: 16
134
+ frame_stack: 4
135
+ no_reward_timeout_steps: 1000
136
+ no_reward_fire_steps: 500
137
+ vec_env_class: async
138
+ policy_hyperparams: &atari-policy-defaults
139
+ activation_fn: relu
140
+ algo_hyperparams:
141
+ ent_coef: 0.01
142
+ vf_coef: 0.25
rl_algo_impls/hyperparams/dqn.yml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 5e4
3
+ env_hyperparams:
4
+ rolling_length: 50
5
+ policy_hyperparams:
6
+ hidden_sizes: [256, 256]
7
+ algo_hyperparams:
8
+ learning_rate: !!float 2.3e-3
9
+ batch_size: 64
10
+ buffer_size: 100000
11
+ learning_starts: 1000
12
+ gamma: 0.99
13
+ target_update_interval: 10
14
+ train_freq: 256
15
+ gradient_steps: 128
16
+ exploration_fraction: 0.16
17
+ exploration_final_eps: 0.04
18
+ eval_hyperparams:
19
+ step_freq: !!float 1e4
20
+
21
+ CartPole-v0:
22
+ <<: *cartpole-defaults
23
+ n_timesteps: !!float 4e4
24
+
25
+ MountainCar-v0:
26
+ n_timesteps: !!float 1.2e5
27
+ env_hyperparams:
28
+ rolling_length: 50
29
+ policy_hyperparams:
30
+ hidden_sizes: [256, 256]
31
+ algo_hyperparams:
32
+ learning_rate: !!float 4e-3
33
+ batch_size: 128
34
+ buffer_size: 10000
35
+ learning_starts: 1000
36
+ gamma: 0.98
37
+ target_update_interval: 600
38
+ train_freq: 16
39
+ gradient_steps: 8
40
+ exploration_fraction: 0.2
41
+ exploration_final_eps: 0.07
42
+
43
+ Acrobot-v1:
44
+ n_timesteps: !!float 1e5
45
+ env_hyperparams:
46
+ rolling_length: 50
47
+ policy_hyperparams:
48
+ hidden_sizes: [256, 256]
49
+ algo_hyperparams:
50
+ learning_rate: !!float 6.3e-4
51
+ batch_size: 128
52
+ buffer_size: 50000
53
+ learning_starts: 0
54
+ gamma: 0.99
55
+ target_update_interval: 250
56
+ train_freq: 4
57
+ gradient_steps: -1
58
+ exploration_fraction: 0.12
59
+ exploration_final_eps: 0.1
60
+
61
+ LunarLander-v2:
62
+ n_timesteps: !!float 5e5
63
+ env_hyperparams:
64
+ rolling_length: 50
65
+ policy_hyperparams:
66
+ hidden_sizes: [256, 256]
67
+ algo_hyperparams:
68
+ learning_rate: !!float 1e-4
69
+ batch_size: 256
70
+ buffer_size: 100000
71
+ learning_starts: 10000
72
+ gamma: 0.99
73
+ target_update_interval: 250
74
+ train_freq: 8
75
+ gradient_steps: -1
76
+ exploration_fraction: 0.12
77
+ exploration_final_eps: 0.1
78
+ max_grad_norm: 0.5
79
+ eval_hyperparams:
80
+ step_freq: 25_000
81
+
82
+ _atari: &atari-defaults
83
+ n_timesteps: !!float 1e7
84
+ env_hyperparams:
85
+ frame_stack: 4
86
+ no_reward_timeout_steps: 1_000
87
+ no_reward_fire_steps: 500
88
+ n_envs: 8
89
+ vec_env_class: async
90
+ algo_hyperparams:
91
+ buffer_size: 100000
92
+ learning_rate: !!float 1e-4
93
+ batch_size: 32
94
+ learning_starts: 100000
95
+ target_update_interval: 1000
96
+ train_freq: 8
97
+ gradient_steps: 2
98
+ exploration_fraction: 0.1
99
+ exploration_final_eps: 0.01
100
+ eval_hyperparams:
101
+ deterministic: false
102
+
103
+ PongNoFrameskip-v4:
104
+ <<: *atari-defaults
105
+ n_timesteps: !!float 2.5e6
106
+
107
+ _impala-atari: &impala-atari-defaults
108
+ <<: *atari-defaults
109
+ policy_hyperparams:
110
+ cnn_style: impala
111
+ cnn_flatten_dim: 256
112
+ init_layers_orthogonal: true
113
+ cnn_layers_init_orthogonal: false
114
+
115
+ impala-PongNoFrameskip-v4:
116
+ <<: *impala-atari-defaults
117
+ env_id: PongNoFrameskip-v4
118
+ n_timesteps: !!float 2.5e6
119
+
120
+ impala-BreakoutNoFrameskip-v4:
121
+ <<: *impala-atari-defaults
122
+ env_id: BreakoutNoFrameskip-v4
123
+
124
+ impala-SpaceInvadersNoFrameskip-v4:
125
+ <<: *impala-atari-defaults
126
+ env_id: SpaceInvadersNoFrameskip-v4
127
+
128
+ impala-QbertNoFrameskip-v4:
129
+ <<: *impala-atari-defaults
130
+ env_id: QbertNoFrameskip-v4
rl_algo_impls/hyperparams/ppo.yml ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 1e5
3
+ env_hyperparams:
4
+ n_envs: 8
5
+ algo_hyperparams:
6
+ n_steps: 32
7
+ batch_size: 256
8
+ n_epochs: 20
9
+ gae_lambda: 0.8
10
+ gamma: 0.98
11
+ ent_coef: 0.0
12
+ learning_rate: 0.001
13
+ learning_rate_decay: linear
14
+ clip_range: 0.2
15
+ clip_range_decay: linear
16
+ eval_hyperparams:
17
+ step_freq: !!float 2.5e4
18
+
19
+ CartPole-v0:
20
+ <<: *cartpole-defaults
21
+ n_timesteps: !!float 5e4
22
+
23
+ MountainCar-v0:
24
+ n_timesteps: !!float 1e6
25
+ env_hyperparams:
26
+ normalize: true
27
+ n_envs: 16
28
+ algo_hyperparams:
29
+ n_steps: 16
30
+ n_epochs: 4
31
+ gae_lambda: 0.98
32
+ gamma: 0.99
33
+ ent_coef: 0.0
34
+
35
+ MountainCarContinuous-v0:
36
+ n_timesteps: !!float 1e5
37
+ env_hyperparams:
38
+ normalize: true
39
+ n_envs: 4
40
+ # policy_hyperparams:
41
+ # init_layers_orthogonal: false
42
+ # log_std_init: -3.29
43
+ # use_sde: true
44
+ algo_hyperparams:
45
+ n_steps: 512
46
+ batch_size: 256
47
+ n_epochs: 10
48
+ learning_rate: !!float 7.77e-5
49
+ ent_coef: 0.01 # 0.00429
50
+ ent_coef_decay: linear
51
+ clip_range: 0.1
52
+ gae_lambda: 0.9
53
+ max_grad_norm: 5
54
+ vf_coef: 0.19
55
+ eval_hyperparams:
56
+ step_freq: 5000
57
+
58
+ Acrobot-v1:
59
+ n_timesteps: !!float 1e6
60
+ env_hyperparams:
61
+ n_envs: 16
62
+ normalize: true
63
+ algo_hyperparams:
64
+ n_steps: 256
65
+ n_epochs: 4
66
+ gae_lambda: 0.94
67
+ gamma: 0.99
68
+ ent_coef: 0.0
69
+
70
+ LunarLander-v2:
71
+ n_timesteps: !!float 4e6
72
+ env_hyperparams:
73
+ n_envs: 16
74
+ algo_hyperparams:
75
+ n_steps: 1024
76
+ batch_size: 64
77
+ n_epochs: 4
78
+ gae_lambda: 0.98
79
+ gamma: 0.999
80
+ learning_rate: !!float 5e-4
81
+ learning_rate_decay: linear
82
+ clip_range: 0.2
83
+ clip_range_decay: linear
84
+ ent_coef: 0.01
85
+ normalize_advantage: false
86
+
87
+ BipedalWalker-v3:
88
+ n_timesteps: !!float 10e6
89
+ env_hyperparams:
90
+ n_envs: 16
91
+ normalize: true
92
+ algo_hyperparams:
93
+ n_steps: 2048
94
+ batch_size: 64
95
+ gae_lambda: 0.95
96
+ gamma: 0.99
97
+ n_epochs: 10
98
+ ent_coef: 0.001
99
+ learning_rate: !!float 2.5e-4
100
+ learning_rate_decay: linear
101
+ clip_range: 0.2
102
+ clip_range_decay: linear
103
+
104
+ CarRacing-v0: &carracing-defaults
105
+ n_timesteps: !!float 4e6
106
+ env_hyperparams:
107
+ n_envs: 8
108
+ frame_stack: 4
109
+ policy_hyperparams: &carracing-policy-defaults
110
+ use_sde: true
111
+ log_std_init: -2
112
+ init_layers_orthogonal: false
113
+ activation_fn: relu
114
+ share_features_extractor: false
115
+ cnn_flatten_dim: 256
116
+ hidden_sizes: [256]
117
+ algo_hyperparams:
118
+ n_steps: 512
119
+ batch_size: 128
120
+ n_epochs: 10
121
+ learning_rate: !!float 1e-4
122
+ learning_rate_decay: linear
123
+ gamma: 0.99
124
+ gae_lambda: 0.95
125
+ ent_coef: 0.0
126
+ sde_sample_freq: 4
127
+ max_grad_norm: 0.5
128
+ vf_coef: 0.5
129
+ clip_range: 0.2
130
+
131
+ impala-CarRacing-v0:
132
+ <<: *carracing-defaults
133
+ env_id: CarRacing-v0
134
+ policy_hyperparams:
135
+ <<: *carracing-policy-defaults
136
+ cnn_style: impala
137
+ init_layers_orthogonal: true
138
+ cnn_layers_init_orthogonal: false
139
+ hidden_sizes: []
140
+
141
+ # BreakoutNoFrameskip-v4
142
+ # PongNoFrameskip-v4
143
+ # SpaceInvadersNoFrameskip-v4
144
+ # QbertNoFrameskip-v4
145
+ _atari: &atari-defaults
146
+ n_timesteps: !!float 1e7
147
+ env_hyperparams: &atari-env-defaults
148
+ n_envs: 8
149
+ frame_stack: 4
150
+ no_reward_timeout_steps: 1000
151
+ no_reward_fire_steps: 500
152
+ vec_env_class: async
153
+ policy_hyperparams: &atari-policy-defaults
154
+ activation_fn: relu
155
+ algo_hyperparams: &atari-algo-defaults
156
+ n_steps: 128
157
+ batch_size: 256
158
+ n_epochs: 4
159
+ learning_rate: !!float 2.5e-4
160
+ learning_rate_decay: linear
161
+ clip_range: 0.1
162
+ clip_range_decay: linear
163
+ vf_coef: 0.5
164
+ ent_coef: 0.01
165
+ eval_hyperparams:
166
+ deterministic: false
167
+
168
+ _norm-rewards-atari: &norm-rewards-atari-default
169
+ <<: *atari-defaults
170
+ env_hyperparams:
171
+ <<: *atari-env-defaults
172
+ clip_atari_rewards: false
173
+ normalize: true
174
+ normalize_kwargs:
175
+ norm_obs: false
176
+ norm_reward: true
177
+
178
+ norm-rewards-BreakoutNoFrameskip-v4:
179
+ <<: *norm-rewards-atari-default
180
+ env_id: BreakoutNoFrameskip-v4
181
+
182
+ debug-PongNoFrameskip-v4:
183
+ <<: *atari-defaults
184
+ device: cpu
185
+ env_id: PongNoFrameskip-v4
186
+ env_hyperparams:
187
+ <<: *atari-env-defaults
188
+ vec_env_class: sync
189
+
190
+ _impala-atari: &impala-atari-defaults
191
+ <<: *atari-defaults
192
+ policy_hyperparams:
193
+ <<: *atari-policy-defaults
194
+ cnn_style: impala
195
+ cnn_flatten_dim: 256
196
+ init_layers_orthogonal: true
197
+ cnn_layers_init_orthogonal: false
198
+
199
+ impala-PongNoFrameskip-v4:
200
+ <<: *impala-atari-defaults
201
+ env_id: PongNoFrameskip-v4
202
+
203
+ impala-BreakoutNoFrameskip-v4:
204
+ <<: *impala-atari-defaults
205
+ env_id: BreakoutNoFrameskip-v4
206
+
207
+ impala-SpaceInvadersNoFrameskip-v4:
208
+ <<: *impala-atari-defaults
209
+ env_id: SpaceInvadersNoFrameskip-v4
210
+
211
+ impala-QbertNoFrameskip-v4:
212
+ <<: *impala-atari-defaults
213
+ env_id: QbertNoFrameskip-v4
214
+
215
+ _microrts: &microrts-defaults
216
+ <<: *atari-defaults
217
+ n_timesteps: !!float 2e6
218
+ env_hyperparams: &microrts-env-defaults
219
+ n_envs: 8
220
+ vec_env_class: sync
221
+ mask_actions: true
222
+ policy_hyperparams: &microrts-policy-defaults
223
+ <<: *atari-policy-defaults
224
+ cnn_style: microrts
225
+ cnn_flatten_dim: 128
226
+ algo_hyperparams: &microrts-algo-defaults
227
+ <<: *atari-algo-defaults
228
+ clip_range_decay: none
229
+ clip_range_vf: 0.1
230
+ ppo2_vf_coef_halving: true
231
+ eval_hyperparams: &microrts-eval-defaults
232
+ deterministic: false # Good idea because MultiCategorical mode isn't great
233
+
234
+ _no-mask-microrts: &no-mask-microrts-defaults
235
+ <<: *microrts-defaults
236
+ env_hyperparams:
237
+ <<: *microrts-env-defaults
238
+ mask_actions: false
239
+
240
+ MicrortsMining-v1-NoMask:
241
+ <<: *no-mask-microrts-defaults
242
+ env_id: MicrortsMining-v1
243
+
244
+ MicrortsAttackShapedReward-v1-NoMask:
245
+ <<: *no-mask-microrts-defaults
246
+ env_id: MicrortsAttackShapedReward-v1
247
+
248
+ MicrortsRandomEnemyShapedReward3-v1-NoMask:
249
+ <<: *no-mask-microrts-defaults
250
+ env_id: MicrortsRandomEnemyShapedReward3-v1
251
+
252
+ _microrts_ai: &microrts-ai-defaults
253
+ <<: *microrts-defaults
254
+ n_timesteps: !!float 100e6
255
+ additional_keys_to_log: ["microrts_stats"]
256
+ env_hyperparams: &microrts-ai-env-defaults
257
+ n_envs: 24
258
+ env_type: microrts
259
+ make_kwargs: &microrts-ai-env-make-kwargs-defaults
260
+ num_selfplay_envs: 0
261
+ max_steps: 2000
262
+ render_theme: 2
263
+ map_paths: [maps/16x16/basesWorkers16x16.xml]
264
+ reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
265
+ policy_hyperparams: &microrts-ai-policy-defaults
266
+ <<: *microrts-policy-defaults
267
+ cnn_flatten_dim: 256
268
+ actor_head_style: gridnet
269
+ algo_hyperparams: &microrts-ai-algo-defaults
270
+ <<: *microrts-algo-defaults
271
+ learning_rate: !!float 2.5e-4
272
+ learning_rate_decay: linear
273
+ n_steps: 512
274
+ batch_size: 3072
275
+ n_epochs: 4
276
+ ent_coef: 0.01
277
+ vf_coef: 0.5
278
+ max_grad_norm: 0.5
279
+ clip_range: 0.1
280
+ clip_range_vf: 0.1
281
+ eval_hyperparams: &microrts-ai-eval-defaults
282
+ <<: *microrts-eval-defaults
283
+ score_function: mean
284
+ max_video_length: 4000
285
+ env_overrides: &microrts-ai-eval-env-overrides
286
+ make_kwargs:
287
+ <<: *microrts-ai-env-make-kwargs-defaults
288
+ max_steps: 4000
289
+ reward_weight: [1.0, 0, 0, 0, 0, 0]
290
+
291
+ MicrortsAttackPassiveEnemySparseReward-v3:
292
+ <<: *microrts-ai-defaults
293
+ n_timesteps: !!float 2e6
294
+ env_id: MicrortsAttackPassiveEnemySparseReward-v3 # Workaround to keep model name simple
295
+ env_hyperparams:
296
+ <<: *microrts-ai-env-defaults
297
+ bots:
298
+ passiveAI: 24
299
+
300
+ MicrortsDefeatRandomEnemySparseReward-v3: &microrts-random-ai-defaults
301
+ <<: *microrts-ai-defaults
302
+ n_timesteps: !!float 2e6
303
+ env_id: MicrortsDefeatRandomEnemySparseReward-v3 # Workaround to keep model name simple
304
+ env_hyperparams:
305
+ <<: *microrts-ai-env-defaults
306
+ bots:
307
+ randomBiasedAI: 24
308
+
309
+ enc-dec-MicrortsDefeatRandomEnemySparseReward-v3:
310
+ <<: *microrts-random-ai-defaults
311
+ policy_hyperparams:
312
+ <<: *microrts-ai-policy-defaults
313
+ cnn_style: gridnet_encoder
314
+ actor_head_style: gridnet_decoder
315
+ v_hidden_sizes: [128]
316
+
317
+ unet-MicrortsDefeatRandomEnemySparseReward-v3:
318
+ <<: *microrts-random-ai-defaults
319
+ # device: cpu
320
+ policy_hyperparams:
321
+ <<: *microrts-ai-policy-defaults
322
+ actor_head_style: unet
323
+ v_hidden_sizes: [256, 128]
324
+ algo_hyperparams:
325
+ <<: *microrts-ai-algo-defaults
326
+ learning_rate: !!float 2.5e-4
327
+ learning_rate_decay: spike
328
+
329
+ MicrortsDefeatCoacAIShaped-v3: &microrts-coacai-defaults
330
+ <<: *microrts-ai-defaults
331
+ env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
332
+ n_timesteps: !!float 300e6
333
+ env_hyperparams: &microrts-coacai-env-defaults
334
+ <<: *microrts-ai-env-defaults
335
+ bots:
336
+ coacAI: 24
337
+ eval_hyperparams: &microrts-coacai-eval-defaults
338
+ <<: *microrts-ai-eval-defaults
339
+ step_freq: !!float 1e6
340
+ n_episodes: 26
341
+ env_overrides: &microrts-coacai-eval-env-overrides
342
+ <<: *microrts-ai-eval-env-overrides
343
+ n_envs: 26
344
+ bots:
345
+ coacAI: 2
346
+ randomBiasedAI: 2
347
+ randomAI: 2
348
+ passiveAI: 2
349
+ workerRushAI: 2
350
+ lightRushAI: 2
351
+ naiveMCTSAI: 2
352
+ mixedBot: 2
353
+ rojo: 2
354
+ izanagi: 2
355
+ tiamat: 2
356
+ droplet: 2
357
+ guidedRojoA3N: 2
358
+
359
+ MicrortsDefeatCoacAIShaped-v3-diverseBots: &microrts-diverse-defaults
360
+ <<: *microrts-coacai-defaults
361
+ env_hyperparams:
362
+ <<: *microrts-coacai-env-defaults
363
+ bots:
364
+ coacAI: 18
365
+ randomBiasedAI: 2
366
+ lightRushAI: 2
367
+ workerRushAI: 2
368
+
369
+ enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
370
+ &microrts-env-dec-diverse-defaults
371
+ <<: *microrts-diverse-defaults
372
+ policy_hyperparams:
373
+ <<: *microrts-ai-policy-defaults
374
+ cnn_style: gridnet_encoder
375
+ actor_head_style: gridnet_decoder
376
+ v_hidden_sizes: [128]
377
+
378
+ debug-enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
379
+ <<: *microrts-env-dec-diverse-defaults
380
+ n_timesteps: !!float 1e6
381
+
382
+ unet-MicrortsDefeatCoacAIShaped-v3-diverseBots: &microrts-unet-defaults
383
+ <<: *microrts-diverse-defaults
384
+ policy_hyperparams:
385
+ <<: *microrts-ai-policy-defaults
386
+ actor_head_style: unet
387
+ v_hidden_sizes: [256, 128]
388
+ algo_hyperparams: &microrts-unet-algo-defaults
389
+ <<: *microrts-ai-algo-defaults
390
+ learning_rate: !!float 2.5e-4
391
+ learning_rate_decay: spike
392
+
393
+ Microrts-selfplay-unet: &microrts-selfplay-defaults
394
+ <<: *microrts-unet-defaults
395
+ env_hyperparams: &microrts-selfplay-env-defaults
396
+ <<: *microrts-ai-env-defaults
397
+ make_kwargs: &microrts-selfplay-env-make-kwargs-defaults
398
+ <<: *microrts-ai-env-make-kwargs-defaults
399
+ num_selfplay_envs: 36
400
+ self_play_kwargs:
401
+ num_old_policies: 12
402
+ save_steps: 200000
403
+ swap_steps: 10000
404
+ swap_window_size: 4
405
+ window: 25
406
+ eval_hyperparams: &microrts-selfplay-eval-defaults
407
+ <<: *microrts-coacai-eval-defaults
408
+ env_overrides: &microrts-selfplay-eval-env-overrides
409
+ <<: *microrts-coacai-eval-env-overrides
410
+ self_play_kwargs: {}
411
+
412
+ Microrts-selfplay-unet-winloss: &microrts-selfplay-winloss-defaults
413
+ <<: *microrts-selfplay-defaults
414
+ env_hyperparams:
415
+ <<: *microrts-selfplay-env-defaults
416
+ make_kwargs:
417
+ <<: *microrts-selfplay-env-make-kwargs-defaults
418
+ reward_weight: [1.0, 0, 0, 0, 0, 0]
419
+ algo_hyperparams: &microrts-selfplay-winloss-algo-defaults
420
+ <<: *microrts-unet-algo-defaults
421
+ gamma: 0.999
422
+
423
+ Microrts-selfplay-unet-decay:
424
+ <<: *microrts-selfplay-winloss-defaults
425
+ microrts_reward_decay_callback: true
426
+ algo_hyperparams:
427
+ <<: *microrts-selfplay-winloss-algo-defaults
428
+ gamma_end: 0.999
429
+
430
+ Microrts-selfplay-unet-debug: &microrts-selfplay-debug-defaults
431
+ <<: *microrts-selfplay-defaults
432
+ eval_hyperparams:
433
+ <<: *microrts-selfplay-eval-defaults
434
+ step_freq: !!float 1e5
435
+ env_overrides:
436
+ <<: *microrts-selfplay-eval-env-overrides
437
+ n_envs: 24
438
+ bots:
439
+ coacAI: 12
440
+ randomBiasedAI: 4
441
+ workerRushAI: 4
442
+ lightRushAI: 4
443
+
444
+ Microrts-selfplay-unet-debug-mps:
445
+ <<: *microrts-selfplay-debug-defaults
446
+ device: mps
447
+
448
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
449
+ n_timesteps: !!float 2e6
450
+ env_hyperparams: &pybullet-env-defaults
451
+ n_envs: 16
452
+ normalize: true
453
+ policy_hyperparams: &pybullet-policy-defaults
454
+ pi_hidden_sizes: [256, 256]
455
+ v_hidden_sizes: [256, 256]
456
+ activation_fn: relu
457
+ algo_hyperparams: &pybullet-algo-defaults
458
+ n_steps: 512
459
+ batch_size: 128
460
+ n_epochs: 20
461
+ gamma: 0.99
462
+ gae_lambda: 0.9
463
+ ent_coef: 0.0
464
+ max_grad_norm: 0.5
465
+ vf_coef: 0.5
466
+ learning_rate: !!float 3e-5
467
+ clip_range: 0.4
468
+
469
+ AntBulletEnv-v0:
470
+ <<: *pybullet-defaults
471
+ policy_hyperparams:
472
+ <<: *pybullet-policy-defaults
473
+ algo_hyperparams:
474
+ <<: *pybullet-algo-defaults
475
+
476
+ Walker2DBulletEnv-v0:
477
+ <<: *pybullet-defaults
478
+ algo_hyperparams:
479
+ <<: *pybullet-algo-defaults
480
+ clip_range_decay: linear
481
+
482
+ HopperBulletEnv-v0:
483
+ <<: *pybullet-defaults
484
+ algo_hyperparams:
485
+ <<: *pybullet-algo-defaults
486
+ clip_range_decay: linear
487
+
488
+ HumanoidBulletEnv-v0:
489
+ <<: *pybullet-defaults
490
+ n_timesteps: !!float 1e7
491
+ env_hyperparams:
492
+ <<: *pybullet-env-defaults
493
+ n_envs: 8
494
+ policy_hyperparams:
495
+ <<: *pybullet-policy-defaults
496
+ # log_std_init: -1
497
+ algo_hyperparams:
498
+ <<: *pybullet-algo-defaults
499
+ n_steps: 2048
500
+ batch_size: 64
501
+ n_epochs: 10
502
+ gae_lambda: 0.95
503
+ learning_rate: !!float 2.5e-4
504
+ clip_range: 0.2
505
+
506
+ _procgen: &procgen-defaults
507
+ env_hyperparams: &procgen-env-defaults
508
+ env_type: procgen
509
+ n_envs: 64
510
+ # grayscale: false
511
+ # frame_stack: 4
512
+ normalize: true # procgen only normalizes reward
513
+ make_kwargs: &procgen-make-kwargs-defaults
514
+ num_threads: 8
515
+ policy_hyperparams: &procgen-policy-defaults
516
+ activation_fn: relu
517
+ cnn_style: impala
518
+ cnn_flatten_dim: 256
519
+ init_layers_orthogonal: true
520
+ cnn_layers_init_orthogonal: false
521
+ algo_hyperparams: &procgen-algo-defaults
522
+ gamma: 0.999
523
+ gae_lambda: 0.95
524
+ n_steps: 256
525
+ batch_size: 2048
526
+ n_epochs: 3
527
+ ent_coef: 0.01
528
+ clip_range: 0.2
529
+ # clip_range_decay: linear
530
+ clip_range_vf: 0.2
531
+ learning_rate: !!float 5e-4
532
+ # learning_rate_decay: linear
533
+ vf_coef: 0.5
534
+ eval_hyperparams: &procgen-eval-defaults
535
+ ignore_first_episode: true
536
+ # deterministic: false
537
+ step_freq: !!float 1e5
538
+
539
+ _procgen-easy: &procgen-easy-defaults
540
+ <<: *procgen-defaults
541
+ n_timesteps: !!float 25e6
542
+ env_hyperparams: &procgen-easy-env-defaults
543
+ <<: *procgen-env-defaults
544
+ make_kwargs:
545
+ <<: *procgen-make-kwargs-defaults
546
+ distribution_mode: easy
547
+
548
+ procgen-coinrun-easy: &coinrun-easy-defaults
549
+ <<: *procgen-easy-defaults
550
+ env_id: coinrun
551
+
552
+ debug-procgen-coinrun:
553
+ <<: *coinrun-easy-defaults
554
+ device: cpu
555
+
556
+ procgen-starpilot-easy:
557
+ <<: *procgen-easy-defaults
558
+ env_id: starpilot
559
+
560
+ procgen-bossfight-easy:
561
+ <<: *procgen-easy-defaults
562
+ env_id: bossfight
563
+
564
+ procgen-bigfish-easy:
565
+ <<: *procgen-easy-defaults
566
+ env_id: bigfish
567
+
568
+ _procgen-hard: &procgen-hard-defaults
569
+ <<: *procgen-defaults
570
+ n_timesteps: !!float 200e6
571
+ env_hyperparams: &procgen-hard-env-defaults
572
+ <<: *procgen-env-defaults
573
+ n_envs: 256
574
+ make_kwargs:
575
+ <<: *procgen-make-kwargs-defaults
576
+ distribution_mode: hard
577
+ algo_hyperparams: &procgen-hard-algo-defaults
578
+ <<: *procgen-algo-defaults
579
+ batch_size: 8192
580
+ clip_range_decay: linear
581
+ learning_rate_decay: linear
582
+ eval_hyperparams:
583
+ <<: *procgen-eval-defaults
584
+ step_freq: !!float 5e5
585
+
586
+ procgen-starpilot-hard: &procgen-starpilot-hard-defaults
587
+ <<: *procgen-hard-defaults
588
+ env_id: starpilot
589
+
590
+ procgen-starpilot-hard-2xIMPALA:
591
+ <<: *procgen-starpilot-hard-defaults
592
+ policy_hyperparams:
593
+ <<: *procgen-policy-defaults
594
+ impala_channels: [32, 64, 64]
595
+ algo_hyperparams:
596
+ <<: *procgen-hard-algo-defaults
597
+ learning_rate: !!float 3.3e-4
598
+
599
+ procgen-starpilot-hard-2xIMPALA-fat:
600
+ <<: *procgen-starpilot-hard-defaults
601
+ policy_hyperparams:
602
+ <<: *procgen-policy-defaults
603
+ impala_channels: [32, 64, 64]
604
+ cnn_flatten_dim: 512
605
+ algo_hyperparams:
606
+ <<: *procgen-hard-algo-defaults
607
+ learning_rate: !!float 2.5e-4
608
+
609
+ procgen-starpilot-hard-4xIMPALA:
610
+ <<: *procgen-starpilot-hard-defaults
611
+ policy_hyperparams:
612
+ <<: *procgen-policy-defaults
613
+ impala_channels: [64, 128, 128]
614
+ algo_hyperparams:
615
+ <<: *procgen-hard-algo-defaults
616
+ learning_rate: !!float 2.1e-4
rl_algo_impls/hyperparams/vpg.yml ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 4e5
3
+ algo_hyperparams:
4
+ n_steps: 4096
5
+ pi_lr: 0.01
6
+ gamma: 0.99
7
+ gae_lambda: 1
8
+ val_lr: 0.01
9
+ train_v_iters: 80
10
+ eval_hyperparams:
11
+ step_freq: !!float 2.5e4
12
+
13
+ CartPole-v0:
14
+ <<: *cartpole-defaults
15
+ n_timesteps: !!float 1e5
16
+ algo_hyperparams:
17
+ n_steps: 1024
18
+ pi_lr: 0.01
19
+ gamma: 0.99
20
+ gae_lambda: 1
21
+ val_lr: 0.01
22
+ train_v_iters: 80
23
+
24
+ MountainCar-v0:
25
+ n_timesteps: !!float 1e6
26
+ env_hyperparams:
27
+ normalize: true
28
+ n_envs: 16
29
+ algo_hyperparams:
30
+ n_steps: 200
31
+ pi_lr: 0.005
32
+ gamma: 0.99
33
+ gae_lambda: 0.97
34
+ val_lr: 0.01
35
+ train_v_iters: 80
36
+ max_grad_norm: 0.5
37
+
38
+ MountainCarContinuous-v0:
39
+ n_timesteps: !!float 3e5
40
+ env_hyperparams:
41
+ normalize: true
42
+ n_envs: 4
43
+ # policy_hyperparams:
44
+ # init_layers_orthogonal: false
45
+ # log_std_init: -3.29
46
+ # use_sde: true
47
+ algo_hyperparams:
48
+ n_steps: 1000
49
+ pi_lr: !!float 5e-4
50
+ gamma: 0.99
51
+ gae_lambda: 0.9
52
+ val_lr: !!float 1e-3
53
+ train_v_iters: 80
54
+ max_grad_norm: 5
55
+ eval_hyperparams:
56
+ step_freq: 5000
57
+
58
+ Acrobot-v1:
59
+ n_timesteps: !!float 2e5
60
+ algo_hyperparams:
61
+ n_steps: 2048
62
+ pi_lr: 0.005
63
+ gamma: 0.99
64
+ gae_lambda: 0.97
65
+ val_lr: 0.01
66
+ train_v_iters: 80
67
+ max_grad_norm: 0.5
68
+
69
+ LunarLander-v2:
70
+ n_timesteps: !!float 4e6
71
+ policy_hyperparams:
72
+ hidden_sizes: [256, 256]
73
+ algo_hyperparams:
74
+ n_steps: 2048
75
+ pi_lr: 0.0001
76
+ gamma: 0.999
77
+ gae_lambda: 0.97
78
+ val_lr: 0.0001
79
+ train_v_iters: 80
80
+ max_grad_norm: 0.5
81
+ eval_hyperparams:
82
+ deterministic: false
83
+
84
+ BipedalWalker-v3:
85
+ n_timesteps: !!float 10e6
86
+ env_hyperparams:
87
+ n_envs: 16
88
+ normalize: true
89
+ policy_hyperparams:
90
+ hidden_sizes: [256, 256]
91
+ algo_hyperparams:
92
+ n_steps: 1600
93
+ gae_lambda: 0.95
94
+ gamma: 0.99
95
+ pi_lr: !!float 1e-4
96
+ val_lr: !!float 1e-4
97
+ train_v_iters: 80
98
+ max_grad_norm: 0.5
99
+ eval_hyperparams:
100
+ deterministic: false
101
+
102
+ CarRacing-v0:
103
+ n_timesteps: !!float 4e6
104
+ env_hyperparams:
105
+ frame_stack: 4
106
+ n_envs: 4
107
+ vec_env_class: sync
108
+ policy_hyperparams:
109
+ use_sde: true
110
+ log_std_init: -2
111
+ init_layers_orthogonal: false
112
+ activation_fn: relu
113
+ cnn_flatten_dim: 256
114
+ hidden_sizes: [256]
115
+ algo_hyperparams:
116
+ n_steps: 1000
117
+ pi_lr: !!float 5e-5
118
+ gamma: 0.99
119
+ gae_lambda: 0.95
120
+ val_lr: !!float 1e-4
121
+ train_v_iters: 40
122
+ max_grad_norm: 0.5
123
+ sde_sample_freq: 4
124
+
125
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
126
+ n_timesteps: !!float 2e6
127
+ env_hyperparams: &pybullet-env-defaults
128
+ normalize: true
129
+ policy_hyperparams: &pybullet-policy-defaults
130
+ hidden_sizes: [256, 256]
131
+ algo_hyperparams: &pybullet-algo-defaults
132
+ n_steps: 4000
133
+ pi_lr: !!float 3e-4
134
+ gamma: 0.99
135
+ gae_lambda: 0.97
136
+ val_lr: !!float 1e-3
137
+ train_v_iters: 80
138
+ max_grad_norm: 0.5
139
+
140
+ AntBulletEnv-v0:
141
+ <<: *pybullet-defaults
142
+ policy_hyperparams:
143
+ <<: *pybullet-policy-defaults
144
+ hidden_sizes: [400, 300]
145
+ algo_hyperparams:
146
+ <<: *pybullet-algo-defaults
147
+ pi_lr: !!float 7e-4
148
+ val_lr: !!float 7e-3
149
+
150
+ HopperBulletEnv-v0:
151
+ <<: *pybullet-defaults
152
+
153
+ Walker2DBulletEnv-v0:
154
+ <<: *pybullet-defaults
155
+
156
+ FrozenLake-v1:
157
+ n_timesteps: !!float 8e5
158
+ env_params:
159
+ make_kwargs:
160
+ map_name: 8x8
161
+ is_slippery: true
162
+ policy_hyperparams:
163
+ hidden_sizes: [64]
164
+ algo_hyperparams:
165
+ n_steps: 2048
166
+ pi_lr: 0.01
167
+ gamma: 0.99
168
+ gae_lambda: 0.98
169
+ val_lr: 0.01
170
+ train_v_iters: 80
171
+ max_grad_norm: 0.5
172
+ eval_hyperparams:
173
+ step_freq: !!float 5e4
174
+ n_episodes: 10
175
+ save_best: true
176
+
177
+ _atari: &atari-defaults
178
+ n_timesteps: !!float 10e6
179
+ env_hyperparams:
180
+ n_envs: 2
181
+ frame_stack: 4
182
+ no_reward_timeout_steps: 1000
183
+ no_reward_fire_steps: 500
184
+ vec_env_class: async
185
+ policy_hyperparams:
186
+ activation_fn: relu
187
+ algo_hyperparams:
188
+ n_steps: 3072
189
+ pi_lr: !!float 5e-5
190
+ gamma: 0.99
191
+ gae_lambda: 0.95
192
+ val_lr: !!float 1e-4
193
+ train_v_iters: 80
194
+ max_grad_norm: 0.5
195
+ ent_coef: 0.01
196
+ eval_hyperparams:
197
+ deterministic: false
rl_algo_impls/optimize.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import gc
3
+ import inspect
4
+ import logging
5
+ import os
6
+ from dataclasses import asdict, dataclass
7
+ from typing import Callable, List, NamedTuple, Optional, Sequence, Union
8
+
9
+ import numpy as np
10
+ import optuna
11
+ import torch
12
+ from optuna.pruners import HyperbandPruner
13
+ from optuna.samplers import TPESampler
14
+ from optuna.visualization import plot_optimization_history, plot_param_importances
15
+ from torch.utils.tensorboard.writer import SummaryWriter
16
+
17
+ import wandb
18
+ from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
19
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
20
+ from rl_algo_impls.runner.running_utils import (
21
+ ALGOS,
22
+ base_parser,
23
+ get_device,
24
+ hparam_dict,
25
+ load_hyperparams,
26
+ make_policy,
27
+ set_seeds,
28
+ )
29
+ from rl_algo_impls.shared.callbacks import Callback
30
+ from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
31
+ MicrortsRewardDecayCallback,
32
+ )
33
+ from rl_algo_impls.shared.callbacks.optimize_callback import (
34
+ Evaluation,
35
+ OptimizeCallback,
36
+ evaluation,
37
+ )
38
+ from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
39
+ from rl_algo_impls.shared.stats import EpisodesStats
40
+ from rl_algo_impls.shared.vec_env import make_env, make_eval_env
41
+ from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
42
+ from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
43
+
44
+
45
+ @dataclass
46
+ class StudyArgs:
47
+ load_study: bool
48
+ study_name: Optional[str] = None
49
+ storage_path: Optional[str] = None
50
+ n_trials: int = 100
51
+ n_jobs: int = 1
52
+ n_evaluations: int = 4
53
+ n_eval_envs: int = 8
54
+ n_eval_episodes: int = 16
55
+ timeout: Union[int, float, None] = None
56
+ wandb_project_name: Optional[str] = None
57
+ wandb_entity: Optional[str] = None
58
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
59
+ wandb_group: Optional[str] = None
60
+ virtual_display: bool = False
61
+
62
+
63
+ class Args(NamedTuple):
64
+ train_args: Sequence[RunArgs]
65
+ study_args: StudyArgs
66
+
67
+
68
+ def parse_args() -> Args:
69
+ parser = base_parser()
70
+ parser.add_argument(
71
+ "--load-study",
72
+ action="store_true",
73
+ help="Load a preexisting study, useful for parallelization",
74
+ )
75
+ parser.add_argument("--study-name", type=str, help="Optuna study name")
76
+ parser.add_argument(
77
+ "--storage-path",
78
+ type=str,
79
+ help="Path of database for Optuna to persist to",
80
+ )
81
+ parser.add_argument(
82
+ "--wandb-project-name",
83
+ type=str,
84
+ default="rl-algo-impls-tuning",
85
+ help="WandB project name to upload tuning data to. If none, won't upload",
86
+ )
87
+ parser.add_argument(
88
+ "--wandb-entity",
89
+ type=str,
90
+ help="WandB team. None uses the default entity",
91
+ )
92
+ parser.add_argument(
93
+ "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
94
+ )
95
+ parser.add_argument(
96
+ "--wandb-group", type=str, help="WandB group to group trials under"
97
+ )
98
+ parser.add_argument(
99
+ "--n-trials", type=int, default=100, help="Maximum number of trials"
100
+ )
101
+ parser.add_argument(
102
+ "--n-jobs", type=int, default=1, help="Number of jobs to run in parallel"
103
+ )
104
+ parser.add_argument(
105
+ "--n-evaluations",
106
+ type=int,
107
+ default=4,
108
+ help="Number of evaluations during the training",
109
+ )
110
+ parser.add_argument(
111
+ "--n-eval-envs",
112
+ type=int,
113
+ default=8,
114
+ help="Number of envs in vectorized eval environment",
115
+ )
116
+ parser.add_argument(
117
+ "--n-eval-episodes",
118
+ type=int,
119
+ default=16,
120
+ help="Number of episodes to complete for evaluation",
121
+ )
122
+ parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization")
123
+ parser.add_argument(
124
+ "--virtual-display", action="store_true", help="Use headless virtual display"
125
+ )
126
+ # parser.set_defaults(
127
+ # algo=["a2c"],
128
+ # env=["CartPole-v1"],
129
+ # seed=[100, 200, 300],
130
+ # n_trials=5,
131
+ # virtual_display=True,
132
+ # )
133
+ train_dict, study_dict = {}, {}
134
+ for k, v in vars(parser.parse_args()).items():
135
+ if k in inspect.signature(StudyArgs).parameters:
136
+ study_dict[k] = v
137
+ else:
138
+ train_dict[k] = v
139
+
140
+ study_args = StudyArgs(**study_dict)
141
+ # Hyperparameter tuning across algos and envs not supported
142
+ assert len(train_dict["algo"]) == 1
143
+ assert len(train_dict["env"]) == 1
144
+ train_args = RunArgs.expand_from_dict(train_dict)
145
+
146
+ if not all((study_args.study_name, study_args.storage_path)):
147
+ hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env)
148
+ config = Config(train_args[0], hyperparams, os.getcwd())
149
+ if study_args.study_name is None:
150
+ study_args.study_name = config.run_name(include_seed=False)
151
+ if study_args.storage_path is None:
152
+ study_args.storage_path = (
153
+ f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}"
154
+ )
155
+ # Default set group name to study name
156
+ study_args.wandb_group = study_args.wandb_group or study_args.study_name
157
+
158
+ return Args(train_args, study_args)
159
+
160
+
161
+ def objective_fn(
162
+ args: Sequence[RunArgs], study_args: StudyArgs
163
+ ) -> Callable[[optuna.Trial], float]:
164
+ def objective(trial: optuna.Trial) -> float:
165
+ if len(args) == 1:
166
+ return simple_optimize(trial, args[0], study_args)
167
+ else:
168
+ return stepwise_optimize(trial, args, study_args)
169
+
170
+ return objective
171
+
172
+
173
+ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float:
174
+ base_hyperparams = load_hyperparams(args.algo, args.env)
175
+ base_config = Config(args, base_hyperparams, os.getcwd())
176
+ if args.algo == "a2c":
177
+ hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
178
+ else:
179
+ raise ValueError(f"Optimizing {args.algo} isn't supported")
180
+ config = Config(args, hyperparams, os.getcwd())
181
+
182
+ wandb_enabled = bool(study_args.wandb_project_name)
183
+ if wandb_enabled:
184
+ wandb.init(
185
+ project=study_args.wandb_project_name,
186
+ entity=study_args.wandb_entity,
187
+ config=asdict(hyperparams),
188
+ name=f"{config.model_name()}-{str(trial.number)}",
189
+ tags=study_args.wandb_tags,
190
+ group=study_args.wandb_group,
191
+ sync_tensorboard=True,
192
+ monitor_gym=True,
193
+ save_code=True,
194
+ reinit=True,
195
+ )
196
+ wandb.config.update(args)
197
+
198
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
199
+ set_seeds(args.seed, args.use_deterministic_algorithms)
200
+
201
+ env = make_env(
202
+ config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
203
+ )
204
+ device = get_device(config, env)
205
+ policy_factory = lambda: make_policy(
206
+ args.algo, env, device, **config.policy_hyperparams
207
+ )
208
+ policy = policy_factory()
209
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
210
+
211
+ eval_env = make_eval_env(
212
+ config,
213
+ EnvHyperparams(**config.env_hyperparams),
214
+ override_n_envs=study_args.n_eval_envs,
215
+ )
216
+ optimize_callback = OptimizeCallback(
217
+ policy,
218
+ eval_env,
219
+ trial,
220
+ tb_writer,
221
+ step_freq=config.n_timesteps // study_args.n_evaluations,
222
+ n_episodes=study_args.n_eval_episodes,
223
+ deterministic=config.eval_hyperparams.get("deterministic", True),
224
+ )
225
+ callbacks: List[Callback] = [optimize_callback]
226
+ if config.hyperparams.microrts_reward_decay_callback:
227
+ callbacks.append(MicrortsRewardDecayCallback(config, env))
228
+ selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
229
+ if selfPlayWrapper:
230
+ callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
231
+ try:
232
+ algo.learn(config.n_timesteps, callbacks=callbacks)
233
+
234
+ if not optimize_callback.is_pruned:
235
+ optimize_callback.evaluate()
236
+ if not optimize_callback.is_pruned:
237
+ policy.save(config.model_dir_path(best=False))
238
+
239
+ eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
240
+ train_stat: EpisodesStats = callback.last_train_stat # type: ignore
241
+
242
+ tb_writer.add_hparams(
243
+ hparam_dict(hyperparams, vars(args)),
244
+ {
245
+ "hparam/last_mean": eval_stat.score.mean,
246
+ "hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
247
+ "hparam/train_mean": train_stat.score.mean,
248
+ "hparam/train_result": train_stat.score.mean - train_stat.score.std,
249
+ "hparam/score": optimize_callback.last_score,
250
+ "hparam/is_pruned": optimize_callback.is_pruned,
251
+ },
252
+ None,
253
+ config.run_name(),
254
+ )
255
+ tb_writer.close()
256
+
257
+ if wandb_enabled:
258
+ wandb.run.summary["state"] = ( # type: ignore
259
+ "Pruned" if optimize_callback.is_pruned else "Complete"
260
+ )
261
+ wandb.finish(quiet=True)
262
+
263
+ if optimize_callback.is_pruned:
264
+ raise optuna.exceptions.TrialPruned()
265
+
266
+ return optimize_callback.last_score
267
+ except AssertionError as e:
268
+ logging.warning(e)
269
+ return np.nan
270
+ finally:
271
+ env.close()
272
+ eval_env.close()
273
+ gc.collect()
274
+ torch.cuda.empty_cache()
275
+
276
+
277
+ def stepwise_optimize(
278
+ trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs
279
+ ) -> float:
280
+ algo = args[0].algo
281
+ env_id = args[0].env
282
+ base_hyperparams = load_hyperparams(algo, env_id)
283
+ base_config = Config(args[0], base_hyperparams, os.getcwd())
284
+ if algo == "a2c":
285
+ hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
286
+ else:
287
+ raise ValueError(f"Optimizing {algo} isn't supported")
288
+
289
+ wandb_enabled = bool(study_args.wandb_project_name)
290
+ if wandb_enabled:
291
+ wandb.init(
292
+ project=study_args.wandb_project_name,
293
+ entity=study_args.wandb_entity,
294
+ config=asdict(hyperparams),
295
+ name=f"{str(trial.number)}-S{base_config.seed()}",
296
+ tags=study_args.wandb_tags,
297
+ group=study_args.wandb_group,
298
+ save_code=True,
299
+ reinit=True,
300
+ )
301
+
302
+ score = -np.inf
303
+
304
+ for i in range(study_args.n_evaluations):
305
+ evaluations: List[Evaluation] = []
306
+
307
+ for arg in args:
308
+ config = Config(arg, hyperparams, os.getcwd())
309
+
310
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
311
+ set_seeds(arg.seed, arg.use_deterministic_algorithms)
312
+
313
+ env = make_env(
314
+ config,
315
+ EnvHyperparams(**config.env_hyperparams),
316
+ normalize_load_path=config.model_dir_path() if i > 0 else None,
317
+ tb_writer=tb_writer,
318
+ )
319
+ device = get_device(config, env)
320
+ policy_factory = lambda: make_policy(
321
+ arg.algo, env, device, **config.policy_hyperparams
322
+ )
323
+ policy = policy_factory()
324
+ if i > 0:
325
+ policy.load(config.model_dir_path())
326
+ algo = ALGOS[arg.algo](
327
+ policy, env, device, tb_writer, **config.algo_hyperparams
328
+ )
329
+
330
+ eval_env = make_eval_env(
331
+ config,
332
+ EnvHyperparams(**config.env_hyperparams),
333
+ normalize_load_path=config.model_dir_path() if i > 0 else None,
334
+ override_n_envs=study_args.n_eval_envs,
335
+ )
336
+
337
+ start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
338
+ train_timesteps = (
339
+ int((i + 1) * config.n_timesteps / study_args.n_evaluations)
340
+ - start_timesteps
341
+ )
342
+
343
+ callbacks = []
344
+ if config.hyperparams.microrts_reward_decay_callback:
345
+ callbacks.append(
346
+ MicrortsRewardDecayCallback(
347
+ config, env, start_timesteps=start_timesteps
348
+ )
349
+ )
350
+ selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
351
+ if selfPlayWrapper:
352
+ callbacks.append(
353
+ SelfPlayCallback(policy, policy_factory, selfPlayWrapper)
354
+ )
355
+ try:
356
+ algo.learn(
357
+ train_timesteps,
358
+ callbacks=callbacks,
359
+ total_timesteps=config.n_timesteps,
360
+ start_timesteps=start_timesteps,
361
+ )
362
+
363
+ evaluations.append(
364
+ evaluation(
365
+ policy,
366
+ eval_env,
367
+ tb_writer,
368
+ study_args.n_eval_episodes,
369
+ config.eval_hyperparams.get("deterministic", True),
370
+ start_timesteps + train_timesteps,
371
+ )
372
+ )
373
+
374
+ policy.save(config.model_dir_path())
375
+
376
+ tb_writer.close()
377
+
378
+ except AssertionError as e:
379
+ logging.warning(e)
380
+ if wandb_enabled:
381
+ wandb_finish("Error")
382
+ return np.nan
383
+ finally:
384
+ env.close()
385
+ eval_env.close()
386
+ gc.collect()
387
+ torch.cuda.empty_cache()
388
+
389
+ d = {}
390
+ for idx, e in enumerate(evaluations):
391
+ d[f"{idx}/eval_mean"] = e.eval_stat.score.mean
392
+ d[f"{idx}/train_mean"] = e.train_stat.score.mean
393
+ d[f"{idx}/score"] = e.score
394
+ d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item()
395
+ d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item()
396
+ score = np.mean([e.score for e in evaluations]).item()
397
+ d["score"] = score
398
+
399
+ step = i + 1
400
+ wandb.log(d, step=step)
401
+
402
+ print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}")
403
+ trial.report(score, step)
404
+ if trial.should_prune():
405
+ if wandb_enabled:
406
+ wandb_finish("Pruned")
407
+ raise optuna.exceptions.TrialPruned()
408
+
409
+ if wandb_enabled:
410
+ wandb_finish("Complete")
411
+ return score
412
+
413
+
414
+ def wandb_finish(state: str) -> None:
415
+ wandb.run.summary["state"] = state # type: ignore
416
+ wandb.finish(quiet=True)
417
+
418
+
419
+ def optimize() -> None:
420
+ from pyvirtualdisplay.display import Display
421
+
422
+ train_args, study_args = parse_args()
423
+ if study_args.virtual_display:
424
+ virtual_display = Display(visible=False, size=(1400, 900))
425
+ virtual_display.start()
426
+
427
+ sampler = TPESampler(**TPESampler.hyperopt_parameters())
428
+ pruner = HyperbandPruner()
429
+ if study_args.load_study:
430
+ assert study_args.study_name
431
+ assert study_args.storage_path
432
+ study = optuna.load_study(
433
+ study_name=study_args.study_name,
434
+ storage=study_args.storage_path,
435
+ sampler=sampler,
436
+ pruner=pruner,
437
+ )
438
+ else:
439
+ study = optuna.create_study(
440
+ study_name=study_args.study_name,
441
+ storage=study_args.storage_path,
442
+ sampler=sampler,
443
+ pruner=pruner,
444
+ direction="maximize",
445
+ )
446
+
447
+ try:
448
+ study.optimize(
449
+ objective_fn(train_args, study_args),
450
+ n_trials=study_args.n_trials,
451
+ n_jobs=study_args.n_jobs,
452
+ timeout=study_args.timeout,
453
+ )
454
+ except KeyboardInterrupt:
455
+ pass
456
+
457
+ best = study.best_trial
458
+ print(f"Best Trial Value: {best.value}")
459
+ print("Attributes:")
460
+ for key, value in list(best.params.items()) + list(best.user_attrs.items()):
461
+ print(f" {key}: {value}")
462
+
463
+ df = study.trials_dataframe()
464
+ df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False)
465
+ print(df.to_markdown(index=False))
466
+
467
+ fig1 = plot_optimization_history(study)
468
+ fig1.write_image("opt_history.png")
469
+
470
+ fig2 = plot_param_importances(study)
471
+ fig2.write_image("param_importances.png")
472
+
473
+
474
+ if __name__ == "__main__":
475
+ optimize()
rl_algo_impls/ppo/ppo.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import asdict, dataclass
3
+ from time import perf_counter
4
+ from typing import List, NamedTuple, Optional, TypeVar
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.optim import Adam
10
+ from torch.utils.tensorboard.writer import SummaryWriter
11
+
12
+ from rl_algo_impls.shared.algorithm import Algorithm
13
+ from rl_algo_impls.shared.callbacks import Callback
14
+ from rl_algo_impls.shared.gae import compute_advantages
15
+ from rl_algo_impls.shared.policy.actor_critic import ActorCritic
16
+ from rl_algo_impls.shared.schedule import (
17
+ constant_schedule,
18
+ linear_schedule,
19
+ schedule,
20
+ update_learning_rate,
21
+ )
22
+ from rl_algo_impls.shared.stats import log_scalars
23
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
24
+ VecEnv,
25
+ single_action_space,
26
+ single_observation_space,
27
+ )
28
+
29
+
30
+ class TrainStepStats(NamedTuple):
31
+ loss: float
32
+ pi_loss: float
33
+ v_loss: float
34
+ entropy_loss: float
35
+ approx_kl: float
36
+ clipped_frac: float
37
+ val_clipped_frac: float
38
+
39
+
40
+ @dataclass
41
+ class TrainStats:
42
+ loss: float
43
+ pi_loss: float
44
+ v_loss: float
45
+ entropy_loss: float
46
+ approx_kl: float
47
+ clipped_frac: float
48
+ val_clipped_frac: float
49
+ explained_var: float
50
+
51
+ def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None:
52
+ self.loss = np.mean([s.loss for s in step_stats]).item()
53
+ self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
54
+ self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
55
+ self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
56
+ self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
57
+ self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
58
+ self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item()
59
+ self.explained_var = explained_var
60
+
61
+ def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
62
+ for name, value in asdict(self).items():
63
+ tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
64
+
65
+ def __repr__(self) -> str:
66
+ return " | ".join(
67
+ [
68
+ f"Loss: {round(self.loss, 2)}",
69
+ f"Pi L: {round(self.pi_loss, 2)}",
70
+ f"V L: {round(self.v_loss, 2)}",
71
+ f"E L: {round(self.entropy_loss, 2)}",
72
+ f"Apx KL Div: {round(self.approx_kl, 2)}",
73
+ f"Clip Frac: {round(self.clipped_frac, 2)}",
74
+ f"Val Clip Frac: {round(self.val_clipped_frac, 2)}",
75
+ ]
76
+ )
77
+
78
+
79
+ PPOSelf = TypeVar("PPOSelf", bound="PPO")
80
+
81
+
82
+ class PPO(Algorithm):
83
+ def __init__(
84
+ self,
85
+ policy: ActorCritic,
86
+ env: VecEnv,
87
+ device: torch.device,
88
+ tb_writer: SummaryWriter,
89
+ learning_rate: float = 3e-4,
90
+ learning_rate_decay: str = "none",
91
+ n_steps: int = 2048,
92
+ batch_size: int = 64,
93
+ n_epochs: int = 10,
94
+ gamma: float = 0.99,
95
+ gae_lambda: float = 0.95,
96
+ clip_range: float = 0.2,
97
+ clip_range_decay: str = "none",
98
+ clip_range_vf: Optional[float] = None,
99
+ clip_range_vf_decay: str = "none",
100
+ normalize_advantage: bool = True,
101
+ ent_coef: float = 0.0,
102
+ ent_coef_decay: str = "none",
103
+ vf_coef: float = 0.5,
104
+ ppo2_vf_coef_halving: bool = False,
105
+ max_grad_norm: float = 0.5,
106
+ sde_sample_freq: int = -1,
107
+ update_advantage_between_epochs: bool = True,
108
+ update_returns_between_epochs: bool = False,
109
+ gamma_end: Optional[float] = None,
110
+ ) -> None:
111
+ super().__init__(policy, env, device, tb_writer)
112
+ self.policy = policy
113
+ self.get_action_mask = getattr(env, "get_action_mask")
114
+
115
+ self.gamma_schedule = (
116
+ linear_schedule(gamma, gamma_end)
117
+ if gamma_end is not None
118
+ else constant_schedule(gamma)
119
+ )
120
+ self.gae_lambda = gae_lambda
121
+ self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
122
+ self.lr_schedule = schedule(learning_rate_decay, learning_rate)
123
+ self.max_grad_norm = max_grad_norm
124
+ self.clip_range_schedule = schedule(clip_range_decay, clip_range)
125
+ self.clip_range_vf_schedule = None
126
+ if clip_range_vf:
127
+ self.clip_range_vf_schedule = schedule(clip_range_vf_decay, clip_range_vf)
128
+
129
+ if normalize_advantage:
130
+ assert (
131
+ env.num_envs * n_steps > 1 and batch_size > 1
132
+ ), f"Each minibatch must be larger than 1 to support normalization"
133
+ self.normalize_advantage = normalize_advantage
134
+
135
+ self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
136
+ self.vf_coef = vf_coef
137
+ self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
138
+
139
+ self.n_steps = n_steps
140
+ self.batch_size = batch_size
141
+ self.n_epochs = n_epochs
142
+ self.sde_sample_freq = sde_sample_freq
143
+
144
+ self.update_advantage_between_epochs = update_advantage_between_epochs
145
+ self.update_returns_between_epochs = update_returns_between_epochs
146
+
147
+ def learn(
148
+ self: PPOSelf,
149
+ train_timesteps: int,
150
+ callbacks: Optional[List[Callback]] = None,
151
+ total_timesteps: Optional[int] = None,
152
+ start_timesteps: int = 0,
153
+ ) -> PPOSelf:
154
+ if total_timesteps is None:
155
+ total_timesteps = train_timesteps
156
+ assert start_timesteps + train_timesteps <= total_timesteps
157
+
158
+ epoch_dim = (self.n_steps, self.env.num_envs)
159
+ step_dim = (self.env.num_envs,)
160
+ obs_space = single_observation_space(self.env)
161
+ act_space = single_action_space(self.env)
162
+ act_shape = self.policy.action_shape
163
+
164
+ next_obs = self.env.reset()
165
+ next_action_masks = self.get_action_mask() if self.get_action_mask else None
166
+ next_episode_starts = np.full(step_dim, True, dtype=np.bool_)
167
+
168
+ obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
169
+ actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
170
+ rewards = np.zeros(epoch_dim, dtype=np.float32)
171
+ episode_starts = np.zeros(epoch_dim, dtype=np.bool_)
172
+ values = np.zeros(epoch_dim, dtype=np.float32)
173
+ logprobs = np.zeros(epoch_dim, dtype=np.float32)
174
+ action_masks = (
175
+ np.zeros(
176
+ (self.n_steps,) + next_action_masks.shape, dtype=next_action_masks.dtype
177
+ )
178
+ if next_action_masks is not None
179
+ else None
180
+ )
181
+
182
+ timesteps_elapsed = start_timesteps
183
+ while timesteps_elapsed < start_timesteps + train_timesteps:
184
+ start_time = perf_counter()
185
+
186
+ progress = timesteps_elapsed / total_timesteps
187
+ ent_coef = self.ent_coef_schedule(progress)
188
+ learning_rate = self.lr_schedule(progress)
189
+ update_learning_rate(self.optimizer, learning_rate)
190
+ pi_clip = self.clip_range_schedule(progress)
191
+ gamma = self.gamma_schedule(progress)
192
+ chart_scalars = {
193
+ "learning_rate": self.optimizer.param_groups[0]["lr"],
194
+ "ent_coef": ent_coef,
195
+ "pi_clip": pi_clip,
196
+ "gamma": gamma,
197
+ }
198
+ if self.clip_range_vf_schedule:
199
+ v_clip = self.clip_range_vf_schedule(progress)
200
+ chart_scalars["v_clip"] = v_clip
201
+ else:
202
+ v_clip = None
203
+ log_scalars(self.tb_writer, "charts", chart_scalars, timesteps_elapsed)
204
+
205
+ self.policy.eval()
206
+ self.policy.reset_noise()
207
+ for s in range(self.n_steps):
208
+ timesteps_elapsed += self.env.num_envs
209
+ if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
210
+ self.policy.reset_noise()
211
+
212
+ obs[s] = next_obs
213
+ episode_starts[s] = next_episode_starts
214
+ if action_masks is not None:
215
+ action_masks[s] = next_action_masks
216
+
217
+ (
218
+ actions[s],
219
+ values[s],
220
+ logprobs[s],
221
+ clamped_action,
222
+ ) = self.policy.step(next_obs, action_masks=next_action_masks)
223
+ next_obs, rewards[s], next_episode_starts, _ = self.env.step(
224
+ clamped_action
225
+ )
226
+ next_action_masks = (
227
+ self.get_action_mask() if self.get_action_mask else None
228
+ )
229
+
230
+ self.policy.train()
231
+
232
+ b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device) # type: ignore
233
+ b_actions = torch.tensor(actions.reshape((-1,) + act_shape)).to( # type: ignore
234
+ self.device
235
+ )
236
+ b_logprobs = torch.tensor(logprobs.reshape(-1)).to(self.device)
237
+ b_action_masks = (
238
+ torch.tensor(action_masks.reshape((-1,) + next_action_masks.shape[1:])).to( # type: ignore
239
+ self.device
240
+ )
241
+ if action_masks is not None
242
+ else None
243
+ )
244
+
245
+ y_pred = values.reshape(-1)
246
+ b_values = torch.tensor(y_pred).to(self.device)
247
+
248
+ step_stats = []
249
+ # Define variables that will definitely be set through the first epoch
250
+ advantages: np.ndarray = None # type: ignore
251
+ b_advantages: torch.Tensor = None # type: ignore
252
+ y_true: np.ndarray = None # type: ignore
253
+ b_returns: torch.Tensor = None # type: ignore
254
+ for e in range(self.n_epochs):
255
+ if e == 0 or self.update_advantage_between_epochs:
256
+ advantages = compute_advantages(
257
+ rewards,
258
+ values,
259
+ episode_starts,
260
+ next_episode_starts,
261
+ next_obs,
262
+ self.policy,
263
+ gamma,
264
+ self.gae_lambda,
265
+ )
266
+ b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
267
+ if e == 0 or self.update_returns_between_epochs:
268
+ returns = advantages + values
269
+ y_true = returns.reshape(-1)
270
+ b_returns = torch.tensor(y_true).to(self.device)
271
+
272
+ b_idxs = torch.randperm(len(b_obs))
273
+ # Only record last epoch's stats
274
+ step_stats.clear()
275
+ for i in range(0, len(b_obs), self.batch_size):
276
+ self.policy.reset_noise(self.batch_size)
277
+
278
+ mb_idxs = b_idxs[i : i + self.batch_size]
279
+
280
+ mb_obs = b_obs[mb_idxs]
281
+ mb_actions = b_actions[mb_idxs]
282
+ mb_values = b_values[mb_idxs]
283
+ mb_logprobs = b_logprobs[mb_idxs]
284
+ mb_action_masks = (
285
+ b_action_masks[mb_idxs] if b_action_masks is not None else None
286
+ )
287
+
288
+ mb_adv = b_advantages[mb_idxs]
289
+ if self.normalize_advantage:
290
+ mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)
291
+ mb_returns = b_returns[mb_idxs]
292
+
293
+ new_logprobs, entropy, new_values = self.policy(
294
+ mb_obs, mb_actions, action_masks=mb_action_masks
295
+ )
296
+
297
+ logratio = new_logprobs - mb_logprobs
298
+ ratio = torch.exp(logratio)
299
+ clipped_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
300
+ pi_loss = torch.max(-ratio * mb_adv, -clipped_ratio * mb_adv).mean()
301
+
302
+ v_loss_unclipped = (new_values - mb_returns) ** 2
303
+ if v_clip:
304
+ v_loss_clipped = (
305
+ mb_values
306
+ + torch.clamp(new_values - mb_values, -v_clip, v_clip)
307
+ - mb_returns
308
+ ) ** 2
309
+ v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
310
+ else:
311
+ v_loss = v_loss_unclipped.mean()
312
+
313
+ if self.ppo2_vf_coef_halving:
314
+ v_loss *= 0.5
315
+
316
+ entropy_loss = -entropy.mean()
317
+
318
+ loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
319
+
320
+ self.optimizer.zero_grad()
321
+ loss.backward()
322
+ nn.utils.clip_grad_norm_(
323
+ self.policy.parameters(), self.max_grad_norm
324
+ )
325
+ self.optimizer.step()
326
+
327
+ with torch.no_grad():
328
+ approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
329
+ clipped_frac = (
330
+ ((ratio - 1).abs() > pi_clip)
331
+ .float()
332
+ .mean()
333
+ .cpu()
334
+ .numpy()
335
+ .item()
336
+ )
337
+ val_clipped_frac = (
338
+ ((new_values - mb_values).abs() > v_clip)
339
+ .float()
340
+ .mean()
341
+ .cpu()
342
+ .numpy()
343
+ .item()
344
+ if v_clip
345
+ else 0
346
+ )
347
+
348
+ step_stats.append(
349
+ TrainStepStats(
350
+ loss.item(),
351
+ pi_loss.item(),
352
+ v_loss.item(),
353
+ entropy_loss.item(),
354
+ approx_kl,
355
+ clipped_frac,
356
+ val_clipped_frac,
357
+ )
358
+ )
359
+
360
+ var_y = np.var(y_true).item()
361
+ explained_var = (
362
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
363
+ )
364
+ TrainStats(step_stats, explained_var).write_to_tensorboard(
365
+ self.tb_writer, timesteps_elapsed
366
+ )
367
+
368
+ end_time = perf_counter()
369
+ rollout_steps = self.n_steps * self.env.num_envs
370
+ self.tb_writer.add_scalar(
371
+ "train/steps_per_second",
372
+ rollout_steps / (end_time - start_time),
373
+ timesteps_elapsed,
374
+ )
375
+
376
+ if callbacks:
377
+ if not all(
378
+ c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
379
+ ):
380
+ logging.info(
381
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
382
+ )
383
+ break
384
+
385
+ return self
rl_algo_impls/publish/markdown_format.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import wandb.apis.public
4
+ import yaml
5
+
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass, asdict
8
+ from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
9
+ from urllib.parse import urlparse
10
+
11
+ from rl_algo_impls.runner.evaluate import Evaluation
12
+
13
+ EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
14
+
15
+
16
+ @dataclass
17
+ class EvaluationRow:
18
+ algo: str
19
+ env: str
20
+ seed: Optional[int]
21
+ reward_mean: float
22
+ reward_std: float
23
+ eval_episodes: int
24
+ best: str
25
+ wandb_url: str
26
+
27
+ @staticmethod
28
+ def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
29
+ results = defaultdict(list)
30
+ for r in rows:
31
+ for k, v in asdict(r).items():
32
+ results[k].append(v)
33
+ return pd.DataFrame(results)
34
+
35
+
36
+ class EvalTableData(NamedTuple):
37
+ run: wandb.apis.public.Run
38
+ evaluation: Evaluation
39
+
40
+
41
+ def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
42
+ best_stats = sorted(
43
+ [d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
44
+ )[0]
45
+ table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
46
+ rows = [
47
+ EvaluationRow(
48
+ config.algo,
49
+ config.env_id,
50
+ config.seed(),
51
+ stats.score.mean,
52
+ stats.score.std,
53
+ len(stats),
54
+ "*" if stats == best_stats else "",
55
+ f"[wandb]({r.url})",
56
+ )
57
+ for (r, (_, stats, config)) in table_data
58
+ ]
59
+ df = EvaluationRow.data_frame(rows)
60
+ return df.to_markdown(index=False)
61
+
62
+
63
+ def github_project_link(github_url: str) -> str:
64
+ return f"[{urlparse(github_url).path}]({github_url})"
65
+
66
+
67
+ def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
68
+ algo_caps = algo.upper()
69
+ lines = [
70
+ f"# **{algo_caps}** Agent playing **{env}**",
71
+ f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
72
+ f"the {github_project_link(github_url)} repo.",
73
+ f"All models trained at this commit can be found at {wandb_report_url}.",
74
+ ]
75
+ return "\n\n".join(lines)
76
+
77
+
78
+ def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
79
+ if not commit_hash:
80
+ return github_project_link(github_url)
81
+ return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
82
+
83
+
84
+ def results_section(
85
+ table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
86
+ ) -> str:
87
+ # type: ignore
88
+ lines = [
89
+ "## Training Results",
90
+ f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
91
+ + "agents using different initial seeds. "
92
+ + f"These agents were trained by checking out "
93
+ + f"{github_tree_link(github_url, commit_hash)}. "
94
+ + "The best and last models were kept from each training. "
95
+ + "This submission has loaded the best models from each training, reevaluates "
96
+ + "them, and selects the best model from these latest evaluations (mean - std).",
97
+ ]
98
+ lines.append(evaluation_table(table_data))
99
+ return "\n\n".join(lines)
100
+
101
+
102
+ def prerequisites_section() -> str:
103
+ return """
104
+ ### Prerequisites: Weights & Biases (WandB)
105
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
106
+ By default training goes to a rl-algo-impls project while benchmarks go to
107
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
108
+ models and the model weights are uploaded to WandB.
109
+
110
+ Before doing anything below, you'll need to create a wandb account and run `wandb
111
+ login`.
112
+ """
113
+
114
+
115
+ def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
116
+ return f"""
117
+ ## Usage
118
+ {urlparse(github_url).path}: {github_url}
119
+
120
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
121
+ implementation could be sufficiently different to not be able to reproduce similar
122
+ results. You might need to checkout the commit the agent was trained on:
123
+ {github_tree_link(github_url, commit_hash)}.
124
+ ```
125
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
126
+ python enjoy.py --wandb-run-path={run_path}
127
+ ```
128
+
129
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
130
+ Colab starting from the
131
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
132
+ notebook.
133
+ """
134
+
135
+
136
+ def training_setion(
137
+ github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
138
+ ) -> str:
139
+ return f"""
140
+ ## Training
141
+ If you want the highest chance to reproduce these results, you'll want to checkout the
142
+ commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
143
+ training is deterministic, different hardware will give different results.
144
+
145
+ ```
146
+ python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
147
+ ```
148
+
149
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
150
+ Colab starting from the
151
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
152
+ notebook.
153
+ """
154
+
155
+
156
+ def benchmarking_section(report_url: str) -> str:
157
+ return f"""
158
+ ## Benchmarking (with Lambda Labs instance)
159
+ This and other models from {report_url} were generated by running a script on a Lambda
160
+ Labs instance. In a Lambda Labs instance terminal:
161
+ ```
162
+ git clone git@github.com:sgoodfriend/rl-algo-impls.git
163
+ cd rl-algo-impls
164
+ bash ./lambda_labs/setup.sh
165
+ wandb login
166
+ bash ./lambda_labs/benchmark.sh [-a {{"ppo a2c dqn vpg"}}] [-e ENVS] [-j {{6}}] [-p {{rl-algo-impls-benchmarks}}] [-s {{"1 2 3"}}]
167
+ ```
168
+
169
+ ### Alternative: Google Colab Pro+
170
+ As an alternative,
171
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
172
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
173
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
174
+ """
175
+
176
+
177
+ def hyperparams_section(run_config: Dict[str, Any]) -> str:
178
+ return f"""
179
+ ## Hyperparameters
180
+ This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
181
+ run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
182
+ close and has some additional data:
183
+ ```
184
+ {yaml.dump(run_config)}
185
+ ```
186
+ """
187
+
188
+
189
+ def model_card_text(
190
+ algo: str,
191
+ env: str,
192
+ github_url: str,
193
+ commit_hash: str,
194
+ wandb_report_url: str,
195
+ table_data: List[EvalTableData],
196
+ best_eval: EvalTableData,
197
+ ) -> str:
198
+ run, (_, _, config) = best_eval
199
+ run_path = "/".join(run.path)
200
+ return "\n\n".join(
201
+ [
202
+ header_section(algo, env, github_url, wandb_report_url),
203
+ results_section(table_data, algo, github_url, commit_hash),
204
+ prerequisites_section(),
205
+ usage_section(github_url, run_path, commit_hash),
206
+ training_setion(github_url, commit_hash, algo, env, config.seed()),
207
+ benchmarking_section(wandb_report_url),
208
+ hyperparams_section(run.config),
209
+ ]
210
+ )
rl_algo_impls/runner/config.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import inspect
3
+ import itertools
4
+ import os
5
+ from dataclasses import dataclass
6
+ from datetime import datetime
7
+ from typing import Any, Dict, List, Optional, Type, TypeVar, Union
8
+
9
+ RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
10
+
11
+
12
+ @dataclass
13
+ class RunArgs:
14
+ algo: str
15
+ env: str
16
+ seed: Optional[int] = None
17
+ use_deterministic_algorithms: bool = True
18
+
19
+ @classmethod
20
+ def expand_from_dict(
21
+ cls: Type[RunArgsSelf], d: Dict[str, Any]
22
+ ) -> List[RunArgsSelf]:
23
+ maybe_listify = lambda v: [v] if isinstance(v, str) or isinstance(v, int) else v
24
+ algos = maybe_listify(d["algo"])
25
+ envs = maybe_listify(d["env"])
26
+ seeds = maybe_listify(d["seed"])
27
+ args = []
28
+ for algo, env, seed in itertools.product(algos, envs, seeds):
29
+ _d = d.copy()
30
+ _d.update({"algo": algo, "env": env, "seed": seed})
31
+ args.append(cls(**_d))
32
+ return args
33
+
34
+
35
+ @dataclass
36
+ class EnvHyperparams:
37
+ env_type: str = "gymvec"
38
+ n_envs: int = 1
39
+ frame_stack: int = 1
40
+ make_kwargs: Optional[Dict[str, Any]] = None
41
+ no_reward_timeout_steps: Optional[int] = None
42
+ no_reward_fire_steps: Optional[int] = None
43
+ vec_env_class: str = "sync"
44
+ normalize: bool = False
45
+ normalize_kwargs: Optional[Dict[str, Any]] = None
46
+ rolling_length: int = 100
47
+ train_record_video: bool = False
48
+ video_step_interval: Union[int, float] = 1_000_000
49
+ initial_steps_to_truncate: Optional[int] = None
50
+ clip_atari_rewards: bool = True
51
+ normalize_type: Optional[str] = None
52
+ mask_actions: bool = False
53
+ bots: Optional[Dict[str, int]] = None
54
+ self_play_kwargs: Optional[Dict[str, Any]] = None
55
+
56
+
57
+ HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
58
+
59
+
60
+ @dataclass
61
+ class Hyperparams:
62
+ device: str = "auto"
63
+ n_timesteps: Union[int, float] = 100_000
64
+ env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
65
+ policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
66
+ algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
67
+ eval_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
68
+ env_id: Optional[str] = None
69
+ additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
70
+ microrts_reward_decay_callback: bool = False
71
+
72
+ @classmethod
73
+ def from_dict_with_extra_fields(
74
+ cls: Type[HyperparamsSelf], d: Dict[str, Any]
75
+ ) -> HyperparamsSelf:
76
+ return cls(
77
+ **{k: v for k, v in d.items() if k in inspect.signature(cls).parameters}
78
+ )
79
+
80
+
81
+ @dataclass
82
+ class Config:
83
+ args: RunArgs
84
+ hyperparams: Hyperparams
85
+ root_dir: str
86
+ run_id: str = datetime.now().isoformat()
87
+
88
+ def seed(self, training: bool = True) -> Optional[int]:
89
+ seed = self.args.seed
90
+ if training or seed is None:
91
+ return seed
92
+ return seed + self.env_hyperparams.get("n_envs", 1)
93
+
94
+ @property
95
+ def device(self) -> str:
96
+ return self.hyperparams.device
97
+
98
+ @property
99
+ def n_timesteps(self) -> int:
100
+ return int(self.hyperparams.n_timesteps)
101
+
102
+ @property
103
+ def env_hyperparams(self) -> Dict[str, Any]:
104
+ return self.hyperparams.env_hyperparams
105
+
106
+ @property
107
+ def policy_hyperparams(self) -> Dict[str, Any]:
108
+ return self.hyperparams.policy_hyperparams
109
+
110
+ @property
111
+ def algo_hyperparams(self) -> Dict[str, Any]:
112
+ return self.hyperparams.algo_hyperparams
113
+
114
+ @property
115
+ def eval_hyperparams(self) -> Dict[str, Any]:
116
+ return self.hyperparams.eval_hyperparams
117
+
118
+ def eval_callback_params(self) -> Dict[str, Any]:
119
+ eval_hyperparams = self.eval_hyperparams.copy()
120
+ if "env_overrides" in eval_hyperparams:
121
+ del eval_hyperparams["env_overrides"]
122
+ return eval_hyperparams
123
+
124
+ @property
125
+ def algo(self) -> str:
126
+ return self.args.algo
127
+
128
+ @property
129
+ def env_id(self) -> str:
130
+ return self.hyperparams.env_id or self.args.env
131
+
132
+ @property
133
+ def additional_keys_to_log(self) -> List[str]:
134
+ return self.hyperparams.additional_keys_to_log
135
+
136
+ def model_name(self, include_seed: bool = True) -> str:
137
+ # Use arg env name instead of environment name
138
+ parts = [self.algo, self.args.env]
139
+ if include_seed and self.args.seed is not None:
140
+ parts.append(f"S{self.args.seed}")
141
+
142
+ # Assume that the custom arg name already has the necessary information
143
+ if not self.hyperparams.env_id:
144
+ make_kwargs = self.env_hyperparams.get("make_kwargs", {})
145
+ if make_kwargs:
146
+ for k, v in make_kwargs.items():
147
+ if type(v) == bool and v:
148
+ parts.append(k)
149
+ elif type(v) == int and v:
150
+ parts.append(f"{k}{v}")
151
+ else:
152
+ parts.append(str(v))
153
+
154
+ return "-".join(parts)
155
+
156
+ def run_name(self, include_seed: bool = True) -> str:
157
+ parts = [self.model_name(include_seed=include_seed), self.run_id]
158
+ return "-".join(parts)
159
+
160
+ @property
161
+ def saved_models_dir(self) -> str:
162
+ return os.path.join(self.root_dir, "saved_models")
163
+
164
+ @property
165
+ def downloaded_models_dir(self) -> str:
166
+ return os.path.join(self.root_dir, "downloaded_models")
167
+
168
+ def model_dir_name(
169
+ self,
170
+ best: bool = False,
171
+ extension: str = "",
172
+ ) -> str:
173
+ return self.model_name() + ("-best" if best else "") + extension
174
+
175
+ def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
176
+ return os.path.join(
177
+ self.saved_models_dir if not downloaded else self.downloaded_models_dir,
178
+ self.model_dir_name(best=best),
179
+ )
180
+
181
+ @property
182
+ def runs_dir(self) -> str:
183
+ return os.path.join(self.root_dir, "runs")
184
+
185
+ @property
186
+ def tensorboard_summary_path(self) -> str:
187
+ return os.path.join(self.runs_dir, self.run_name())
188
+
189
+ @property
190
+ def logs_path(self) -> str:
191
+ return os.path.join(self.runs_dir, f"log.yml")
192
+
193
+ @property
194
+ def videos_dir(self) -> str:
195
+ return os.path.join(self.root_dir, "videos")
196
+
197
+ @property
198
+ def video_prefix(self) -> str:
199
+ return os.path.join(self.videos_dir, self.model_name())
200
+
201
+ @property
202
+ def best_videos_dir(self) -> str:
203
+ return os.path.join(self.videos_dir, f"{self.model_name()}-best")
rl_algo_impls/runner/evaluate.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from dataclasses import dataclass
4
+ from typing import NamedTuple, Optional
5
+
6
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
7
+ from rl_algo_impls.runner.running_utils import (
8
+ get_device,
9
+ load_hyperparams,
10
+ make_policy,
11
+ set_seeds,
12
+ )
13
+ from rl_algo_impls.shared.callbacks.eval_callback import evaluate
14
+ from rl_algo_impls.shared.policy.policy import Policy
15
+ from rl_algo_impls.shared.stats import EpisodesStats
16
+ from rl_algo_impls.shared.vec_env import make_eval_env
17
+
18
+
19
+ @dataclass
20
+ class EvalArgs(RunArgs):
21
+ render: bool = True
22
+ best: bool = True
23
+ n_envs: Optional[int] = 1
24
+ n_episodes: int = 3
25
+ deterministic_eval: Optional[bool] = None
26
+ no_print_returns: bool = False
27
+ wandb_run_path: Optional[str] = None
28
+
29
+
30
+ class Evaluation(NamedTuple):
31
+ policy: Policy
32
+ stats: EpisodesStats
33
+ config: Config
34
+
35
+
36
+ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
37
+ if args.wandb_run_path:
38
+ import wandb
39
+
40
+ api = wandb.Api()
41
+ run = api.run(args.wandb_run_path)
42
+ params = run.config
43
+
44
+ args.algo = params["algo"]
45
+ args.env = params["env"]
46
+ args.seed = params.get("seed", None)
47
+ args.use_deterministic_algorithms = params.get(
48
+ "use_deterministic_algorithms", True
49
+ )
50
+
51
+ config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
52
+ model_path = config.model_dir_path(best=args.best, downloaded=True)
53
+
54
+ model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
55
+ run.file(model_archive_name).download()
56
+ if os.path.isdir(model_path):
57
+ shutil.rmtree(model_path)
58
+ shutil.unpack_archive(model_archive_name, model_path)
59
+ os.remove(model_archive_name)
60
+ else:
61
+ hyperparams = load_hyperparams(args.algo, args.env)
62
+
63
+ config = Config(args, hyperparams, root_dir)
64
+ model_path = config.model_dir_path(best=args.best)
65
+
66
+ print(args)
67
+
68
+ set_seeds(args.seed, args.use_deterministic_algorithms)
69
+
70
+ env = make_eval_env(
71
+ config,
72
+ EnvHyperparams(**config.env_hyperparams),
73
+ override_n_envs=args.n_envs,
74
+ render=args.render,
75
+ normalize_load_path=model_path,
76
+ )
77
+ device = get_device(config, env)
78
+ policy = make_policy(
79
+ args.algo,
80
+ env,
81
+ device,
82
+ load_path=model_path,
83
+ **config.policy_hyperparams,
84
+ ).eval()
85
+
86
+ deterministic = (
87
+ args.deterministic_eval
88
+ if args.deterministic_eval is not None
89
+ else config.eval_hyperparams.get("deterministic", True)
90
+ )
91
+ return Evaluation(
92
+ policy,
93
+ evaluate(
94
+ env,
95
+ policy,
96
+ args.n_episodes,
97
+ render=args.render,
98
+ deterministic=deterministic,
99
+ print_returns=not args.no_print_returns,
100
+ ),
101
+ config,
102
+ )
rl_algo_impls/runner/running_utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ from dataclasses import asdict
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, Type, Union
8
+
9
+ import gym
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ import torch.backends.cudnn
14
+ import yaml
15
+ from gym.spaces import Box, Discrete
16
+ from torch.utils.tensorboard.writer import SummaryWriter
17
+
18
+ from rl_algo_impls.a2c.a2c import A2C
19
+ from rl_algo_impls.dqn.dqn import DQN
20
+ from rl_algo_impls.dqn.policy import DQNPolicy
21
+ from rl_algo_impls.ppo.ppo import PPO
22
+ from rl_algo_impls.runner.config import Config, Hyperparams
23
+ from rl_algo_impls.shared.algorithm import Algorithm
24
+ from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
25
+ from rl_algo_impls.shared.policy.actor_critic import ActorCritic
26
+ from rl_algo_impls.shared.policy.policy import Policy
27
+ from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
28
+ from rl_algo_impls.vpg.policy import VPGActorCritic
29
+ from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
30
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
31
+
32
+ ALGOS: Dict[str, Type[Algorithm]] = {
33
+ "dqn": DQN,
34
+ "vpg": VanillaPolicyGradient,
35
+ "ppo": PPO,
36
+ "a2c": A2C,
37
+ }
38
+ POLICIES: Dict[str, Type[Policy]] = {
39
+ "dqn": DQNPolicy,
40
+ "vpg": VPGActorCritic,
41
+ "ppo": ActorCritic,
42
+ "a2c": ActorCritic,
43
+ }
44
+
45
+ HYPERPARAMS_PATH = "hyperparams"
46
+
47
+
48
+ def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument(
51
+ "--algo",
52
+ default=["dqn"],
53
+ type=str,
54
+ choices=list(ALGOS.keys()),
55
+ nargs="+" if multiple else 1,
56
+ help="Abbreviation(s) of algorithm(s)",
57
+ )
58
+ parser.add_argument(
59
+ "--env",
60
+ default=["CartPole-v1"],
61
+ type=str,
62
+ nargs="+" if multiple else 1,
63
+ help="Name of environment(s) in gym",
64
+ )
65
+ parser.add_argument(
66
+ "--seed",
67
+ default=[1],
68
+ type=int,
69
+ nargs="*" if multiple else "?",
70
+ help="Seeds to run experiment. Unset will do one run with no set seed",
71
+ )
72
+ return parser
73
+
74
+
75
+ def load_hyperparams(algo: str, env_id: str) -> Hyperparams:
76
+ root_path = Path(__file__).parent.parent
77
+ hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
78
+ with open(hyperparams_path, "r") as f:
79
+ hyperparams_dict = yaml.safe_load(f)
80
+
81
+ if env_id in hyperparams_dict:
82
+ return Hyperparams(**hyperparams_dict[env_id])
83
+
84
+ import_for_env_id(env_id)
85
+ spec = gym.spec(env_id)
86
+ entry_point_name = str(spec.entry_point) # type: ignore
87
+ if "AtariEnv" in entry_point_name and "_atari" in hyperparams_dict:
88
+ return Hyperparams(**hyperparams_dict["_atari"])
89
+ elif "gym_microrts" in entry_point_name and "_microrts" in hyperparams_dict:
90
+ return Hyperparams(**hyperparams_dict["_microrts"])
91
+ else:
92
+ raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
93
+
94
+
95
+ def get_device(config: Config, env: VecEnv) -> torch.device:
96
+ device = config.device
97
+ # cuda by default
98
+ if device == "auto":
99
+ device = "cuda"
100
+ # Apple MPS is a second choice (sometimes)
101
+ if device == "cuda" and not torch.cuda.is_available():
102
+ device = "mps"
103
+ # If no MPS, fallback to cpu
104
+ if device == "mps" and not torch.backends.mps.is_available():
105
+ device = "cpu"
106
+ # Simple environments like Discreet and 1-D Boxes might also be better
107
+ # served with the CPU.
108
+ if device == "mps":
109
+ obs_space = single_observation_space(env)
110
+ if isinstance(obs_space, Discrete):
111
+ device = "cpu"
112
+ elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
113
+ device = "cpu"
114
+ if is_microrts(config):
115
+ device = "cpu"
116
+ print(f"Device: {device}")
117
+ return torch.device(device)
118
+
119
+
120
+ def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
121
+ if seed is None:
122
+ return
123
+ random.seed(seed)
124
+ np.random.seed(seed)
125
+ torch.manual_seed(seed)
126
+ torch.backends.cudnn.benchmark = False
127
+ torch.use_deterministic_algorithms(use_deterministic_algorithms)
128
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
129
+ # Stop warning and it would introduce stochasticity if I was using TF
130
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
131
+
132
+
133
+ def make_policy(
134
+ algo: str,
135
+ env: VecEnv,
136
+ device: torch.device,
137
+ load_path: Optional[str] = None,
138
+ **kwargs,
139
+ ) -> Policy:
140
+ policy = POLICIES[algo](env, **kwargs).to(device)
141
+ if load_path:
142
+ policy.load(load_path)
143
+ return policy
144
+
145
+
146
+ def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
147
+ figure = plt.figure()
148
+ cumulative_steps = [
149
+ (idx + 1) * callback.step_freq for idx in range(len(callback.stats))
150
+ ]
151
+ plt.plot(
152
+ cumulative_steps,
153
+ [s.score.mean for s in callback.stats],
154
+ "b-",
155
+ label="mean",
156
+ )
157
+ plt.plot(
158
+ cumulative_steps,
159
+ [s.score.mean - s.score.std for s in callback.stats],
160
+ "g--",
161
+ label="mean-std",
162
+ )
163
+ plt.fill_between(
164
+ cumulative_steps,
165
+ [s.score.min for s in callback.stats], # type: ignore
166
+ [s.score.max for s in callback.stats], # type: ignore
167
+ facecolor="cyan",
168
+ label="range",
169
+ )
170
+ plt.xlabel("Steps")
171
+ plt.ylabel("Score")
172
+ plt.legend()
173
+ plt.title(f"Eval {run_name}")
174
+ tb_writer.add_figure("eval", figure)
175
+
176
+
177
+ Scalar = Union[bool, str, float, int, None]
178
+
179
+
180
+ def hparam_dict(
181
+ hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
182
+ ) -> Dict[str, Scalar]:
183
+ flattened = args.copy()
184
+ for k, v in flattened.items():
185
+ if isinstance(v, list):
186
+ flattened[k] = json.dumps(v)
187
+ for k, v in asdict(hyperparams).items():
188
+ if isinstance(v, dict):
189
+ for sk, sv in v.items():
190
+ key = f"{k}/{sk}"
191
+ if isinstance(sv, dict) or isinstance(sv, list):
192
+ flattened[key] = str(sv)
193
+ else:
194
+ flattened[key] = sv
195
+ elif isinstance(v, list):
196
+ flattened[k] = json.dumps(v)
197
+ else:
198
+ flattened[k] = v # type: ignore
199
+ return flattened # type: ignore
rl_algo_impls/runner/train.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ from rl_algo_impls.shared.callbacks import Callback
5
+ from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
6
+ from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
7
+ from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
8
+
9
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
10
+
11
+ import dataclasses
12
+ import shutil
13
+ from dataclasses import asdict, dataclass
14
+ from typing import Any, Dict, List, Optional, Sequence
15
+
16
+ import yaml
17
+ from torch.utils.tensorboard.writer import SummaryWriter
18
+
19
+ import wandb
20
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
21
+ from rl_algo_impls.runner.running_utils import (
22
+ ALGOS,
23
+ get_device,
24
+ hparam_dict,
25
+ load_hyperparams,
26
+ make_policy,
27
+ plot_eval_callback,
28
+ set_seeds,
29
+ )
30
+ from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
31
+ from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
32
+ MicrortsRewardDecayCallback,
33
+ )
34
+ from rl_algo_impls.shared.stats import EpisodesStats
35
+ from rl_algo_impls.shared.vec_env import make_env, make_eval_env
36
+
37
+
38
+ @dataclass
39
+ class TrainArgs(RunArgs):
40
+ wandb_project_name: Optional[str] = None
41
+ wandb_entity: Optional[str] = None
42
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
43
+ wandb_group: Optional[str] = None
44
+
45
+
46
+ def train(args: TrainArgs):
47
+ print(args)
48
+ hyperparams = load_hyperparams(args.algo, args.env)
49
+ print(hyperparams)
50
+ config = Config(args, hyperparams, os.getcwd())
51
+
52
+ wandb_enabled = args.wandb_project_name
53
+ if wandb_enabled:
54
+ wandb.tensorboard.patch(
55
+ root_logdir=config.tensorboard_summary_path, pytorch=True
56
+ )
57
+ wandb.init(
58
+ project=args.wandb_project_name,
59
+ entity=args.wandb_entity,
60
+ config=asdict(hyperparams),
61
+ name=config.run_name(),
62
+ monitor_gym=True,
63
+ save_code=True,
64
+ tags=args.wandb_tags,
65
+ group=args.wandb_group,
66
+ )
67
+ wandb.config.update(args)
68
+
69
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
70
+
71
+ set_seeds(args.seed, args.use_deterministic_algorithms)
72
+
73
+ env = make_env(
74
+ config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
75
+ )
76
+ device = get_device(config, env)
77
+ policy_factory = lambda: make_policy(
78
+ args.algo, env, device, **config.policy_hyperparams
79
+ )
80
+ policy = policy_factory()
81
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
82
+
83
+ num_parameters = policy.num_parameters()
84
+ num_trainable_parameters = policy.num_trainable_parameters()
85
+ if wandb_enabled:
86
+ wandb.run.summary["num_parameters"] = num_parameters # type: ignore
87
+ wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters # type: ignore
88
+ else:
89
+ print(
90
+ f"num_parameters = {num_parameters} ; "
91
+ f"num_trainable_parameters = {num_trainable_parameters}"
92
+ )
93
+
94
+ eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
95
+ record_best_videos = config.eval_hyperparams.get("record_best_videos", True)
96
+ eval_callback = EvalCallback(
97
+ policy,
98
+ eval_env,
99
+ tb_writer,
100
+ best_model_path=config.model_dir_path(best=True),
101
+ **config.eval_callback_params(),
102
+ video_env=make_eval_env(
103
+ config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
104
+ )
105
+ if record_best_videos
106
+ else None,
107
+ best_video_dir=config.best_videos_dir,
108
+ additional_keys_to_log=config.additional_keys_to_log,
109
+ )
110
+ callbacks: List[Callback] = [eval_callback]
111
+ if config.hyperparams.microrts_reward_decay_callback:
112
+ callbacks.append(MicrortsRewardDecayCallback(config, env))
113
+ selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
114
+ if selfPlayWrapper:
115
+ callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
116
+ algo.learn(config.n_timesteps, callbacks=callbacks)
117
+
118
+ policy.save(config.model_dir_path(best=False))
119
+
120
+ eval_stats = eval_callback.evaluate(n_episodes=10, print_returns=True)
121
+
122
+ plot_eval_callback(eval_callback, tb_writer, config.run_name())
123
+
124
+ log_dict: Dict[str, Any] = {
125
+ "eval": eval_stats._asdict(),
126
+ }
127
+ if eval_callback.best:
128
+ log_dict["best_eval"] = eval_callback.best._asdict()
129
+ log_dict.update(asdict(hyperparams))
130
+ log_dict.update(vars(args))
131
+ with open(config.logs_path, "a") as f:
132
+ yaml.dump({config.run_name(): log_dict}, f)
133
+
134
+ best_eval_stats: EpisodesStats = eval_callback.best # type: ignore
135
+ tb_writer.add_hparams(
136
+ hparam_dict(hyperparams, vars(args)),
137
+ {
138
+ "hparam/best_mean": best_eval_stats.score.mean,
139
+ "hparam/best_result": best_eval_stats.score.mean
140
+ - best_eval_stats.score.std,
141
+ "hparam/last_mean": eval_stats.score.mean,
142
+ "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
143
+ },
144
+ None,
145
+ config.run_name(),
146
+ )
147
+
148
+ tb_writer.close()
149
+
150
+ if wandb_enabled:
151
+ shutil.make_archive(
152
+ os.path.join(wandb.run.dir, config.model_dir_name()),
153
+ "zip",
154
+ config.model_dir_path(),
155
+ )
156
+ shutil.make_archive(
157
+ os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
158
+ "zip",
159
+ config.model_dir_path(best=True),
160
+ )
161
+ wandb.finish()
rl_algo_impls/shared/actor/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
2
+ from rl_algo_impls.shared.actor.make_actor import actor_head
rl_algo_impls/shared/actor/actor.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import NamedTuple, Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.distributions import Distribution
8
+
9
+
10
+ class PiForward(NamedTuple):
11
+ pi: Distribution
12
+ logp_a: Optional[torch.Tensor]
13
+ entropy: Optional[torch.Tensor]
14
+
15
+
16
+ class Actor(nn.Module, ABC):
17
+ @abstractmethod
18
+ def forward(
19
+ self,
20
+ obs: torch.Tensor,
21
+ actions: Optional[torch.Tensor] = None,
22
+ action_masks: Optional[torch.Tensor] = None,
23
+ ) -> PiForward:
24
+ ...
25
+
26
+ def sample_weights(self, batch_size: int = 1) -> None:
27
+ pass
28
+
29
+ @property
30
+ @abstractmethod
31
+ def action_shape(self) -> Tuple[int, ...]:
32
+ ...
33
+
34
+
35
+ def pi_forward(
36
+ distribution: Distribution, actions: Optional[torch.Tensor] = None
37
+ ) -> PiForward:
38
+ logp_a = None
39
+ entropy = None
40
+ if actions is not None:
41
+ logp_a = distribution.log_prob(actions)
42
+ entropy = distribution.entropy()
43
+ return PiForward(distribution, logp_a, entropy)
rl_algo_impls/shared/actor/categorical.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributions import Categorical
6
+
7
+ from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
8
+ from rl_algo_impls.shared.module.utils import mlp
9
+
10
+
11
+ class MaskedCategorical(Categorical):
12
+ def __init__(
13
+ self,
14
+ probs=None,
15
+ logits=None,
16
+ validate_args=None,
17
+ mask: Optional[torch.Tensor] = None,
18
+ ):
19
+ if mask is not None:
20
+ assert logits is not None, "mask requires logits and not probs"
21
+ logits = torch.where(mask, logits, -1e8)
22
+ self.mask = mask
23
+ super().__init__(probs, logits, validate_args)
24
+
25
+ def entropy(self) -> torch.Tensor:
26
+ if self.mask is None:
27
+ return super().entropy()
28
+ # If mask set, then use approximation for entropy
29
+ p_log_p = self.logits * self.probs # type: ignore
30
+ masked = torch.where(self.mask, p_log_p, 0)
31
+ return -masked.sum(-1)
32
+
33
+
34
+ class CategoricalActorHead(Actor):
35
+ def __init__(
36
+ self,
37
+ act_dim: int,
38
+ in_dim: int,
39
+ hidden_sizes: Tuple[int, ...] = (32,),
40
+ activation: Type[nn.Module] = nn.Tanh,
41
+ init_layers_orthogonal: bool = True,
42
+ ) -> None:
43
+ super().__init__()
44
+ layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
45
+ self._fc = mlp(
46
+ layer_sizes,
47
+ activation,
48
+ init_layers_orthogonal=init_layers_orthogonal,
49
+ final_layer_gain=0.01,
50
+ )
51
+
52
+ def forward(
53
+ self,
54
+ obs: torch.Tensor,
55
+ actions: Optional[torch.Tensor] = None,
56
+ action_masks: Optional[torch.Tensor] = None,
57
+ ) -> PiForward:
58
+ logits = self._fc(obs)
59
+ pi = MaskedCategorical(logits=logits, mask=action_masks)
60
+ return pi_forward(pi, actions)
61
+
62
+ @property
63
+ def action_shape(self) -> Tuple[int, ...]:
64
+ return ()
rl_algo_impls/shared/actor/gaussian.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributions import Distribution, Normal
6
+
7
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
8
+ from rl_algo_impls.shared.module.utils import mlp
9
+
10
+
11
+ class GaussianDistribution(Normal):
12
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
13
+ return super().log_prob(a).sum(axis=-1)
14
+
15
+ def sample(self) -> torch.Tensor:
16
+ return self.rsample()
17
+
18
+
19
+ class GaussianActorHead(Actor):
20
+ def __init__(
21
+ self,
22
+ act_dim: int,
23
+ in_dim: int,
24
+ hidden_sizes: Tuple[int, ...] = (32,),
25
+ activation: Type[nn.Module] = nn.Tanh,
26
+ init_layers_orthogonal: bool = True,
27
+ log_std_init: float = -0.5,
28
+ ) -> None:
29
+ super().__init__()
30
+ self.act_dim = act_dim
31
+ layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
32
+ self.mu_net = mlp(
33
+ layer_sizes,
34
+ activation,
35
+ init_layers_orthogonal=init_layers_orthogonal,
36
+ final_layer_gain=0.01,
37
+ )
38
+ self.log_std = nn.Parameter(
39
+ torch.ones(act_dim, dtype=torch.float32) * log_std_init
40
+ )
41
+
42
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
43
+ mu = self.mu_net(obs)
44
+ std = torch.exp(self.log_std)
45
+ return GaussianDistribution(mu, std)
46
+
47
+ def forward(
48
+ self,
49
+ obs: torch.Tensor,
50
+ actions: Optional[torch.Tensor] = None,
51
+ action_masks: Optional[torch.Tensor] = None,
52
+ ) -> PiForward:
53
+ assert (
54
+ not action_masks
55
+ ), f"{self.__class__.__name__} does not support action_masks"
56
+ pi = self._distribution(obs)
57
+ return pi_forward(pi, actions)
58
+
59
+ @property
60
+ def action_shape(self) -> Tuple[int, ...]:
61
+ return (self.act_dim,)
rl_algo_impls/shared/actor/gridnet.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Type
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from numpy.typing import NDArray
7
+ from torch.distributions import Distribution, constraints
8
+
9
+ from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
10
+ from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
+ from rl_algo_impls.shared.encoder import EncoderOutDim
12
+ from rl_algo_impls.shared.module.utils import mlp
13
+
14
+
15
+ class GridnetDistribution(Distribution):
16
+ def __init__(
17
+ self,
18
+ map_size: int,
19
+ action_vec: NDArray[np.int64],
20
+ logits: torch.Tensor,
21
+ masks: torch.Tensor,
22
+ validate_args: Optional[bool] = None,
23
+ ) -> None:
24
+ self.map_size = map_size
25
+ self.action_vec = action_vec
26
+
27
+ masks = masks.view(-1, masks.shape[-1])
28
+ split_masks = torch.split(masks, action_vec.tolist(), dim=1)
29
+
30
+ grid_logits = logits.reshape(-1, action_vec.sum())
31
+ split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
32
+ self.categoricals = [
33
+ MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
34
+ for lg, m in zip(split_logits, split_masks)
35
+ ]
36
+
37
+ batch_shape = logits.size()[:-1] if logits.ndimension() > 1 else torch.Size()
38
+ super().__init__(batch_shape=batch_shape, validate_args=validate_args)
39
+
40
+ def log_prob(self, action: torch.Tensor) -> torch.Tensor:
41
+ prob_stack = torch.stack(
42
+ [
43
+ c.log_prob(a)
44
+ for a, c in zip(action.view(-1, action.shape[-1]).T, self.categoricals)
45
+ ],
46
+ dim=-1,
47
+ )
48
+ logprob = prob_stack.view(-1, self.map_size, len(self.action_vec))
49
+ return logprob.sum(dim=(1, 2))
50
+
51
+ def entropy(self) -> torch.Tensor:
52
+ ent = torch.stack([c.entropy() for c in self.categoricals], dim=-1)
53
+ ent = ent.view(-1, self.map_size, len(self.action_vec))
54
+ return ent.sum(dim=(1, 2))
55
+
56
+ def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
57
+ s = torch.stack([c.sample(sample_shape) for c in self.categoricals], dim=-1)
58
+ return s.view(-1, self.map_size, len(self.action_vec))
59
+
60
+ @property
61
+ def mode(self) -> torch.Tensor:
62
+ m = torch.stack([c.mode for c in self.categoricals], dim=-1)
63
+ return m.view(-1, self.map_size, len(self.action_vec))
64
+
65
+ @property
66
+ def arg_constraints(self) -> Dict[str, constraints.Constraint]:
67
+ # Constraints handled by child distributions in dist
68
+ return {}
69
+
70
+
71
+ class GridnetActorHead(Actor):
72
+ def __init__(
73
+ self,
74
+ map_size: int,
75
+ action_vec: NDArray[np.int64],
76
+ in_dim: EncoderOutDim,
77
+ hidden_sizes: Tuple[int, ...] = (32,),
78
+ activation: Type[nn.Module] = nn.ReLU,
79
+ init_layers_orthogonal: bool = True,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.map_size = map_size
83
+ self.action_vec = action_vec
84
+ assert isinstance(in_dim, int)
85
+ layer_sizes = (in_dim,) + hidden_sizes + (map_size * action_vec.sum(),)
86
+ self._fc = mlp(
87
+ layer_sizes,
88
+ activation,
89
+ init_layers_orthogonal=init_layers_orthogonal,
90
+ final_layer_gain=0.01,
91
+ )
92
+
93
+ def forward(
94
+ self,
95
+ obs: torch.Tensor,
96
+ actions: Optional[torch.Tensor] = None,
97
+ action_masks: Optional[torch.Tensor] = None,
98
+ ) -> PiForward:
99
+ assert (
100
+ action_masks is not None
101
+ ), f"No mask case unhandled in {self.__class__.__name__}"
102
+ logits = self._fc(obs)
103
+ pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
104
+ return pi_forward(pi, actions)
105
+
106
+ @property
107
+ def action_shape(self) -> Tuple[int, ...]:
108
+ return (self.map_size, len(self.action_vec))
rl_algo_impls/shared/actor/gridnet_decoder.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from numpy.typing import NDArray
7
+
8
+ from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
9
+ from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
10
+ from rl_algo_impls.shared.encoder import EncoderOutDim
11
+ from rl_algo_impls.shared.module.utils import layer_init
12
+
13
+
14
+ class Transpose(nn.Module):
15
+ def __init__(self, permutation: Tuple[int, ...]) -> None:
16
+ super().__init__()
17
+ self.permutation = permutation
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ return x.permute(self.permutation)
21
+
22
+
23
+ class GridnetDecoder(Actor):
24
+ def __init__(
25
+ self,
26
+ map_size: int,
27
+ action_vec: NDArray[np.int64],
28
+ in_dim: EncoderOutDim,
29
+ activation: Type[nn.Module] = nn.ReLU,
30
+ init_layers_orthogonal: bool = True,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.map_size = map_size
34
+ self.action_vec = action_vec
35
+ assert isinstance(in_dim, tuple)
36
+ self.deconv = nn.Sequential(
37
+ layer_init(
38
+ nn.ConvTranspose2d(
39
+ in_dim[0], 128, 3, stride=2, padding=1, output_padding=1
40
+ ),
41
+ init_layers_orthogonal=init_layers_orthogonal,
42
+ ),
43
+ activation(),
44
+ layer_init(
45
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
46
+ init_layers_orthogonal=init_layers_orthogonal,
47
+ ),
48
+ activation(),
49
+ layer_init(
50
+ nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
51
+ init_layers_orthogonal=init_layers_orthogonal,
52
+ ),
53
+ activation(),
54
+ layer_init(
55
+ nn.ConvTranspose2d(
56
+ 32, action_vec.sum(), 3, stride=2, padding=1, output_padding=1
57
+ ),
58
+ init_layers_orthogonal=init_layers_orthogonal,
59
+ std=0.01,
60
+ ),
61
+ Transpose((0, 2, 3, 1)),
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ obs: torch.Tensor,
67
+ actions: Optional[torch.Tensor] = None,
68
+ action_masks: Optional[torch.Tensor] = None,
69
+ ) -> PiForward:
70
+ assert (
71
+ action_masks is not None
72
+ ), f"No mask case unhandled in {self.__class__.__name__}"
73
+ logits = self.deconv(obs)
74
+ pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
75
+ return pi_forward(pi, actions)
76
+
77
+ @property
78
+ def action_shape(self) -> Tuple[int, ...]:
79
+ return (self.map_size, len(self.action_vec))
rl_algo_impls/shared/actor/make_actor.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import gym
4
+ import torch.nn as nn
5
+ from gym.spaces import Box, Discrete, MultiDiscrete
6
+
7
+ from rl_algo_impls.shared.actor.actor import Actor
8
+ from rl_algo_impls.shared.actor.categorical import CategoricalActorHead
9
+ from rl_algo_impls.shared.actor.gaussian import GaussianActorHead
10
+ from rl_algo_impls.shared.actor.gridnet import GridnetActorHead
11
+ from rl_algo_impls.shared.actor.gridnet_decoder import GridnetDecoder
12
+ from rl_algo_impls.shared.actor.multi_discrete import MultiDiscreteActorHead
13
+ from rl_algo_impls.shared.actor.state_dependent_noise import (
14
+ StateDependentNoiseActorHead,
15
+ )
16
+ from rl_algo_impls.shared.encoder import EncoderOutDim
17
+
18
+
19
+ def actor_head(
20
+ action_space: gym.Space,
21
+ in_dim: EncoderOutDim,
22
+ hidden_sizes: Tuple[int, ...],
23
+ init_layers_orthogonal: bool,
24
+ activation: Type[nn.Module],
25
+ log_std_init: float = -0.5,
26
+ use_sde: bool = False,
27
+ full_std: bool = True,
28
+ squash_output: bool = False,
29
+ actor_head_style: str = "single",
30
+ action_plane_space: Optional[bool] = None,
31
+ ) -> Actor:
32
+ assert not use_sde or isinstance(
33
+ action_space, Box
34
+ ), "use_sde only valid if Box action_space"
35
+ assert not squash_output or use_sde, "squash_output only valid if use_sde"
36
+ if isinstance(action_space, Discrete):
37
+ assert isinstance(in_dim, int)
38
+ return CategoricalActorHead(
39
+ action_space.n, # type: ignore
40
+ in_dim=in_dim,
41
+ hidden_sizes=hidden_sizes,
42
+ activation=activation,
43
+ init_layers_orthogonal=init_layers_orthogonal,
44
+ )
45
+ elif isinstance(action_space, Box):
46
+ assert isinstance(in_dim, int)
47
+ if use_sde:
48
+ return StateDependentNoiseActorHead(
49
+ action_space.shape[0], # type: ignore
50
+ in_dim=in_dim,
51
+ hidden_sizes=hidden_sizes,
52
+ activation=activation,
53
+ init_layers_orthogonal=init_layers_orthogonal,
54
+ log_std_init=log_std_init,
55
+ full_std=full_std,
56
+ squash_output=squash_output,
57
+ )
58
+ else:
59
+ return GaussianActorHead(
60
+ action_space.shape[0], # type: ignore
61
+ in_dim=in_dim,
62
+ hidden_sizes=hidden_sizes,
63
+ activation=activation,
64
+ init_layers_orthogonal=init_layers_orthogonal,
65
+ log_std_init=log_std_init,
66
+ )
67
+ elif isinstance(action_space, MultiDiscrete):
68
+ if actor_head_style == "single":
69
+ return MultiDiscreteActorHead(
70
+ action_space.nvec, # type: ignore
71
+ in_dim=in_dim,
72
+ hidden_sizes=hidden_sizes,
73
+ activation=activation,
74
+ init_layers_orthogonal=init_layers_orthogonal,
75
+ )
76
+ elif actor_head_style == "gridnet":
77
+ assert isinstance(action_plane_space, MultiDiscrete)
78
+ return GridnetActorHead(
79
+ len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
80
+ action_plane_space.nvec, # type: ignore
81
+ in_dim=in_dim,
82
+ hidden_sizes=hidden_sizes,
83
+ activation=activation,
84
+ init_layers_orthogonal=init_layers_orthogonal,
85
+ )
86
+ elif actor_head_style == "gridnet_decoder":
87
+ assert isinstance(action_plane_space, MultiDiscrete)
88
+ return GridnetDecoder(
89
+ len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
90
+ action_plane_space.nvec, # type: ignore
91
+ in_dim=in_dim,
92
+ activation=activation,
93
+ init_layers_orthogonal=init_layers_orthogonal,
94
+ )
95
+ else:
96
+ raise ValueError(f"Doesn't support actor_head_style {actor_head_style}")
97
+ else:
98
+ raise ValueError(f"Unsupported action space: {action_space}")
rl_algo_impls/shared/actor/multi_discrete.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Type
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from numpy.typing import NDArray
7
+ from torch.distributions import Distribution, constraints
8
+
9
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
10
+ from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
+ from rl_algo_impls.shared.encoder import EncoderOutDim
12
+ from rl_algo_impls.shared.module.utils import mlp
13
+
14
+
15
+ class MultiCategorical(Distribution):
16
+ def __init__(
17
+ self,
18
+ nvec: NDArray[np.int64],
19
+ probs=None,
20
+ logits=None,
21
+ validate_args=None,
22
+ masks: Optional[torch.Tensor] = None,
23
+ ):
24
+ # Either probs or logits should be set
25
+ assert (probs is None) != (logits is None)
26
+ masks_split = (
27
+ torch.split(masks, nvec.tolist(), dim=1)
28
+ if masks is not None
29
+ else [None] * len(nvec)
30
+ )
31
+ if probs:
32
+ self.dists = [
33
+ MaskedCategorical(probs=p, validate_args=validate_args, mask=m)
34
+ for p, m in zip(torch.split(probs, nvec.tolist(), dim=1), masks_split)
35
+ ]
36
+ param = probs
37
+ else:
38
+ assert logits is not None
39
+ self.dists = [
40
+ MaskedCategorical(logits=lg, validate_args=validate_args, mask=m)
41
+ for lg, m in zip(torch.split(logits, nvec.tolist(), dim=1), masks_split)
42
+ ]
43
+ param = logits
44
+ batch_shape = param.size()[:-1] if param.ndimension() > 1 else torch.Size()
45
+ super().__init__(batch_shape=batch_shape, validate_args=validate_args)
46
+
47
+ def log_prob(self, action: torch.Tensor) -> torch.Tensor:
48
+ prob_stack = torch.stack(
49
+ [c.log_prob(a) for a, c in zip(action.T, self.dists)], dim=-1
50
+ )
51
+ return prob_stack.sum(dim=-1)
52
+
53
+ def entropy(self) -> torch.Tensor:
54
+ return torch.stack([c.entropy() for c in self.dists], dim=-1).sum(dim=-1)
55
+
56
+ def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
57
+ return torch.stack([c.sample(sample_shape) for c in self.dists], dim=-1)
58
+
59
+ @property
60
+ def mode(self) -> torch.Tensor:
61
+ return torch.stack([c.mode for c in self.dists], dim=-1)
62
+
63
+ @property
64
+ def arg_constraints(self) -> Dict[str, constraints.Constraint]:
65
+ # Constraints handled by child distributions in dist
66
+ return {}
67
+
68
+
69
+ class MultiDiscreteActorHead(Actor):
70
+ def __init__(
71
+ self,
72
+ nvec: NDArray[np.int64],
73
+ in_dim: EncoderOutDim,
74
+ hidden_sizes: Tuple[int, ...] = (32,),
75
+ activation: Type[nn.Module] = nn.ReLU,
76
+ init_layers_orthogonal: bool = True,
77
+ ) -> None:
78
+ super().__init__()
79
+ self.nvec = nvec
80
+ assert isinstance(in_dim, int)
81
+ layer_sizes = (in_dim,) + hidden_sizes + (nvec.sum(),)
82
+ self._fc = mlp(
83
+ layer_sizes,
84
+ activation,
85
+ init_layers_orthogonal=init_layers_orthogonal,
86
+ final_layer_gain=0.01,
87
+ )
88
+
89
+ def forward(
90
+ self,
91
+ obs: torch.Tensor,
92
+ actions: Optional[torch.Tensor] = None,
93
+ action_masks: Optional[torch.Tensor] = None,
94
+ ) -> PiForward:
95
+ logits = self._fc(obs)
96
+ pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
97
+ return pi_forward(pi, actions)
98
+
99
+ @property
100
+ def action_shape(self) -> Tuple[int, ...]:
101
+ return (len(self.nvec),)
rl_algo_impls/shared/actor/state_dependent_noise.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Type, TypeVar, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributions import Distribution, Normal
6
+
7
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward
8
+ from rl_algo_impls.shared.module.utils import mlp
9
+
10
+
11
+ class TanhBijector:
12
+ def __init__(self, epsilon: float = 1e-6) -> None:
13
+ self.epsilon = epsilon
14
+
15
+ @staticmethod
16
+ def forward(x: torch.Tensor) -> torch.Tensor:
17
+ return torch.tanh(x)
18
+
19
+ @staticmethod
20
+ def inverse(y: torch.Tensor) -> torch.Tensor:
21
+ eps = torch.finfo(y.dtype).eps
22
+ clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps)
23
+ return torch.atanh(clamped_y)
24
+
25
+ def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
26
+ return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
27
+
28
+
29
+ def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor:
30
+ if len(tensor.shape) > 1:
31
+ return tensor.sum(dim=1)
32
+ return tensor.sum()
33
+
34
+
35
+ class StateDependentNoiseDistribution(Normal):
36
+ def __init__(
37
+ self,
38
+ loc,
39
+ scale,
40
+ latent_sde: torch.Tensor,
41
+ exploration_mat: torch.Tensor,
42
+ exploration_matrices: torch.Tensor,
43
+ bijector: Optional[TanhBijector] = None,
44
+ validate_args=None,
45
+ ):
46
+ super().__init__(loc, scale, validate_args)
47
+ self.latent_sde = latent_sde
48
+ self.exploration_mat = exploration_mat
49
+ self.exploration_matrices = exploration_matrices
50
+ self.bijector = bijector
51
+
52
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
53
+ gaussian_a = self.bijector.inverse(a) if self.bijector else a
54
+ log_prob = sum_independent_dims(super().log_prob(gaussian_a))
55
+ if self.bijector:
56
+ log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1)
57
+ return log_prob
58
+
59
+ def sample(self) -> torch.Tensor:
60
+ noise = self._get_noise()
61
+ actions = self.mean + noise
62
+ return self.bijector.forward(actions) if self.bijector else actions
63
+
64
+ def _get_noise(self) -> torch.Tensor:
65
+ if len(self.latent_sde) == 1 or len(self.latent_sde) != len(
66
+ self.exploration_matrices
67
+ ):
68
+ return torch.mm(self.latent_sde, self.exploration_mat)
69
+ # (batch_size, n_features) -> (batch_size, 1, n_features)
70
+ latent_sde = self.latent_sde.unsqueeze(dim=1)
71
+ # (batch_size, 1, n_actions)
72
+ noise = torch.bmm(latent_sde, self.exploration_matrices)
73
+ return noise.squeeze(dim=1)
74
+
75
+ @property
76
+ def mode(self) -> torch.Tensor:
77
+ mean = super().mode
78
+ return self.bijector.forward(mean) if self.bijector else mean
79
+
80
+
81
+ StateDependentNoiseActorHeadSelf = TypeVar(
82
+ "StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead"
83
+ )
84
+
85
+
86
+ class StateDependentNoiseActorHead(Actor):
87
+ def __init__(
88
+ self,
89
+ act_dim: int,
90
+ in_dim: int,
91
+ hidden_sizes: Tuple[int, ...] = (32,),
92
+ activation: Type[nn.Module] = nn.Tanh,
93
+ init_layers_orthogonal: bool = True,
94
+ log_std_init: float = -0.5,
95
+ full_std: bool = True,
96
+ squash_output: bool = False,
97
+ learn_std: bool = False,
98
+ ) -> None:
99
+ super().__init__()
100
+ self.act_dim = act_dim
101
+ layer_sizes = (in_dim,) + hidden_sizes + (act_dim,)
102
+ if len(layer_sizes) == 2:
103
+ self.latent_net = nn.Identity()
104
+ elif len(layer_sizes) > 2:
105
+ self.latent_net = mlp(
106
+ layer_sizes[:-1],
107
+ activation,
108
+ output_activation=activation,
109
+ init_layers_orthogonal=init_layers_orthogonal,
110
+ )
111
+ self.mu_net = mlp(
112
+ layer_sizes[-2:],
113
+ activation,
114
+ init_layers_orthogonal=init_layers_orthogonal,
115
+ final_layer_gain=0.01,
116
+ )
117
+ self.full_std = full_std
118
+ std_dim = (layer_sizes[-2], act_dim if self.full_std else 1)
119
+ self.log_std = nn.Parameter(
120
+ torch.ones(std_dim, dtype=torch.float32) * log_std_init
121
+ )
122
+ self.bijector = TanhBijector() if squash_output else None
123
+ self.learn_std = learn_std
124
+ self.device = None
125
+
126
+ self.exploration_mat = None
127
+ self.exploration_matrices = None
128
+ self.sample_weights()
129
+
130
+ def to(
131
+ self: StateDependentNoiseActorHeadSelf,
132
+ device: Optional[torch.device] = None,
133
+ dtype: Optional[Union[torch.dtype, str]] = None,
134
+ non_blocking: bool = False,
135
+ ) -> StateDependentNoiseActorHeadSelf:
136
+ super().to(device, dtype, non_blocking)
137
+ self.device = device
138
+ return self
139
+
140
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
141
+ latent = self.latent_net(obs)
142
+ mu = self.mu_net(latent)
143
+ latent_sde = latent if self.learn_std else latent.detach()
144
+ variance = torch.mm(latent_sde**2, self._get_std() ** 2)
145
+ assert self.exploration_mat is not None
146
+ assert self.exploration_matrices is not None
147
+ return StateDependentNoiseDistribution(
148
+ mu,
149
+ torch.sqrt(variance + 1e-6),
150
+ latent_sde,
151
+ self.exploration_mat,
152
+ self.exploration_matrices,
153
+ self.bijector,
154
+ )
155
+
156
+ def _get_std(self) -> torch.Tensor:
157
+ std = torch.exp(self.log_std)
158
+ if self.full_std:
159
+ return std
160
+ ones = torch.ones(self.log_std.shape[0], self.act_dim)
161
+ if self.device:
162
+ ones = ones.to(self.device)
163
+ return ones * std
164
+
165
+ def forward(
166
+ self,
167
+ obs: torch.Tensor,
168
+ actions: Optional[torch.Tensor] = None,
169
+ action_masks: Optional[torch.Tensor] = None,
170
+ ) -> PiForward:
171
+ assert (
172
+ not action_masks
173
+ ), f"{self.__class__.__name__} does not support action_masks"
174
+ pi = self._distribution(obs)
175
+ return pi_forward(pi, actions)
176
+
177
+ def sample_weights(self, batch_size: int = 1) -> None:
178
+ std = self._get_std()
179
+ weights_dist = Normal(torch.zeros_like(std), std)
180
+ # Reparametrization trick to pass gradients
181
+ self.exploration_mat = weights_dist.rsample()
182
+ self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
183
+
184
+ @property
185
+ def action_shape(self) -> Tuple[int, ...]:
186
+ return (self.act_dim,)
187
+
188
+
189
+ def pi_forward(
190
+ distribution: Distribution, actions: Optional[torch.Tensor] = None
191
+ ) -> PiForward:
192
+ logp_a = None
193
+ entropy = None
194
+ if actions is not None:
195
+ logp_a = distribution.log_prob(actions)
196
+ entropy = (
197
+ -logp_a if self.bijector else sum_independent_dims(distribution.entropy())
198
+ )
199
+ return PiForward(distribution, logp_a, entropy)
rl_algo_impls/shared/algorithm.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, TypeVar
3
+
4
+ import gym
5
+ import torch
6
+ from torch.utils.tensorboard.writer import SummaryWriter
7
+
8
+ from rl_algo_impls.shared.callbacks import Callback
9
+ from rl_algo_impls.shared.policy.policy import Policy
10
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
11
+
12
+ AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
13
+
14
+
15
+ class Algorithm(ABC):
16
+ @abstractmethod
17
+ def __init__(
18
+ self,
19
+ policy: Policy,
20
+ env: VecEnv,
21
+ device: torch.device,
22
+ tb_writer: SummaryWriter,
23
+ **kwargs,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.policy = policy
27
+ self.env = env
28
+ self.device = device
29
+ self.tb_writer = tb_writer
30
+
31
+ @abstractmethod
32
+ def learn(
33
+ self: AlgorithmSelf,
34
+ train_timesteps: int,
35
+ callbacks: Optional[List[Callback]] = None,
36
+ total_timesteps: Optional[int] = None,
37
+ start_timesteps: int = 0,
38
+ ) -> AlgorithmSelf:
39
+ ...