sgoodfriend commited on
Commit
9768175
1 Parent(s): d4e691b

PPO playing BipedalWalker-v3 from https://github.com/sgoodfriend/rl-algo-impls/tree/7026bf7f4f56a8a5b0dab7193256d2fbf823b308

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +147 -0
  2. LICENSE +21 -0
  3. README.md +133 -0
  4. benchmark_publish.py +107 -0
  5. benchmarks/benchmark_test.sh +32 -0
  6. benchmarks/colab_atari1.sh +5 -0
  7. benchmarks/colab_atari2.sh +5 -0
  8. benchmarks/colab_basic.sh +5 -0
  9. benchmarks/colab_benchmark.ipynb +195 -0
  10. benchmarks/colab_carracing.sh +5 -0
  11. benchmarks/colab_pybullet.sh +5 -0
  12. benchmarks/train_loop.sh +15 -0
  13. colab_enjoy.ipynb +198 -0
  14. colab_requirements.txt +14 -0
  15. colab_train.ipynb +200 -0
  16. compare_runs.py +187 -0
  17. dqn/dqn.py +182 -0
  18. dqn/policy.py +47 -0
  19. dqn/q_net.py +39 -0
  20. enjoy.py +30 -0
  21. environment.yml +17 -0
  22. hf-deep-rl/dqn_SpaceInvadersNoFrameskip_v4.ipynb +0 -0
  23. huggingface_publish.py +189 -0
  24. hyperparams/dqn.yml +130 -0
  25. hyperparams/ppo.yml +380 -0
  26. hyperparams/vpg.yml +195 -0
  27. lambda_labs/benchmark.sh +33 -0
  28. lambda_labs/impala_atari_benchmark.sh +19 -0
  29. lambda_labs/lambda_requirements.txt +16 -0
  30. lambda_labs/procgen_benchmark.sh +18 -0
  31. lambda_labs/setup.sh +10 -0
  32. lambda_labs/starpilot_hard_benchmark.sh +16 -0
  33. poetry.lock +0 -0
  34. ppo/ppo.py +349 -0
  35. publish/markdown_format.py +210 -0
  36. pyproject.toml +35 -0
  37. replay.meta.json +1 -0
  38. replay.mp4 +0 -0
  39. runner/config.py +155 -0
  40. runner/env.py +262 -0
  41. runner/evaluate.py +103 -0
  42. runner/running_utils.py +192 -0
  43. runner/train.py +141 -0
  44. saved_models/ppo-BipedalWalker-v3-S2-best/model.pth +3 -0
  45. saved_models/ppo-BipedalWalker-v3-S2-best/vecnormalize.pkl +3 -0
  46. shared/algorithm.py +35 -0
  47. shared/callbacks/callback.py +12 -0
  48. shared/callbacks/eval_callback.py +214 -0
  49. shared/gae.py +67 -0
  50. shared/module/feature_extractor.py +215 -0
.gitignore ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Logging into tensorboard and wandb
132
+ runs/*
133
+ wandb
134
+
135
+ # macOS
136
+ .DS_STORE
137
+
138
+ # Local scratch work
139
+ scratch/*
140
+
141
+ # vscode
142
+ .vscode/
143
+
144
+ # Don't bother tracking saved_models or videos
145
+ saved_models/*
146
+ downloaded_models/*
147
+ videos/*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Scott Goodfriend
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: rl-algo-impls
3
+ tags:
4
+ - BipedalWalker-v3
5
+ - ppo
6
+ - deep-reinforcement-learning
7
+ - reinforcement-learning
8
+ model-index:
9
+ - name: ppo
10
+ results:
11
+ - metrics:
12
+ - type: mean_reward
13
+ value: 324.52 +/- 0.73
14
+ name: mean_reward
15
+ task:
16
+ type: reinforcement-learning
17
+ name: reinforcement-learning
18
+ dataset:
19
+ name: BipedalWalker-v3
20
+ type: BipedalWalker-v3
21
+ ---
22
+ # **PPO** Agent playing **BipedalWalker-v3**
23
+
24
+ This is a trained model of a **PPO** agent playing **BipedalWalker-v3** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
+
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/tgbav2ux.
27
+
28
+ ## Training Results
29
+
30
+ This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [7026bf7](https://github.com/sgoodfriend/rl-algo-impls/tree/7026bf7f4f56a8a5b0dab7193256d2fbf823b308). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
+
32
+ | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
+ |:-------|:-----------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | ppo | BipedalWalker-v3 | 1 | 315.428 | 36.4316 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/8nykj4o0) |
35
+ | ppo | BipedalWalker-v3 | 2 | 324.516 | 0.733634 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/onmmcskv) |
36
+ | ppo | BipedalWalker-v3 | 3 | 310.169 | 1.03936 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/xdgop7ak) |
37
+
38
+
39
+ ### Prerequisites: Weights & Biases (WandB)
40
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
41
+ By default training goes to a rl-algo-impls project while benchmarks go to
42
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
43
+ models and the model weights are uploaded to WandB.
44
+
45
+ Before doing anything below, you'll need to create a wandb account and run `wandb
46
+ login`.
47
+
48
+
49
+
50
+ ## Usage
51
+ /sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
52
+
53
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
54
+ implementation could be sufficiently different to not be able to reproduce similar
55
+ results. You might need to checkout the commit the agent was trained on:
56
+ [7026bf7](https://github.com/sgoodfriend/rl-algo-impls/tree/7026bf7f4f56a8a5b0dab7193256d2fbf823b308).
57
+ ```
58
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/onmmcskv
60
+ ```
61
+
62
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
63
+ Colab starting from the
64
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
65
+ notebook.
66
+
67
+
68
+
69
+ ## Training
70
+ If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [7026bf7](https://github.com/sgoodfriend/rl-algo-impls/tree/7026bf7f4f56a8a5b0dab7193256d2fbf823b308). While
72
+ training is deterministic, different hardware will give different results.
73
+
74
+ ```
75
+ python train.py --algo ppo --env BipedalWalker-v3 --seed 2
76
+ ```
77
+
78
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
79
+ Colab starting from the
80
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
81
+ notebook.
82
+
83
+
84
+
85
+ ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/tgbav2ux were generated by running a script on a Lambda
87
+ Labs instance. In a Lambda Labs instance terminal:
88
+ ```
89
+ git clone git@github.com:sgoodfriend/rl-algo-impls.git
90
+ cd rl-algo-impls
91
+ bash ./lambda_labs/setup.sh
92
+ wandb login
93
+ bash ./lambda_labs/benchmark.sh
94
+ ```
95
+
96
+ ### Alternative: Google Colab Pro+
97
+ As an alternative,
98
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
99
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
100
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
101
+
102
+
103
+
104
+ ## Hyperparameters
105
+ This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
106
+ close and has some additional data:
107
+ ```
108
+ algo: ppo
109
+ algo_hyperparams:
110
+ batch_size: 64
111
+ clip_range: 0.2
112
+ clip_range_decay: linear
113
+ ent_coef: 0.001
114
+ gae_lambda: 0.95
115
+ gamma: 0.99
116
+ learning_rate: 0.00025
117
+ learning_rate_decay: linear
118
+ n_epochs: 10
119
+ n_steps: 2048
120
+ env: BipedalWalker-v3
121
+ env_hyperparams:
122
+ n_envs: 16
123
+ normalize: true
124
+ n_timesteps: 10000000
125
+ seed: 2
126
+ use_deterministic_algorithms: true
127
+ wandb_entity: null
128
+ wandb_project_name: rl-algo-impls-benchmarks
129
+ wandb_tags:
130
+ - benchmark_7026bf7
131
+ - host_slg-m1max.local
132
+
133
+ ```
benchmark_publish.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import wandb
4
+ import wandb.apis.public
5
+
6
+ from collections import defaultdict
7
+ from multiprocessing.pool import ThreadPool
8
+ from typing import List, NamedTuple
9
+
10
+
11
+ class RunGroup(NamedTuple):
12
+ algo: str
13
+ env_id: str
14
+
15
+
16
+ if __name__ == "__main__":
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--wandb-project-name",
20
+ type=str,
21
+ default="rl-algo-impls-benchmarks",
22
+ help="WandB project name to load runs from",
23
+ )
24
+ parser.add_argument(
25
+ "--wandb-entity",
26
+ type=str,
27
+ default=None,
28
+ help="WandB team of project. None uses default entity",
29
+ )
30
+ parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
31
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
32
+ parser.add_argument(
33
+ "--envs", type=str, nargs="*", help="Optional filter down to these envs"
34
+ )
35
+ parser.add_argument(
36
+ "--exclude-envs",
37
+ type=str,
38
+ nargs="*",
39
+ help="Environments to exclude from publishing",
40
+ )
41
+ parser.add_argument(
42
+ "--huggingface-user",
43
+ type=str,
44
+ default=None,
45
+ help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
46
+ )
47
+ parser.add_argument(
48
+ "--pool-size",
49
+ type=int,
50
+ default=3,
51
+ help="How many publish jobs can run in parallel",
52
+ )
53
+ parser.add_argument(
54
+ "--virtual-display", action="store_true", help="Use headless virtual display"
55
+ )
56
+ # parser.set_defaults(
57
+ # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
58
+ # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
59
+ # envs=[],
60
+ # exclude_envs=[],
61
+ # )
62
+ args = parser.parse_args()
63
+ print(args)
64
+
65
+ api = wandb.Api()
66
+ all_runs = api.runs(
67
+ f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
68
+ )
69
+
70
+ required_tags = set(args.wandb_tags)
71
+ runs: List[wandb.apis.public.Run] = [
72
+ r
73
+ for r in all_runs
74
+ if required_tags.issubset(set(r.config.get("wandb_tags", [])))
75
+ ]
76
+
77
+ runs_paths_by_group = defaultdict(list)
78
+ for r in runs:
79
+ if r.state != "finished":
80
+ continue
81
+ algo = r.config["algo"]
82
+ env = r.config["env"]
83
+ if args.envs and env not in args.envs:
84
+ continue
85
+ if args.exclude_envs and env in args.exclude_envs:
86
+ continue
87
+ run_group = RunGroup(algo, env)
88
+ runs_paths_by_group[run_group].append("/".join(r.path))
89
+
90
+ def run(run_paths: List[str]) -> None:
91
+ publish_args = ["python", "huggingface_publish.py"]
92
+ publish_args.append("--wandb-run-paths")
93
+ publish_args.extend(run_paths)
94
+ publish_args.append("--wandb-report-url")
95
+ publish_args.append(args.wandb_report_url)
96
+ if args.huggingface_user:
97
+ publish_args.append("--huggingface-user")
98
+ publish_args.append(args.huggingface_user)
99
+ if args.virtual_display:
100
+ publish_args.append("--virtual-display")
101
+ subprocess.run(publish_args)
102
+
103
+ tp = ThreadPool(args.pool_size)
104
+ for run_paths in runs_paths_by_group.values():
105
+ tp.apply_async(run, (run_paths,))
106
+ tp.close()
107
+ tp.join()
benchmarks/benchmark_test.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+
3
+ export WANDB_PROJECT_NAME="rl-algo-impls"
4
+
5
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
6
+
7
+ ALGOS=(
8
+ # "vpg"
9
+ "dqn"
10
+ # "ppo"
11
+ )
12
+ ENVS=(
13
+ # Basic
14
+ "CartPole-v1"
15
+ "MountainCar-v0"
16
+ # "MountainCarContinuous-v0"
17
+ "Acrobot-v1"
18
+ "LunarLander-v2"
19
+ # # PyBullet
20
+ # "HalfCheetahBulletEnv-v0"
21
+ # "AntBulletEnv-v0"
22
+ # "HopperBulletEnv-v0"
23
+ # "Walker2DBulletEnv-v0"
24
+ # # CarRacing
25
+ # "CarRacing-v0"
26
+ # Atari
27
+ "PongNoFrameskip-v4"
28
+ "BreakoutNoFrameskip-v4"
29
+ "SpaceInvadersNoFrameskip-v4"
30
+ "QbertNoFrameskip-v4"
31
+ )
32
+ train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_atari1.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_atari2.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_basic.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_benchmark.ipynb ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "## Setup\n",
63
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
64
+ ],
65
+ "metadata": {
66
+ "id": "bsG35Io0hmKG"
67
+ }
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "%%capture\n",
73
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
74
+ ],
75
+ "metadata": {
76
+ "id": "k5ynTV25hdAf"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "Installing the correct packages:\n",
85
+ "\n",
86
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
87
+ ],
88
+ "metadata": {
89
+ "id": "jKxGok-ElYQ7"
90
+ }
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "source": [
95
+ "%%capture\n",
96
+ "!apt install python-opengl\n",
97
+ "!apt install ffmpeg\n",
98
+ "!apt install xvfb\n",
99
+ "!apt install swig"
100
+ ],
101
+ "metadata": {
102
+ "id": "nn6EETTc2Ewf"
103
+ },
104
+ "execution_count": null,
105
+ "outputs": []
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "source": [
110
+ "%%capture\n",
111
+ "%cd /content/rl-algo-impls\n",
112
+ "!pip install -r colab_requirements.txt"
113
+ ],
114
+ "metadata": {
115
+ "id": "AfZh9rH3yQii"
116
+ },
117
+ "execution_count": null,
118
+ "outputs": []
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "source": [
123
+ "## Run Once Per Runtime"
124
+ ],
125
+ "metadata": {
126
+ "id": "4o5HOLjc4wq7"
127
+ }
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "import wandb\n",
133
+ "wandb.login()"
134
+ ],
135
+ "metadata": {
136
+ "id": "PCXa5tdS2qFX"
137
+ },
138
+ "execution_count": null,
139
+ "outputs": []
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "source": [
144
+ "## Restart Session beteween runs"
145
+ ],
146
+ "metadata": {
147
+ "id": "AZBZfSUV43JQ"
148
+ }
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "source": [
153
+ "%%capture\n",
154
+ "from pyvirtualdisplay import Display\n",
155
+ "\n",
156
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
157
+ "virtual_display.start()"
158
+ ],
159
+ "metadata": {
160
+ "id": "VzemeQJP2NO9"
161
+ },
162
+ "execution_count": null,
163
+ "outputs": []
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "source": [
168
+ "The below 5 bash scripts train agents on environments with 3 seeds each:\n",
169
+ "- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
170
+ "- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
171
+ "- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
172
+ ],
173
+ "metadata": {
174
+ "id": "nSHfna0hLlO1"
175
+ }
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "source": [
180
+ "%cd /content/rl-algo-impls\n",
181
+ "os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
182
+ "!./benchmarks/colab_basic.sh\n",
183
+ "!./benchmarks/colab_pybullet.sh\n",
184
+ "# !./benchmarks/colab_carracing.sh\n",
185
+ "# !./benchmarks/colab_atari1.sh\n",
186
+ "# !./benchmarks/colab_atari2.sh"
187
+ ],
188
+ "metadata": {
189
+ "id": "07aHYFH1zfXa"
190
+ },
191
+ "execution_count": null,
192
+ "outputs": []
193
+ }
194
+ ]
195
+ }
benchmarks/colab_carracing.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="CarRacing-v0"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_pybullet.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+ ALGOS="ppo"
3
+ ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
4
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
+ train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/train_loop.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_loop () {
2
+ local WANDB_TAGS="benchmark_$(git rev-parse --short HEAD) host_$(hostname)"
3
+ local algo
4
+ local env
5
+ local seed
6
+ local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
7
+ local SEEDS="${SEEDS:-1 2 3}"
8
+ for algo in $(echo $1); do
9
+ for env in $(echo $2); do
10
+ for seed in $SEEDS; do
11
+ echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME
12
+ done
13
+ done
14
+ done
15
+ }
colab_enjoy.ipynb ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. enjoy.py parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
72
+ ],
73
+ "metadata": {
74
+ "id": "jKL_NFhVOjSc"
75
+ },
76
+ "execution_count": 2,
77
+ "outputs": []
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "source": [
82
+ "## Setup\n",
83
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
84
+ ],
85
+ "metadata": {
86
+ "id": "bsG35Io0hmKG"
87
+ }
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "source": [
92
+ "%%capture\n",
93
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
94
+ ],
95
+ "metadata": {
96
+ "id": "k5ynTV25hdAf"
97
+ },
98
+ "execution_count": 3,
99
+ "outputs": []
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "source": [
104
+ "Installing the correct packages:\n",
105
+ "\n",
106
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
107
+ ],
108
+ "metadata": {
109
+ "id": "jKxGok-ElYQ7"
110
+ }
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "source": [
115
+ "%%capture\n",
116
+ "!apt install python-opengl\n",
117
+ "!apt install ffmpeg\n",
118
+ "!apt install xvfb\n",
119
+ "!apt install swig"
120
+ ],
121
+ "metadata": {
122
+ "id": "nn6EETTc2Ewf"
123
+ },
124
+ "execution_count": 4,
125
+ "outputs": []
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "source": [
130
+ "%%capture\n",
131
+ "%cd /content/rl-algo-impls\n",
132
+ "!pip install -r colab_requirements.txt"
133
+ ],
134
+ "metadata": {
135
+ "id": "AfZh9rH3yQii"
136
+ },
137
+ "execution_count": 5,
138
+ "outputs": []
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "source": [
143
+ "## Run Once Per Runtime"
144
+ ],
145
+ "metadata": {
146
+ "id": "4o5HOLjc4wq7"
147
+ }
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "source": [
152
+ "import wandb\n",
153
+ "wandb.login()"
154
+ ],
155
+ "metadata": {
156
+ "id": "PCXa5tdS2qFX"
157
+ },
158
+ "execution_count": null,
159
+ "outputs": []
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "source": [
164
+ "## Restart Session beteween runs"
165
+ ],
166
+ "metadata": {
167
+ "id": "AZBZfSUV43JQ"
168
+ }
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "source": [
173
+ "%%capture\n",
174
+ "from pyvirtualdisplay import Display\n",
175
+ "\n",
176
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
177
+ "virtual_display.start()"
178
+ ],
179
+ "metadata": {
180
+ "id": "VzemeQJP2NO9"
181
+ },
182
+ "execution_count": 7,
183
+ "outputs": []
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "source": [
188
+ "%cd /content/rl-algo-impls\n",
189
+ "!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
190
+ ],
191
+ "metadata": {
192
+ "id": "07aHYFH1zfXa"
193
+ },
194
+ "execution_count": null,
195
+ "outputs": []
196
+ }
197
+ ]
198
+ }
colab_requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AutoROM.accept-rom-license >= 0.4.2, < 0.5
2
+ stable-baselines3[extra] >= 1.7.0, < 1.8
3
+ gym[box2d] >= 0.21.0, < 0.22
4
+ pyglet == 1.5.27
5
+ wandb >= 0.13.10, < 0.14
6
+ pyvirtualdisplay == 3.0
7
+ pybullet >= 3.2.5, < 3.3
8
+ tabulate >= 0.9.0, < 0.10
9
+ huggingface-hub >= 0.12.0, < 0.13
10
+ numexpr >= 2.8.4, < 2.9
11
+ gym3 >= 0.3.3, < 0.4
12
+ glfw >= 1.12.0, < 1.13
13
+ procgen >= 0.10.7, < 0.11
14
+ ipython >= 8.10.0, < 8.11
colab_train.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. train run parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "ALGO = \"ppo\"\n",
72
+ "ENV = \"CartPole-v1\"\n",
73
+ "SEED = 1"
74
+ ],
75
+ "metadata": {
76
+ "id": "jKL_NFhVOjSc"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "## Setup\n",
85
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
86
+ ],
87
+ "metadata": {
88
+ "id": "bsG35Io0hmKG"
89
+ }
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "source": [
94
+ "%%capture\n",
95
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
96
+ ],
97
+ "metadata": {
98
+ "id": "k5ynTV25hdAf"
99
+ },
100
+ "execution_count": null,
101
+ "outputs": []
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "source": [
106
+ "Installing the correct packages:\n",
107
+ "\n",
108
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
109
+ ],
110
+ "metadata": {
111
+ "id": "jKxGok-ElYQ7"
112
+ }
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "source": [
117
+ "%%capture\n",
118
+ "!apt install python-opengl\n",
119
+ "!apt install ffmpeg\n",
120
+ "!apt install xvfb\n",
121
+ "!apt install swig"
122
+ ],
123
+ "metadata": {
124
+ "id": "nn6EETTc2Ewf"
125
+ },
126
+ "execution_count": null,
127
+ "outputs": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "%%capture\n",
133
+ "%cd /content/rl-algo-impls\n",
134
+ "!pip install -r colab_requirements.txt"
135
+ ],
136
+ "metadata": {
137
+ "id": "AfZh9rH3yQii"
138
+ },
139
+ "execution_count": null,
140
+ "outputs": []
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "source": [
145
+ "## Run Once Per Runtime"
146
+ ],
147
+ "metadata": {
148
+ "id": "4o5HOLjc4wq7"
149
+ }
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "source": [
154
+ "import wandb\n",
155
+ "wandb.login()"
156
+ ],
157
+ "metadata": {
158
+ "id": "PCXa5tdS2qFX"
159
+ },
160
+ "execution_count": null,
161
+ "outputs": []
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "source": [
166
+ "## Restart Session beteween runs"
167
+ ],
168
+ "metadata": {
169
+ "id": "AZBZfSUV43JQ"
170
+ }
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "source": [
175
+ "%%capture\n",
176
+ "from pyvirtualdisplay import Display\n",
177
+ "\n",
178
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
179
+ "virtual_display.start()"
180
+ ],
181
+ "metadata": {
182
+ "id": "VzemeQJP2NO9"
183
+ },
184
+ "execution_count": null,
185
+ "outputs": []
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "source": [
190
+ "%cd /content/rl-algo-impls\n",
191
+ "!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
192
+ ],
193
+ "metadata": {
194
+ "id": "07aHYFH1zfXa"
195
+ },
196
+ "execution_count": null,
197
+ "outputs": []
198
+ }
199
+ ]
200
+ }
compare_runs.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import numpy as np
4
+ import pandas as pd
5
+ import wandb
6
+ import wandb.apis.public
7
+
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Iterable, List, TypeVar
11
+
12
+ from benchmark_publish import RunGroup
13
+
14
+
15
+ @dataclass
16
+ class Comparison:
17
+ control_values: List[float]
18
+ experiment_values: List[float]
19
+
20
+ def mean_diff_percentage(self) -> float:
21
+ return self._diff_percentage(
22
+ np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
23
+ )
24
+
25
+ def median_diff_percentage(self) -> float:
26
+ return self._diff_percentage(
27
+ np.median(self.control_values).item(),
28
+ np.median(self.experiment_values).item(),
29
+ )
30
+
31
+ def _diff_percentage(self, c: float, e: float) -> float:
32
+ if c == e:
33
+ return 0
34
+ elif c == 0:
35
+ return float("inf") if e > 0 else float("-inf")
36
+ return 100 * (e - c) / c
37
+
38
+ def score(self) -> float:
39
+ return (
40
+ np.sum(
41
+ np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
42
+ ).item()
43
+ / 2
44
+ )
45
+
46
+
47
+ RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
48
+
49
+
50
+ class RunGroupRuns:
51
+ def __init__(
52
+ self,
53
+ run_group: RunGroup,
54
+ control: List[str],
55
+ experiment: List[str],
56
+ summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
57
+ summary_metrics: List[str] = ["mean", "result"],
58
+ ) -> None:
59
+ self.algo = run_group.algo
60
+ self.env = run_group.env_id
61
+ self.control = set(control)
62
+ self.experiment = set(experiment)
63
+
64
+ self.summary_stats = summary_stats
65
+ self.summary_metrics = summary_metrics
66
+
67
+ self.control_runs = []
68
+ self.experiment_runs = []
69
+
70
+ def add_run(self, run: wandb.apis.public.Run) -> None:
71
+ wandb_tags = set(run.config.get("wandb_tags", []))
72
+ if self.control & wandb_tags:
73
+ self.control_runs.append(run)
74
+ elif self.experiment & wandb_tags:
75
+ self.experiment_runs.append(run)
76
+
77
+ def comparisons_by_metric(self) -> Dict[str, Comparison]:
78
+ c_by_m = {}
79
+ for metric in (
80
+ f"{s}_{m}"
81
+ for s, m in itertools.product(self.summary_stats, self.summary_metrics)
82
+ ):
83
+ c_by_m[metric] = Comparison(
84
+ [c.summary[metric] for c in self.control_runs],
85
+ [e.summary[metric] for e in self.experiment_runs],
86
+ )
87
+ return c_by_m
88
+
89
+ @staticmethod
90
+ def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
91
+ results = defaultdict(list)
92
+ for r in rows:
93
+ if not r.control_runs or not r.experiment_runs:
94
+ continue
95
+ results["algo"].append(r.algo)
96
+ results["env"].append(r.env)
97
+ results["control"].append(r.control)
98
+ results["expierment"].append(r.experiment)
99
+ c_by_m = r.comparisons_by_metric()
100
+ results["score"].append(
101
+ sum(m.score() for m in c_by_m.values()) / len(c_by_m)
102
+ )
103
+ for m, c in c_by_m.items():
104
+ results[f"{m}_mean"].append(c.mean_diff_percentage())
105
+ results[f"{m}_median"].append(c.median_diff_percentage())
106
+ return pd.DataFrame(results)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument(
112
+ "-p",
113
+ "--wandb-project-name",
114
+ type=str,
115
+ default="rl-algo-impls-benchmarks",
116
+ help="WandB project name to load runs from",
117
+ )
118
+ parser.add_argument(
119
+ "--wandb-entity",
120
+ type=str,
121
+ default=None,
122
+ help="WandB team. None uses default entity",
123
+ )
124
+ parser.add_argument(
125
+ "-n",
126
+ "--wandb-hostname-tag",
127
+ type=str,
128
+ nargs="*",
129
+ help="WandB tags for hostname (i.e. host_192-9-145-26)",
130
+ )
131
+ parser.add_argument(
132
+ "-c",
133
+ "--wandb-control-tag",
134
+ type=str,
135
+ nargs="+",
136
+ help="WandB tag for control commit (i.e. benchmark_5598ebc)",
137
+ )
138
+ parser.add_argument(
139
+ "-e",
140
+ "--wandb-experiment-tag",
141
+ type=str,
142
+ nargs="+",
143
+ help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
144
+ )
145
+ parser.add_argument(
146
+ "--exclude-envs",
147
+ type=str,
148
+ nargs="*",
149
+ help="Environments to exclude from comparison",
150
+ )
151
+ # parser.set_defaults(
152
+ # wandb_hostname_tag=["host_150-230-44-105", "host_155-248-214-128"],
153
+ # wandb_control_tag=["benchmark_fbc943f"],
154
+ # wandb_experiment_tag=["benchmark_f59bf74"],
155
+ # exclude_envs=[],
156
+ # )
157
+ args = parser.parse_args()
158
+ print(args)
159
+
160
+ api = wandb.Api()
161
+ all_runs = api.runs(
162
+ path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
163
+ order="+created_at",
164
+ )
165
+
166
+ runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
167
+ wandb_hostname_tags = set(args.wandb_hostname_tag)
168
+ for r in all_runs:
169
+ if r.state != "finished":
170
+ continue
171
+ wandb_tags = set(r.config.get("wandb_tags", []))
172
+ if not wandb_tags or not wandb_hostname_tags & wandb_tags:
173
+ continue
174
+ rg = RunGroup(r.config["algo"], r.config.get("env_id") or r.config["env"])
175
+ if args.exclude_envs and rg.env_id in args.exclude_envs:
176
+ continue
177
+ if rg not in runs_by_run_group:
178
+ runs_by_run_group[rg] = RunGroupRuns(
179
+ rg,
180
+ args.wandb_control_tag,
181
+ args.wandb_experiment_tag,
182
+ )
183
+ runs_by_run_group[rg].add_run(r)
184
+ df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
185
+ print(f"**Total Score: {sum(df.score)}**")
186
+ df.loc["mean"] = df.mean(numeric_only=True)
187
+ print(df.to_markdown())
dqn/dqn.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from collections import deque
9
+ from torch.optim import Adam
10
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
11
+ from torch.utils.tensorboard.writer import SummaryWriter
12
+ from typing import List, NamedTuple, Optional, TypeVar
13
+
14
+ from dqn.policy import DQNPolicy
15
+ from shared.algorithm import Algorithm
16
+ from shared.callbacks.callback import Callback
17
+ from shared.schedule import linear_schedule
18
+
19
+
20
+ class Transition(NamedTuple):
21
+ obs: np.ndarray
22
+ action: np.ndarray
23
+ reward: float
24
+ done: bool
25
+ next_obs: np.ndarray
26
+
27
+
28
+ class Batch(NamedTuple):
29
+ obs: np.ndarray
30
+ actions: np.ndarray
31
+ rewards: np.ndarray
32
+ dones: np.ndarray
33
+ next_obs: np.ndarray
34
+
35
+
36
+ class ReplayBuffer:
37
+ def __init__(self, num_envs: int, maxlen: int) -> None:
38
+ self.num_envs = num_envs
39
+ self.buffer = deque(maxlen=maxlen)
40
+
41
+ def add(
42
+ self,
43
+ obs: VecEnvObs,
44
+ action: np.ndarray,
45
+ reward: np.ndarray,
46
+ done: np.ndarray,
47
+ next_obs: VecEnvObs,
48
+ ) -> None:
49
+ assert isinstance(obs, np.ndarray)
50
+ assert isinstance(next_obs, np.ndarray)
51
+ for i in range(self.num_envs):
52
+ self.buffer.append(
53
+ Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
54
+ )
55
+
56
+ def sample(self, batch_size: int) -> Batch:
57
+ ts = random.sample(self.buffer, batch_size)
58
+ return Batch(
59
+ obs=np.array([t.obs for t in ts]),
60
+ actions=np.array([t.action for t in ts]),
61
+ rewards=np.array([t.reward for t in ts]),
62
+ dones=np.array([t.done for t in ts]),
63
+ next_obs=np.array([t.next_obs for t in ts]),
64
+ )
65
+
66
+ def __len__(self) -> int:
67
+ return len(self.buffer)
68
+
69
+
70
+ DQNSelf = TypeVar("DQNSelf", bound="DQN")
71
+
72
+
73
+ class DQN(Algorithm):
74
+ def __init__(
75
+ self,
76
+ policy: DQNPolicy,
77
+ env: VecEnv,
78
+ device: torch.device,
79
+ tb_writer: SummaryWriter,
80
+ learning_rate: float = 1e-4,
81
+ buffer_size: int = 1_000_000,
82
+ learning_starts: int = 50_000,
83
+ batch_size: int = 32,
84
+ tau: float = 1.0,
85
+ gamma: float = 0.99,
86
+ train_freq: int = 4,
87
+ gradient_steps: int = 1,
88
+ target_update_interval: int = 10_000,
89
+ exploration_fraction: float = 0.1,
90
+ exploration_initial_eps: float = 1.0,
91
+ exploration_final_eps: float = 0.05,
92
+ max_grad_norm: float = 10.0,
93
+ ) -> None:
94
+ super().__init__(policy, env, device, tb_writer)
95
+ self.policy = policy
96
+
97
+ self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
98
+
99
+ self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
100
+ self.target_q_net.train(False)
101
+ self.tau = tau
102
+ self.target_update_interval = target_update_interval
103
+
104
+ self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
105
+ self.batch_size = batch_size
106
+
107
+ self.learning_starts = learning_starts
108
+ self.train_freq = train_freq
109
+ self.gradient_steps = gradient_steps
110
+
111
+ self.gamma = gamma
112
+ self.exploration_eps_schedule = linear_schedule(
113
+ exploration_initial_eps,
114
+ exploration_final_eps,
115
+ end_fraction=exploration_fraction,
116
+ )
117
+
118
+ self.max_grad_norm = max_grad_norm
119
+
120
+ def learn(
121
+ self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
122
+ ) -> DQNSelf:
123
+ self.policy.train(True)
124
+ obs = self.env.reset()
125
+ obs = self._collect_rollout(self.learning_starts, obs, 1)
126
+ learning_steps = total_timesteps - self.learning_starts
127
+ timesteps_elapsed = 0
128
+ steps_since_target_update = 0
129
+ while timesteps_elapsed < learning_steps:
130
+ progress = timesteps_elapsed / learning_steps
131
+ eps = self.exploration_eps_schedule(progress)
132
+ obs = self._collect_rollout(self.train_freq, obs, eps)
133
+ rollout_steps = self.train_freq
134
+ timesteps_elapsed += rollout_steps
135
+ for _ in range(
136
+ self.gradient_steps if self.gradient_steps > 0 else self.train_freq
137
+ ):
138
+ self.train()
139
+ steps_since_target_update += rollout_steps
140
+ if steps_since_target_update >= self.target_update_interval:
141
+ self._update_target()
142
+ steps_since_target_update = 0
143
+ if callback:
144
+ callback.on_step(timesteps_elapsed=rollout_steps)
145
+ return self
146
+
147
+ def train(self) -> None:
148
+ if len(self.replay_buffer) < self.batch_size:
149
+ return
150
+ o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
151
+ o = torch.as_tensor(o, device=self.device)
152
+ a = torch.as_tensor(a, device=self.device).unsqueeze(1)
153
+ r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
154
+ d = torch.as_tensor(d, dtype=torch.long, device=self.device)
155
+ next_o = torch.as_tensor(next_o, device=self.device)
156
+
157
+ with torch.no_grad():
158
+ target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
159
+ current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
160
+ loss = F.smooth_l1_loss(current, target)
161
+
162
+ self.optimizer.zero_grad()
163
+ loss.backward()
164
+ if self.max_grad_norm:
165
+ nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
166
+ self.optimizer.step()
167
+
168
+ def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
169
+ for _ in range(0, timesteps, self.env.num_envs):
170
+ action = self.policy.act(obs, eps, deterministic=False)
171
+ next_obs, reward, done, _ = self.env.step(action)
172
+ self.replay_buffer.add(obs, action, reward, done, next_obs)
173
+ obs = next_obs
174
+ return obs
175
+
176
+ def _update_target(self) -> None:
177
+ for target_param, param in zip(
178
+ self.target_q_net.parameters(), self.policy.q_net.parameters()
179
+ ):
180
+ target_param.data.copy_(
181
+ self.tau * param.data + (1 - self.tau) * target_param.data
182
+ )
dqn/policy.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+
5
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
6
+ from typing import Optional, Sequence, TypeVar
7
+
8
+ from dqn.q_net import QNetwork
9
+ from shared.policy.policy import Policy
10
+
11
+ DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
12
+
13
+
14
+ class DQNPolicy(Policy):
15
+ def __init__(
16
+ self,
17
+ env: VecEnv,
18
+ hidden_sizes: Sequence[int] = [],
19
+ cnn_feature_dim: int = 512,
20
+ cnn_style: str = "nature",
21
+ cnn_layers_init_orthogonal: Optional[bool] = None,
22
+ impala_channels: Sequence[int] = (16, 32, 32),
23
+ **kwargs,
24
+ ) -> None:
25
+ super().__init__(env, **kwargs)
26
+ self.q_net = QNetwork(
27
+ env.observation_space,
28
+ env.action_space,
29
+ hidden_sizes,
30
+ cnn_feature_dim=cnn_feature_dim,
31
+ cnn_style=cnn_style,
32
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
33
+ impala_channels=impala_channels,
34
+ )
35
+
36
+ def act(
37
+ self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
38
+ ) -> np.ndarray:
39
+ assert eps == 0 if deterministic else eps >= 0
40
+ if not deterministic and np.random.random() < eps:
41
+ return np.array(
42
+ [self.env.action_space.sample() for _ in range(self.env.num_envs)]
43
+ )
44
+ else:
45
+ o = self._as_tensor(obs)
46
+ with torch.no_grad():
47
+ return self.q_net(o).argmax(axis=1).cpu().numpy()
dqn/q_net.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch as th
3
+ import torch.nn as nn
4
+
5
+ from gym.spaces import Discrete
6
+ from typing import Optional, Sequence, Type
7
+
8
+ from shared.module.feature_extractor import FeatureExtractor
9
+ from shared.module.module import mlp
10
+
11
+
12
+ class QNetwork(nn.Module):
13
+ def __init__(
14
+ self,
15
+ observation_space: gym.Space,
16
+ action_space: gym.Space,
17
+ hidden_sizes: Sequence[int] = [],
18
+ activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
19
+ cnn_feature_dim: int = 512,
20
+ cnn_style: str = "nature",
21
+ cnn_layers_init_orthogonal: Optional[bool] = None,
22
+ ) -> None:
23
+ super().__init__()
24
+ assert isinstance(action_space, Discrete)
25
+ self._feature_extractor = FeatureExtractor(
26
+ observation_space,
27
+ activation,
28
+ cnn_feature_dim=cnn_feature_dim,
29
+ cnn_style=cnn_style,
30
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
31
+ )
32
+ layer_sizes = (
33
+ (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
34
+ )
35
+ self._fc = mlp(layer_sizes, activation)
36
+
37
+ def forward(self, obs: th.Tensor) -> th.Tensor:
38
+ x = self._feature_extractor(obs)
39
+ return self._fc(x)
enjoy.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ from runner.evaluate import EvalArgs, evaluate_model
7
+ from runner.running_utils import base_parser
8
+
9
+
10
+ if __name__ == "__main__":
11
+ parser = base_parser(multiple=False)
12
+ parser.add_argument("--render", default=True, type=bool)
13
+ parser.add_argument("--best", default=True, type=bool)
14
+ parser.add_argument("--n_envs", default=1, type=int)
15
+ parser.add_argument("--n_episodes", default=3, type=int)
16
+ parser.add_argument("--deterministic-eval", default=None, type=bool)
17
+ parser.add_argument(
18
+ "--no-print-returns", action="store_true", help="Limit printing"
19
+ )
20
+ # wandb-run-path overrides base RunArgs
21
+ parser.add_argument("--wandb-run-path", default=None, type=str)
22
+ parser.set_defaults(
23
+ algo=["ppo"],
24
+ )
25
+ args = parser.parse_args()
26
+ args.algo = args.algo[0]
27
+ args.env = args.env[0]
28
+ args = EvalArgs(**vars(args))
29
+
30
+ evaluate_model(args, os.path.dirname(__file__))
environment.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: rl_algo_impls
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - nodefaults
6
+ dependencies:
7
+ - python=3.10.*
8
+ - mamba
9
+ - pip
10
+ - poetry
11
+ - pytorch
12
+ - torchvision
13
+ - torchaudio
14
+ - cmake
15
+ - swig
16
+ - ipywidgets
17
+ - black
hf-deep-rl/dqn_SpaceInvadersNoFrameskip_v4.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
huggingface_publish.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
+
5
+ import argparse
6
+ import requests
7
+ import shutil
8
+ import subprocess
9
+ import tempfile
10
+ import wandb
11
+ import wandb.apis.public
12
+
13
+ from typing import List, Optional
14
+
15
+ from huggingface_hub.hf_api import HfApi, upload_folder
16
+ from huggingface_hub.repocard import metadata_save
17
+ from pyvirtualdisplay.display import Display
18
+
19
+ from publish.markdown_format import EvalTableData, model_card_text
20
+ from runner.config import EnvHyperparams
21
+ from runner.evaluate import EvalArgs, evaluate_model
22
+ from runner.env import make_eval_env
23
+ from shared.callbacks.eval_callback import evaluate
24
+ from wrappers.vec_episode_recorder import VecEpisodeRecorder
25
+
26
+
27
+ def publish(
28
+ wandb_run_paths: List[str],
29
+ wandb_report_url: str,
30
+ huggingface_user: Optional[str] = None,
31
+ huggingface_token: Optional[str] = None,
32
+ virtual_display: bool = False,
33
+ ) -> None:
34
+ if virtual_display:
35
+ display = Display(visible=False, size=(1400, 900))
36
+ display.start()
37
+
38
+ api = wandb.Api()
39
+ runs = [api.run(rp) for rp in wandb_run_paths]
40
+ algo = runs[0].config["algo"]
41
+ hyperparam_id = runs[0].config["env"]
42
+ evaluations = [
43
+ evaluate_model(
44
+ EvalArgs(
45
+ algo,
46
+ hyperparam_id,
47
+ seed=r.config.get("seed", None),
48
+ render=False,
49
+ best=True,
50
+ n_envs=None,
51
+ n_episodes=10,
52
+ no_print_returns=True,
53
+ wandb_run_path="/".join(r.path),
54
+ ),
55
+ os.path.dirname(__file__),
56
+ )
57
+ for r in runs
58
+ ]
59
+ run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
60
+ table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
61
+ best_eval = sorted(
62
+ table_data, key=lambda d: d.evaluation.stats.score, reverse=True
63
+ )[0]
64
+
65
+ with tempfile.TemporaryDirectory() as tmpdirname:
66
+ _, (policy, stats, config) = best_eval
67
+
68
+ repo_name = config.model_name(include_seed=False)
69
+ repo_dir_path = os.path.join(tmpdirname, repo_name)
70
+ # Locally clone this repo to a temp directory
71
+ subprocess.run(["git", "clone", ".", repo_dir_path])
72
+ shutil.rmtree(os.path.join(repo_dir_path, ".git"))
73
+ model_path = config.model_dir_path(best=True, downloaded=True)
74
+ shutil.copytree(
75
+ model_path,
76
+ os.path.join(
77
+ repo_dir_path, "saved_models", config.model_dir_name(best=True)
78
+ ),
79
+ )
80
+
81
+ github_url = "https://github.com/sgoodfriend/rl-algo-impls"
82
+ commit_hash = run_metadata.get("git", {}).get("commit", None)
83
+ env_id = runs[0].config.get("env_id") or runs[0].config["env"]
84
+ card_text = model_card_text(
85
+ algo,
86
+ env_id,
87
+ github_url,
88
+ commit_hash,
89
+ wandb_report_url,
90
+ table_data,
91
+ best_eval,
92
+ )
93
+ readme_filepath = os.path.join(repo_dir_path, "README.md")
94
+ os.remove(readme_filepath)
95
+ with open(readme_filepath, "w") as f:
96
+ f.write(card_text)
97
+
98
+ metadata = {
99
+ "library_name": "rl-algo-impls",
100
+ "tags": [
101
+ env_id,
102
+ algo,
103
+ "deep-reinforcement-learning",
104
+ "reinforcement-learning",
105
+ ],
106
+ "model-index": [
107
+ {
108
+ "name": algo,
109
+ "results": [
110
+ {
111
+ "metrics": [
112
+ {
113
+ "type": "mean_reward",
114
+ "value": str(stats.score),
115
+ "name": "mean_reward",
116
+ }
117
+ ],
118
+ "task": {
119
+ "type": "reinforcement-learning",
120
+ "name": "reinforcement-learning",
121
+ },
122
+ "dataset": {
123
+ "name": env_id,
124
+ "type": env_id,
125
+ },
126
+ }
127
+ ],
128
+ }
129
+ ],
130
+ }
131
+ metadata_save(readme_filepath, metadata)
132
+
133
+ video_env = VecEpisodeRecorder(
134
+ make_eval_env(
135
+ config,
136
+ EnvHyperparams(**config.env_hyperparams),
137
+ override_n_envs=1,
138
+ normalize_load_path=model_path,
139
+ ),
140
+ os.path.join(repo_dir_path, "replay"),
141
+ max_video_length=3600,
142
+ )
143
+ evaluate(
144
+ video_env,
145
+ policy,
146
+ 1,
147
+ deterministic=config.eval_params.get("deterministic", True),
148
+ )
149
+
150
+ api = HfApi()
151
+ huggingface_user = huggingface_user or api.whoami()["name"]
152
+ huggingface_repo = f"{huggingface_user}/{repo_name}"
153
+ api.create_repo(
154
+ token=huggingface_token,
155
+ repo_id=huggingface_repo,
156
+ private=False,
157
+ exist_ok=True,
158
+ )
159
+ repo_url = upload_folder(
160
+ repo_id=huggingface_repo,
161
+ folder_path=repo_dir_path,
162
+ path_in_repo="",
163
+ commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
164
+ token=huggingface_token,
165
+ )
166
+ print(f"Pushed model to the hub: {repo_url}")
167
+
168
+
169
+ if __name__ == "__main__":
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument(
172
+ "--wandb-run-paths",
173
+ type=str,
174
+ nargs="+",
175
+ help="Run paths of the form entity/project/run_id",
176
+ )
177
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
178
+ parser.add_argument(
179
+ "--huggingface-user",
180
+ type=str,
181
+ help="Huggingface user or team to upload model cards",
182
+ default=None,
183
+ )
184
+ parser.add_argument(
185
+ "--virtual-display", action="store_true", help="Use headless virtual display"
186
+ )
187
+ args = parser.parse_args()
188
+ print(args)
189
+ publish(**vars(args))
hyperparams/dqn.yml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 5e4
3
+ env_hyperparams:
4
+ rolling_length: 50
5
+ policy_hyperparams:
6
+ hidden_sizes: [256, 256]
7
+ algo_hyperparams:
8
+ learning_rate: !!float 2.3e-3
9
+ batch_size: 64
10
+ buffer_size: 100000
11
+ learning_starts: 1000
12
+ gamma: 0.99
13
+ target_update_interval: 10
14
+ train_freq: 256
15
+ gradient_steps: 128
16
+ exploration_fraction: 0.16
17
+ exploration_final_eps: 0.04
18
+ eval_params:
19
+ step_freq: !!float 1e4
20
+
21
+ CartPole-v0:
22
+ <<: *cartpole-defaults
23
+ n_timesteps: !!float 4e4
24
+
25
+ MountainCar-v0:
26
+ n_timesteps: !!float 1.2e5
27
+ env_hyperparams:
28
+ rolling_length: 50
29
+ policy_hyperparams:
30
+ hidden_sizes: [256, 256]
31
+ algo_hyperparams:
32
+ learning_rate: !!float 4e-3
33
+ batch_size: 128
34
+ buffer_size: 10000
35
+ learning_starts: 1000
36
+ gamma: 0.98
37
+ target_update_interval: 600
38
+ train_freq: 16
39
+ gradient_steps: 8
40
+ exploration_fraction: 0.2
41
+ exploration_final_eps: 0.07
42
+
43
+ Acrobot-v1:
44
+ n_timesteps: !!float 1e5
45
+ env_hyperparams:
46
+ rolling_length: 50
47
+ policy_hyperparams:
48
+ hidden_sizes: [256, 256]
49
+ algo_hyperparams:
50
+ learning_rate: !!float 6.3e-4
51
+ batch_size: 128
52
+ buffer_size: 50000
53
+ learning_starts: 0
54
+ gamma: 0.99
55
+ target_update_interval: 250
56
+ train_freq: 4
57
+ gradient_steps: -1
58
+ exploration_fraction: 0.12
59
+ exploration_final_eps: 0.1
60
+
61
+ LunarLander-v2:
62
+ n_timesteps: !!float 5e5
63
+ env_hyperparams:
64
+ rolling_length: 50
65
+ policy_hyperparams:
66
+ hidden_sizes: [256, 256]
67
+ algo_hyperparams:
68
+ learning_rate: !!float 1e-4
69
+ batch_size: 256
70
+ buffer_size: 100000
71
+ learning_starts: 10000
72
+ gamma: 0.99
73
+ target_update_interval: 250
74
+ train_freq: 8
75
+ gradient_steps: -1
76
+ exploration_fraction: 0.12
77
+ exploration_final_eps: 0.1
78
+ max_grad_norm: 0.5
79
+ eval_params:
80
+ step_freq: 25_000
81
+
82
+ _atari: &atari-defaults
83
+ n_timesteps: !!float 1e7
84
+ env_hyperparams:
85
+ frame_stack: 4
86
+ no_reward_timeout_steps: 1_000
87
+ no_reward_fire_steps: 500
88
+ n_envs: 8
89
+ vec_env_class: "subproc"
90
+ algo_hyperparams:
91
+ buffer_size: 100000
92
+ learning_rate: !!float 1e-4
93
+ batch_size: 32
94
+ learning_starts: 100000
95
+ target_update_interval: 1000
96
+ train_freq: 8
97
+ gradient_steps: 2
98
+ exploration_fraction: 0.1
99
+ exploration_final_eps: 0.01
100
+ eval_params:
101
+ deterministic: false
102
+
103
+ PongNoFrameskip-v4:
104
+ <<: *atari-defaults
105
+ n_timesteps: !!float 2.5e6
106
+
107
+ _impala-atari: &impala-atari-defaults
108
+ <<: *atari-defaults
109
+ policy_hyperparams:
110
+ cnn_style: impala
111
+ cnn_feature_dim: 256
112
+ init_layers_orthogonal: true
113
+ cnn_layers_init_orthogonal: false
114
+
115
+ impala-PongNoFrameskip-v4:
116
+ <<: *impala-atari-defaults
117
+ env_id: PongNoFrameskip-v4
118
+ n_timesteps: !!float 2.5e6
119
+
120
+ impala-BreakoutNoFrameskip-v4:
121
+ <<: *impala-atari-defaults
122
+ env_id: BreakoutNoFrameskip-v4
123
+
124
+ impala-SpaceInvadersNoFrameskip-v4:
125
+ <<: *impala-atari-defaults
126
+ env_id: SpaceInvadersNoFrameskip-v4
127
+
128
+ impala-QbertNoFrameskip-v4:
129
+ <<: *impala-atari-defaults
130
+ env_id: QbertNoFrameskip-v4
hyperparams/ppo.yml ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 1e5
3
+ env_hyperparams:
4
+ n_envs: 8
5
+ algo_hyperparams:
6
+ n_steps: 32
7
+ batch_size: 256
8
+ n_epochs: 20
9
+ gae_lambda: 0.8
10
+ gamma: 0.98
11
+ ent_coef: 0.0
12
+ learning_rate: 0.001
13
+ learning_rate_decay: linear
14
+ clip_range: 0.2
15
+ clip_range_decay: linear
16
+ eval_params:
17
+ step_freq: !!float 2.5e4
18
+
19
+ CartPole-v0:
20
+ <<: *cartpole-defaults
21
+ n_timesteps: !!float 5e4
22
+
23
+ MountainCar-v0:
24
+ n_timesteps: !!float 1e6
25
+ env_hyperparams:
26
+ normalize: true
27
+ n_envs: 16
28
+ algo_hyperparams:
29
+ n_steps: 16
30
+ n_epochs: 4
31
+ gae_lambda: 0.98
32
+ gamma: 0.99
33
+ ent_coef: 0.0
34
+
35
+ MountainCarContinuous-v0:
36
+ n_timesteps: !!float 1e5
37
+ env_hyperparams:
38
+ normalize: true
39
+ n_envs: 4
40
+ # policy_hyperparams:
41
+ # init_layers_orthogonal: false
42
+ # log_std_init: -3.29
43
+ # use_sde: true
44
+ algo_hyperparams:
45
+ n_steps: 512
46
+ batch_size: 256
47
+ n_epochs: 10
48
+ learning_rate: !!float 7.77e-5
49
+ ent_coef: 0.01 # 0.00429
50
+ ent_coef_decay: linear
51
+ clip_range: 0.1
52
+ gae_lambda: 0.9
53
+ max_grad_norm: 5
54
+ vf_coef: 0.19
55
+ eval_params:
56
+ step_freq: 5000
57
+
58
+ Acrobot-v1:
59
+ n_timesteps: !!float 1e6
60
+ env_hyperparams:
61
+ n_envs: 16
62
+ normalize: true
63
+ algo_hyperparams:
64
+ n_steps: 256
65
+ n_epochs: 4
66
+ gae_lambda: 0.94
67
+ gamma: 0.99
68
+ ent_coef: 0.0
69
+
70
+ LunarLander-v2:
71
+ n_timesteps: !!float 1e6
72
+ env_hyperparams:
73
+ n_envs: 16
74
+ algo_hyperparams:
75
+ n_steps: 1024
76
+ batch_size: 64
77
+ n_epochs: 4
78
+ gae_lambda: 0.98
79
+ gamma: 0.999
80
+ ent_coef: 0.01
81
+ ent_coef_decay: linear
82
+ normalize_advantage: false
83
+
84
+ BipedalWalker-v3:
85
+ n_timesteps: !!float 10e6
86
+ env_hyperparams:
87
+ n_envs: 16
88
+ normalize: true
89
+ algo_hyperparams:
90
+ n_steps: 2048
91
+ batch_size: 64
92
+ gae_lambda: 0.95
93
+ gamma: 0.99
94
+ n_epochs: 10
95
+ ent_coef: 0.001
96
+ learning_rate: !!float 2.5e-4
97
+ learning_rate_decay: linear
98
+ clip_range: 0.2
99
+ clip_range_decay: linear
100
+
101
+ CarRacing-v0: &carracing-defaults
102
+ n_timesteps: !!float 4e6
103
+ env_hyperparams:
104
+ n_envs: 8
105
+ frame_stack: 4
106
+ policy_hyperparams: &carracing-policy-defaults
107
+ use_sde: true
108
+ log_std_init: -2
109
+ init_layers_orthogonal: false
110
+ activation_fn: relu
111
+ share_features_extractor: false
112
+ cnn_feature_dim: 256
113
+ hidden_sizes: [256]
114
+ algo_hyperparams:
115
+ n_steps: 512
116
+ batch_size: 128
117
+ n_epochs: 10
118
+ learning_rate: !!float 1e-4
119
+ learning_rate_decay: linear
120
+ gamma: 0.99
121
+ gae_lambda: 0.95
122
+ ent_coef: 0.0
123
+ sde_sample_freq: 4
124
+ max_grad_norm: 0.5
125
+ vf_coef: 0.5
126
+ clip_range: 0.2
127
+
128
+ impala-CarRacing-v0:
129
+ <<: *carracing-defaults
130
+ env_id: CarRacing-v0
131
+ policy_hyperparams:
132
+ <<: *carracing-policy-defaults
133
+ cnn_style: impala
134
+ init_layers_orthogonal: true
135
+ cnn_layers_init_orthogonal: false
136
+ hidden_sizes: []
137
+
138
+ # BreakoutNoFrameskip-v4
139
+ # PongNoFrameskip-v4
140
+ # SpaceInvadersNoFrameskip-v4
141
+ # QbertNoFrameskip-v4
142
+ _atari: &atari-defaults
143
+ n_timesteps: !!float 1e7
144
+ env_hyperparams: &atari-env-defaults
145
+ n_envs: 8
146
+ frame_stack: 4
147
+ no_reward_timeout_steps: 1000
148
+ no_reward_fire_steps: 500
149
+ vec_env_class: subproc
150
+ policy_hyperparams: &atari-policy-defaults
151
+ activation_fn: relu
152
+ algo_hyperparams:
153
+ n_steps: 128
154
+ batch_size: 256
155
+ n_epochs: 4
156
+ learning_rate: !!float 2.5e-4
157
+ learning_rate_decay: linear
158
+ clip_range: 0.1
159
+ clip_range_decay: linear
160
+ vf_coef: 0.5
161
+ ent_coef: 0.01
162
+ eval_params:
163
+ deterministic: false
164
+
165
+ _norm-rewards-atari: &norm-rewards-atari-default
166
+ <<: *atari-defaults
167
+ env_hyperparams:
168
+ <<: *atari-env-defaults
169
+ clip_atari_rewards: false
170
+ normalize: true
171
+ normalize_kwargs:
172
+ norm_obs: false
173
+ norm_reward: true
174
+
175
+ norm-rewards-BreakoutNoFrameskip-v4:
176
+ <<: *norm-rewards-atari-default
177
+ env_id: BreakoutNoFrameskip-v4
178
+
179
+ debug-PongNoFrameskip-v4:
180
+ <<: *atari-defaults
181
+ device: cpu
182
+ env_id: PongNoFrameskip-v4
183
+ env_hyperparams:
184
+ <<: *atari-env-defaults
185
+ vec_env_class: dummy
186
+
187
+ _impala-atari: &impala-atari-defaults
188
+ <<: *atari-defaults
189
+ policy_hyperparams:
190
+ <<: *atari-policy-defaults
191
+ cnn_style: impala
192
+ cnn_feature_dim: 256
193
+ init_layers_orthogonal: true
194
+ cnn_layers_init_orthogonal: false
195
+
196
+ impala-PongNoFrameskip-v4:
197
+ <<: *impala-atari-defaults
198
+ env_id: PongNoFrameskip-v4
199
+
200
+ impala-BreakoutNoFrameskip-v4:
201
+ <<: *impala-atari-defaults
202
+ env_id: BreakoutNoFrameskip-v4
203
+
204
+ impala-SpaceInvadersNoFrameskip-v4:
205
+ <<: *impala-atari-defaults
206
+ env_id: SpaceInvadersNoFrameskip-v4
207
+
208
+ impala-QbertNoFrameskip-v4:
209
+ <<: *impala-atari-defaults
210
+ env_id: QbertNoFrameskip-v4
211
+
212
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
213
+ n_timesteps: !!float 2e6
214
+ env_hyperparams: &pybullet-env-defaults
215
+ n_envs: 16
216
+ normalize: true
217
+ policy_hyperparams: &pybullet-policy-defaults
218
+ pi_hidden_sizes: [256, 256]
219
+ v_hidden_sizes: [256, 256]
220
+ activation_fn: relu
221
+ algo_hyperparams: &pybullet-algo-defaults
222
+ n_steps: 512
223
+ batch_size: 128
224
+ n_epochs: 20
225
+ gamma: 0.99
226
+ gae_lambda: 0.9
227
+ ent_coef: 0.0
228
+ max_grad_norm: 0.5
229
+ vf_coef: 0.5
230
+ learning_rate: !!float 3e-5
231
+ clip_range: 0.4
232
+
233
+ AntBulletEnv-v0:
234
+ <<: *pybullet-defaults
235
+ policy_hyperparams:
236
+ <<: *pybullet-policy-defaults
237
+ algo_hyperparams:
238
+ <<: *pybullet-algo-defaults
239
+
240
+ Walker2DBulletEnv-v0:
241
+ <<: *pybullet-defaults
242
+ algo_hyperparams:
243
+ <<: *pybullet-algo-defaults
244
+ clip_range_decay: linear
245
+
246
+ HopperBulletEnv-v0:
247
+ <<: *pybullet-defaults
248
+ algo_hyperparams:
249
+ <<: *pybullet-algo-defaults
250
+ clip_range_decay: linear
251
+
252
+ HumanoidBulletEnv-v0:
253
+ <<: *pybullet-defaults
254
+ n_timesteps: !!float 1e7
255
+ env_hyperparams:
256
+ <<: *pybullet-env-defaults
257
+ n_envs: 8
258
+ policy_hyperparams:
259
+ <<: *pybullet-policy-defaults
260
+ # log_std_init: -1
261
+ algo_hyperparams:
262
+ <<: *pybullet-algo-defaults
263
+ n_steps: 2048
264
+ batch_size: 64
265
+ n_epochs: 10
266
+ gae_lambda: 0.95
267
+ learning_rate: !!float 2.5e-4
268
+ clip_range: 0.2
269
+
270
+ _procgen: &procgen-defaults
271
+ env_hyperparams: &procgen-env-defaults
272
+ is_procgen: true
273
+ n_envs: 64
274
+ # grayscale: false
275
+ # frame_stack: 4
276
+ normalize: true # procgen only normalizes reward
277
+ make_kwargs: &procgen-make-kwargs-defaults
278
+ num_threads: 8
279
+ policy_hyperparams: &procgen-policy-defaults
280
+ activation_fn: relu
281
+ cnn_style: impala
282
+ cnn_feature_dim: 256
283
+ init_layers_orthogonal: true
284
+ cnn_layers_init_orthogonal: false
285
+ algo_hyperparams: &procgen-algo-defaults
286
+ gamma: 0.999
287
+ gae_lambda: 0.95
288
+ n_steps: 256
289
+ batch_size: 2048
290
+ n_epochs: 3
291
+ ent_coef: 0.01
292
+ clip_range: 0.2
293
+ # clip_range_decay: linear
294
+ clip_range_vf: 0.2
295
+ learning_rate: !!float 5e-4
296
+ # learning_rate_decay: linear
297
+ vf_coef: 0.5
298
+ eval_params: &procgen-eval-defaults
299
+ ignore_first_episode: true
300
+ # deterministic: false
301
+ step_freq: !!float 1e5
302
+
303
+ _procgen-easy: &procgen-easy-defaults
304
+ <<: *procgen-defaults
305
+ n_timesteps: !!float 25e6
306
+ env_hyperparams: &procgen-easy-env-defaults
307
+ <<: *procgen-env-defaults
308
+ make_kwargs:
309
+ <<: *procgen-make-kwargs-defaults
310
+ distribution_mode: easy
311
+
312
+ procgen-coinrun-easy: &coinrun-easy-defaults
313
+ <<: *procgen-easy-defaults
314
+ env_id: coinrun
315
+
316
+ debug-procgen-coinrun:
317
+ <<: *coinrun-easy-defaults
318
+ device: cpu
319
+
320
+ procgen-starpilot-easy:
321
+ <<: *procgen-easy-defaults
322
+ env_id: starpilot
323
+
324
+ procgen-bossfight-easy:
325
+ <<: *procgen-easy-defaults
326
+ env_id: bossfight
327
+
328
+ procgen-bigfish-easy:
329
+ <<: *procgen-easy-defaults
330
+ env_id: bigfish
331
+
332
+ _procgen-hard: &procgen-hard-defaults
333
+ <<: *procgen-defaults
334
+ n_timesteps: !!float 200e6
335
+ env_hyperparams: &procgen-hard-env-defaults
336
+ <<: *procgen-env-defaults
337
+ n_envs: 256
338
+ make_kwargs:
339
+ <<: *procgen-make-kwargs-defaults
340
+ distribution_mode: hard
341
+ algo_hyperparams: &procgen-hard-algo-defaults
342
+ <<: *procgen-algo-defaults
343
+ batch_size: 8192
344
+ clip_range_decay: linear
345
+ learning_rate_decay: linear
346
+ eval_params:
347
+ <<: *procgen-eval-defaults
348
+ step_freq: !!float 5e5
349
+
350
+ procgen-starpilot-hard: &procgen-starpilot-hard-defaults
351
+ <<: *procgen-hard-defaults
352
+ env_id: starpilot
353
+
354
+ procgen-starpilot-hard-2xIMPALA:
355
+ <<: *procgen-starpilot-hard-defaults
356
+ policy_hyperparams:
357
+ <<: *procgen-policy-defaults
358
+ impala_channels: [32, 64, 64]
359
+ algo_hyperparams:
360
+ <<: *procgen-hard-algo-defaults
361
+ learning_rate: !!float 3.3e-4
362
+
363
+ procgen-starpilot-hard-2xIMPALA-fat:
364
+ <<: *procgen-starpilot-hard-defaults
365
+ policy_hyperparams:
366
+ <<: *procgen-policy-defaults
367
+ impala_channels: [32, 64, 64]
368
+ cnn_feature_dim: 512
369
+ algo_hyperparams:
370
+ <<: *procgen-hard-algo-defaults
371
+ learning_rate: !!float 2.5e-4
372
+
373
+ procgen-starpilot-hard-4xIMPALA:
374
+ <<: *procgen-starpilot-hard-defaults
375
+ policy_hyperparams:
376
+ <<: *procgen-policy-defaults
377
+ impala_channels: [64, 128, 128]
378
+ algo_hyperparams:
379
+ <<: *procgen-hard-algo-defaults
380
+ learning_rate: !!float 2.1e-4
hyperparams/vpg.yml ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 4e5
3
+ algo_hyperparams:
4
+ n_steps: 4096
5
+ pi_lr: 0.01
6
+ gamma: 0.99
7
+ gae_lambda: 1
8
+ val_lr: 0.01
9
+ train_v_iters: 80
10
+ eval_params:
11
+ step_freq: !!float 2.5e4
12
+
13
+ CartPole-v0:
14
+ <<: *cartpole-defaults
15
+ n_timesteps: !!float 1e5
16
+ algo_hyperparams:
17
+ n_steps: 1024
18
+ pi_lr: 0.01
19
+ gamma: 0.99
20
+ gae_lambda: 1
21
+ val_lr: 0.01
22
+ train_v_iters: 80
23
+
24
+ MountainCar-v0:
25
+ n_timesteps: !!float 1e6
26
+ env_hyperparams:
27
+ normalize: true
28
+ n_envs: 16
29
+ algo_hyperparams:
30
+ n_steps: 200
31
+ pi_lr: 0.005
32
+ gamma: 0.99
33
+ gae_lambda: 0.97
34
+ val_lr: 0.01
35
+ train_v_iters: 80
36
+ max_grad_norm: 0.5
37
+
38
+ MountainCarContinuous-v0:
39
+ n_timesteps: !!float 3e5
40
+ env_hyperparams:
41
+ normalize: true
42
+ n_envs: 4
43
+ # policy_hyperparams:
44
+ # init_layers_orthogonal: false
45
+ # log_std_init: -3.29
46
+ # use_sde: true
47
+ algo_hyperparams:
48
+ n_steps: 1000
49
+ pi_lr: !!float 5e-4
50
+ gamma: 0.99
51
+ gae_lambda: 0.9
52
+ val_lr: !!float 1e-3
53
+ train_v_iters: 80
54
+ max_grad_norm: 5
55
+ eval_params:
56
+ step_freq: 5000
57
+
58
+ Acrobot-v1:
59
+ n_timesteps: !!float 2e5
60
+ algo_hyperparams:
61
+ n_steps: 2048
62
+ pi_lr: 0.005
63
+ gamma: 0.99
64
+ gae_lambda: 0.97
65
+ val_lr: 0.01
66
+ train_v_iters: 80
67
+ max_grad_norm: 0.5
68
+
69
+ LunarLander-v2:
70
+ n_timesteps: !!float 4e6
71
+ policy_hyperparams:
72
+ hidden_sizes: [256, 256]
73
+ algo_hyperparams:
74
+ n_steps: 2048
75
+ pi_lr: 0.0001
76
+ gamma: 0.999
77
+ gae_lambda: 0.97
78
+ val_lr: 0.0001
79
+ train_v_iters: 80
80
+ max_grad_norm: 0.5
81
+ eval_params:
82
+ deterministic: false
83
+
84
+ BipedalWalker-v3:
85
+ n_timesteps: !!float 10e6
86
+ env_hyperparams:
87
+ n_envs: 16
88
+ normalize: true
89
+ policy_hyperparams:
90
+ hidden_sizes: [256, 256]
91
+ algo_hyperparams:
92
+ n_steps: 1600
93
+ gae_lambda: 0.95
94
+ gamma: 0.99
95
+ pi_lr: !!float 1e-4
96
+ val_lr: !!float 1e-4
97
+ train_v_iters: 80
98
+ max_grad_norm: 0.5
99
+ eval_params:
100
+ deterministic: false
101
+
102
+ CarRacing-v0:
103
+ n_timesteps: !!float 4e6
104
+ env_hyperparams:
105
+ frame_stack: 4
106
+ n_envs: 4
107
+ vec_env_class: "dummy"
108
+ policy_hyperparams:
109
+ use_sde: true
110
+ log_std_init: -2
111
+ init_layers_orthogonal: false
112
+ activation_fn: relu
113
+ cnn_feature_dim: 256
114
+ hidden_sizes: [256]
115
+ algo_hyperparams:
116
+ n_steps: 1000
117
+ pi_lr: !!float 5e-5
118
+ gamma: 0.99
119
+ gae_lambda: 0.95
120
+ val_lr: !!float 1e-4
121
+ train_v_iters: 40
122
+ max_grad_norm: 0.5
123
+ sde_sample_freq: 4
124
+
125
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
126
+ n_timesteps: !!float 2e6
127
+ policy_hyperparams: &pybullet-policy-defaults
128
+ hidden_sizes: [256, 256]
129
+ algo_hyperparams: &pybullet-algo-defaults
130
+ n_steps: 4000
131
+ pi_lr: !!float 3e-4
132
+ gamma: 0.99
133
+ gae_lambda: 0.97
134
+ val_lr: !!float 1e-3
135
+ train_v_iters: 80
136
+ max_grad_norm: 0.5
137
+
138
+ AntBulletEnv-v0:
139
+ <<: *pybullet-defaults
140
+ policy_hyperparams:
141
+ <<: *pybullet-policy-defaults
142
+ hidden_sizes: [400, 300]
143
+ algo_hyperparams:
144
+ <<: *pybullet-algo-defaults
145
+ pi_lr: !!float 7e-4
146
+ val_lr: !!float 7e-3
147
+
148
+ HopperBulletEnv-v0:
149
+ <<: *pybullet-defaults
150
+
151
+ Walker2DBulletEnv-v0:
152
+ <<: *pybullet-defaults
153
+
154
+ FrozenLake-v1:
155
+ n_timesteps: !!float 8e5
156
+ env_params:
157
+ make_kwargs:
158
+ map_name: 8x8
159
+ is_slippery: true
160
+ policy_hyperparams:
161
+ hidden_sizes: [64]
162
+ algo_hyperparams:
163
+ n_steps: 2048
164
+ pi_lr: 0.01
165
+ gamma: 0.99
166
+ gae_lambda: 0.98
167
+ val_lr: 0.01
168
+ train_v_iters: 80
169
+ max_grad_norm: 0.5
170
+ eval_params:
171
+ step_freq: !!float 5e4
172
+ n_episodes: 10
173
+ save_best: true
174
+
175
+ _atari: &atari-defaults
176
+ n_timesteps: !!float 1e7
177
+ env_hyperparams:
178
+ n_envs: 4
179
+ frame_stack: 4
180
+ no_reward_timeout_steps: 1000
181
+ no_reward_fire_steps: 500
182
+ vec_env_class: subproc
183
+ policy_hyperparams:
184
+ activation_fn: relu
185
+ algo_hyperparams:
186
+ n_steps: 2048
187
+ pi_lr: !!float 5e-5
188
+ gamma: 0.99
189
+ gae_lambda: 0.95
190
+ val_lr: !!float 1e-4
191
+ train_v_iters: 80
192
+ max_grad_norm: 0.5
193
+ ent_coef: 0.01
194
+ eval_params:
195
+ deterministic: false
lambda_labs/benchmark.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+
3
+ # export WANDB_PROJECT_NAME="rl-algo-impls"
4
+
5
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
6
+
7
+ ALGOS=(
8
+ # "vpg"
9
+ # "dqn"
10
+ "ppo"
11
+ )
12
+ ENVS=(
13
+ # Basic
14
+ "CartPole-v1"
15
+ "MountainCar-v0"
16
+ "MountainCarContinuous-v0"
17
+ "Acrobot-v1"
18
+ "LunarLander-v2"
19
+ "BipedalWalker-v3"
20
+ # PyBullet
21
+ "HalfCheetahBulletEnv-v0"
22
+ "AntBulletEnv-v0"
23
+ "HopperBulletEnv-v0"
24
+ "Walker2DBulletEnv-v0"
25
+ # CarRacing
26
+ "CarRacing-v0"
27
+ # Atari
28
+ "PongNoFrameskip-v4"
29
+ "BreakoutNoFrameskip-v4"
30
+ "SpaceInvadersNoFrameskip-v4"
31
+ "QbertNoFrameskip-v4"
32
+ )
33
+ train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
lambda_labs/impala_atari_benchmark.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+
3
+ # export WANDB_PROJECT_NAME="rl-algo-impls"
4
+
5
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-5}"
6
+
7
+ ALGOS=(
8
+ # "vpg"
9
+ # "dqn"
10
+ "ppo"
11
+ )
12
+ ENVS=(
13
+ "impala-PongNoFrameskip-v4"
14
+ "impala-BreakoutNoFrameskip-v4"
15
+ "impala-SpaceInvadersNoFrameskip-v4"
16
+ "impala-QbertNoFrameskip-v4"
17
+ "impala-CarRacing-v0"
18
+ )
19
+ train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
lambda_labs/lambda_requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scipy >= 1.10.0, < 1.11
2
+ tensorboard >= ^2.11.0, < 2.12
3
+ AutoROM.accept-rom-license >= 0.4.2, < 0.5
4
+ stable-baselines3[extra] >= 1.7.0, < 1.8
5
+ gym[box2d] >= 0.21.0, < 0.22
6
+ pyglet == 1.5.27
7
+ wandb >= 0.13.10, < 0.14
8
+ pyvirtualdisplay == 3.0
9
+ pybullet >= 3.2.5, < 3.3
10
+ tabulate >= 0.9.0, < 0.10
11
+ huggingface-hub >= 0.12.0, < 0.13
12
+ numexpr >= 2.8.4, < 2.9
13
+ gym3 >= 0.3.3, < 0.4
14
+ glfw >= 1.12.0, < 1.13
15
+ procgen >= 0.10.7, < 0.11
16
+ ipython >= 8.10.0, < 8.11
lambda_labs/procgen_benchmark.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+
3
+ # export WANDB_PROJECT_NAME="rl-algo-impls"
4
+
5
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
6
+
7
+ ALGOS=(
8
+ # "vpg"
9
+ # "dqn"
10
+ "ppo"
11
+ )
12
+ ENVS=(
13
+ "procgen-coinrun-easy"
14
+ "procgen-starpilot-easy"
15
+ "procgen-bossfight-easy"
16
+ "procgen-bigfish-easy"
17
+ )
18
+ train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
lambda_labs/setup.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ sudo apt update
2
+ sudo apt install -y python-opengl
3
+ sudo apt install -y ffmpeg
4
+ sudo apt install -y xvfb
5
+ sudo apt install -y swig
6
+
7
+ python3 -m pip install --upgrade pip
8
+ pip install --upgrade torch torchvision torchaudio
9
+
10
+ pip install --upgrade -r ~/rl-algo-impls/lambda_labs/lambda_requirements.txt
lambda_labs/starpilot_hard_benchmark.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+
3
+ # export WANDB_PROJECT_NAME="rl-algo-impls"
4
+
5
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-1}"
6
+
7
+ ALGOS=(
8
+ "ppo"
9
+ )
10
+ ENVS=(
11
+ "procgen-starpilot-hard"
12
+ "procgen-starpilot-hard-2xIMPALA"
13
+ "procgen-starpilot-hard-2xIMPALA-fat"
14
+ "procgen-starpilot-hard-4xIMPALA"
15
+ )
16
+ train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
ppo/ppo.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from dataclasses import asdict, dataclass, field
7
+ from time import perf_counter
8
+ from torch.optim import Adam
9
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
10
+ from torch.utils.tensorboard.writer import SummaryWriter
11
+ from typing import List, Optional, NamedTuple, TypeVar
12
+
13
+ from shared.algorithm import Algorithm
14
+ from shared.callbacks.callback import Callback
15
+ from shared.gae import compute_advantage, compute_rtg_and_advantage
16
+ from shared.policy.on_policy import ActorCritic
17
+ from shared.schedule import constant_schedule, linear_schedule
18
+ from shared.trajectory import Trajectory, TrajectoryAccumulator
19
+
20
+
21
+ @dataclass
22
+ class PPOTrajectory(Trajectory):
23
+ logp_a: List[float] = field(default_factory=list)
24
+
25
+ def add(
26
+ self,
27
+ obs: np.ndarray,
28
+ act: np.ndarray,
29
+ next_obs: np.ndarray,
30
+ rew: float,
31
+ terminated: bool,
32
+ v: float,
33
+ logp_a: float,
34
+ ):
35
+ super().add(obs, act, next_obs, rew, terminated, v)
36
+ self.logp_a.append(logp_a)
37
+
38
+
39
+ class PPOTrajectoryAccumulator(TrajectoryAccumulator):
40
+ def __init__(self, num_envs: int) -> None:
41
+ super().__init__(num_envs, PPOTrajectory)
42
+
43
+ def step(
44
+ self,
45
+ obs: VecEnvObs,
46
+ action: np.ndarray,
47
+ next_obs: VecEnvObs,
48
+ reward: np.ndarray,
49
+ done: np.ndarray,
50
+ val: np.ndarray,
51
+ logp_a: np.ndarray,
52
+ ) -> None:
53
+ super().step(obs, action, next_obs, reward, done, val, logp_a)
54
+
55
+
56
+ class TrainStepStats(NamedTuple):
57
+ loss: float
58
+ pi_loss: float
59
+ v_loss: float
60
+ entropy_loss: float
61
+ approx_kl: float
62
+ clipped_frac: float
63
+ val_clipped_frac: float
64
+
65
+
66
+ @dataclass
67
+ class TrainStats:
68
+ loss: float
69
+ pi_loss: float
70
+ v_loss: float
71
+ entropy_loss: float
72
+ approx_kl: float
73
+ clipped_frac: float
74
+ val_clipped_frac: float
75
+ explained_var: float
76
+
77
+ def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None:
78
+ self.loss = np.mean([s.loss for s in step_stats]).item()
79
+ self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
80
+ self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
81
+ self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
82
+ self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
83
+ self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
84
+ self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item()
85
+ self.explained_var = explained_var
86
+
87
+ def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
88
+ for name, value in asdict(self).items():
89
+ tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
90
+
91
+ def __repr__(self) -> str:
92
+ return " | ".join(
93
+ [
94
+ f"Loss: {round(self.loss, 2)}",
95
+ f"Pi L: {round(self.pi_loss, 2)}",
96
+ f"V L: {round(self.v_loss, 2)}",
97
+ f"E L: {round(self.entropy_loss, 2)}",
98
+ f"Apx KL Div: {round(self.approx_kl, 2)}",
99
+ f"Clip Frac: {round(self.clipped_frac, 2)}",
100
+ f"Val Clip Frac: {round(self.val_clipped_frac, 2)}",
101
+ ]
102
+ )
103
+
104
+
105
+ PPOSelf = TypeVar("PPOSelf", bound="PPO")
106
+
107
+
108
+ class PPO(Algorithm):
109
+ def __init__(
110
+ self,
111
+ policy: ActorCritic,
112
+ env: VecEnv,
113
+ device: torch.device,
114
+ tb_writer: SummaryWriter,
115
+ learning_rate: float = 3e-4,
116
+ learning_rate_decay: str = "none",
117
+ n_steps: int = 2048,
118
+ batch_size: int = 64,
119
+ n_epochs: int = 10,
120
+ gamma: float = 0.99,
121
+ gae_lambda: float = 0.95,
122
+ clip_range: float = 0.2,
123
+ clip_range_decay: str = "none",
124
+ clip_range_vf: Optional[float] = None,
125
+ clip_range_vf_decay: str = "none",
126
+ normalize_advantage: bool = True,
127
+ ent_coef: float = 0.0,
128
+ ent_coef_decay: str = "none",
129
+ vf_coef: float = 0.5,
130
+ ppo2_vf_coef_halving: bool = False,
131
+ max_grad_norm: float = 0.5,
132
+ update_rtg_between_epochs: bool = False,
133
+ sde_sample_freq: int = -1,
134
+ ) -> None:
135
+ super().__init__(policy, env, device, tb_writer)
136
+ self.policy = policy
137
+
138
+ self.gamma = gamma
139
+ self.gae_lambda = gae_lambda
140
+ self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
141
+ self.lr_schedule = (
142
+ linear_schedule(learning_rate, 0)
143
+ if learning_rate_decay == "linear"
144
+ else constant_schedule(learning_rate)
145
+ )
146
+ self.max_grad_norm = max_grad_norm
147
+ self.clip_range_schedule = (
148
+ linear_schedule(clip_range, 0)
149
+ if clip_range_decay == "linear"
150
+ else constant_schedule(clip_range)
151
+ )
152
+ self.clip_range_vf_schedule = None
153
+ if clip_range_vf:
154
+ self.clip_range_vf_schedule = (
155
+ linear_schedule(clip_range_vf, 0)
156
+ if clip_range_vf_decay == "linear"
157
+ else constant_schedule(clip_range_vf)
158
+ )
159
+ self.normalize_advantage = normalize_advantage
160
+ self.ent_coef_schedule = (
161
+ linear_schedule(ent_coef, 0)
162
+ if ent_coef_decay == "linear"
163
+ else constant_schedule(ent_coef)
164
+ )
165
+ self.vf_coef = vf_coef
166
+ self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
167
+
168
+ self.n_steps = n_steps
169
+ self.batch_size = batch_size
170
+ self.n_epochs = n_epochs
171
+ self.sde_sample_freq = sde_sample_freq
172
+
173
+ self.update_rtg_between_epochs = update_rtg_between_epochs
174
+
175
+ def learn(
176
+ self: PPOSelf,
177
+ total_timesteps: int,
178
+ callback: Optional[Callback] = None,
179
+ ) -> PPOSelf:
180
+ obs = self.env.reset()
181
+ ts_elapsed = 0
182
+ while ts_elapsed < total_timesteps:
183
+ start_time = perf_counter()
184
+ accumulator = self._collect_trajectories(obs)
185
+ rollout_steps = self.n_steps * self.env.num_envs
186
+ ts_elapsed += rollout_steps
187
+ progress = ts_elapsed / total_timesteps
188
+ train_stats = self.train(accumulator.all_trajectories, progress, ts_elapsed)
189
+ train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
190
+ end_time = perf_counter()
191
+ self.tb_writer.add_scalar(
192
+ "train/steps_per_second",
193
+ rollout_steps / (end_time - start_time),
194
+ ts_elapsed,
195
+ )
196
+ if callback:
197
+ callback.on_step(timesteps_elapsed=rollout_steps)
198
+
199
+ return self
200
+
201
+ def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
202
+ self.policy.eval()
203
+ accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
204
+ self.policy.reset_noise()
205
+ for i in range(self.n_steps):
206
+ if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
207
+ self.policy.reset_noise()
208
+ action, value, logp_a, clamped_action = self.policy.step(obs)
209
+ next_obs, reward, done, _ = self.env.step(clamped_action)
210
+ accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
211
+ obs = next_obs
212
+ return accumulator
213
+
214
+ def train(
215
+ self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
216
+ ) -> TrainStats:
217
+ self.policy.train()
218
+ learning_rate = self.lr_schedule(progress)
219
+ self.optimizer.param_groups[0]["lr"] = learning_rate
220
+ self.tb_writer.add_scalar(
221
+ "charts/learning_rate",
222
+ self.optimizer.param_groups[0]["lr"],
223
+ timesteps_elapsed,
224
+ )
225
+
226
+ pi_clip = self.clip_range_schedule(progress)
227
+ self.tb_writer.add_scalar("charts/pi_clip", pi_clip, timesteps_elapsed)
228
+ if self.clip_range_vf_schedule:
229
+ v_clip = self.clip_range_vf_schedule(progress)
230
+ self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
231
+ else:
232
+ v_clip = None
233
+ ent_coef = self.ent_coef_schedule(progress)
234
+ self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
235
+
236
+ obs = torch.as_tensor(
237
+ np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
238
+ )
239
+ act = torch.as_tensor(
240
+ np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
241
+ )
242
+ rtg, adv = compute_rtg_and_advantage(
243
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
244
+ )
245
+ orig_v = torch.as_tensor(
246
+ np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
247
+ )
248
+ orig_logp_a = torch.as_tensor(
249
+ np.concatenate([np.array(t.logp_a) for t in trajectories]),
250
+ device=self.device,
251
+ )
252
+
253
+ step_stats = []
254
+ for _ in range(self.n_epochs):
255
+ step_stats.clear()
256
+ if self.update_rtg_between_epochs:
257
+ rtg, adv = compute_rtg_and_advantage(
258
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
259
+ )
260
+ else:
261
+ adv = compute_advantage(
262
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
263
+ )
264
+ idxs = torch.randperm(len(obs))
265
+ for i in range(0, len(obs), self.batch_size):
266
+ mb_idxs = idxs[i : i + self.batch_size]
267
+ mb_adv = adv[mb_idxs]
268
+ if self.normalize_advantage:
269
+ mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
270
+ step_stats.append(
271
+ self._train_step(
272
+ pi_clip,
273
+ v_clip,
274
+ ent_coef,
275
+ obs[mb_idxs],
276
+ act[mb_idxs],
277
+ rtg[mb_idxs],
278
+ mb_adv,
279
+ orig_v[mb_idxs],
280
+ orig_logp_a[mb_idxs],
281
+ )
282
+ )
283
+
284
+ y_pred, y_true = orig_v.cpu().numpy(), rtg.cpu().numpy()
285
+ var_y = np.var(y_true).item()
286
+ explained_var = (
287
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
288
+ )
289
+
290
+ return TrainStats(step_stats, explained_var)
291
+
292
+ def _train_step(
293
+ self,
294
+ pi_clip: float,
295
+ v_clip: Optional[float],
296
+ ent_coef: float,
297
+ obs: torch.Tensor,
298
+ act: torch.Tensor,
299
+ rtg: torch.Tensor,
300
+ adv: torch.Tensor,
301
+ orig_v: torch.Tensor,
302
+ orig_logp_a: torch.Tensor,
303
+ ) -> TrainStepStats:
304
+ logp_a, entropy, v = self.policy(obs, act)
305
+ logratio = logp_a - orig_logp_a
306
+ ratio = torch.exp(logratio)
307
+ clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
308
+ pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
309
+
310
+ v_loss_unclipped = (v - rtg) ** 2
311
+ if v_clip:
312
+ v_loss_clipped = (
313
+ orig_v + torch.clamp(v - orig_v, -v_clip, v_clip) - rtg
314
+ ) ** 2
315
+ v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
316
+ else:
317
+ v_loss = v_loss_unclipped.mean()
318
+ if self.ppo2_vf_coef_halving:
319
+ v_loss *= 0.5
320
+
321
+ entropy_loss = entropy.mean()
322
+
323
+ loss = pi_loss - ent_coef * entropy_loss + self.vf_coef * v_loss
324
+
325
+ self.optimizer.zero_grad()
326
+ loss.backward()
327
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
328
+ self.optimizer.step()
329
+
330
+ with torch.no_grad():
331
+ approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
332
+ clipped_frac = (
333
+ ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
334
+ )
335
+ val_clipped_frac = (
336
+ (((v - orig_v).abs() > v_clip).float().mean().cpu().numpy().item())
337
+ if v_clip
338
+ else 0
339
+ )
340
+
341
+ return TrainStepStats(
342
+ loss.item(),
343
+ pi_loss.item(),
344
+ v_loss.item(),
345
+ entropy_loss.item(),
346
+ approx_kl,
347
+ clipped_frac,
348
+ val_clipped_frac,
349
+ )
publish/markdown_format.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import wandb.apis.public
4
+ import yaml
5
+
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass, asdict
8
+ from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
9
+ from urllib.parse import urlparse
10
+
11
+ from runner.evaluate import Evaluation
12
+
13
+ EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
14
+
15
+
16
+ @dataclass
17
+ class EvaluationRow:
18
+ algo: str
19
+ env: str
20
+ seed: Optional[int]
21
+ reward_mean: float
22
+ reward_std: float
23
+ eval_episodes: int
24
+ best: str
25
+ wandb_url: str
26
+
27
+ @staticmethod
28
+ def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
29
+ results = defaultdict(list)
30
+ for r in rows:
31
+ for k, v in asdict(r).items():
32
+ results[k].append(v)
33
+ return pd.DataFrame(results)
34
+
35
+
36
+ class EvalTableData(NamedTuple):
37
+ run: wandb.apis.public.Run
38
+ evaluation: Evaluation
39
+
40
+
41
+ def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
42
+ best_stats = sorted(
43
+ [d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
44
+ )[0]
45
+ table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
46
+ rows = [
47
+ EvaluationRow(
48
+ config.algo,
49
+ config.env_id,
50
+ config.seed(),
51
+ stats.score.mean,
52
+ stats.score.std,
53
+ len(stats),
54
+ "*" if stats == best_stats else "",
55
+ f"[wandb]({r.url})",
56
+ )
57
+ for (r, (_, stats, config)) in table_data
58
+ ]
59
+ df = EvaluationRow.data_frame(rows)
60
+ return df.to_markdown(index=False)
61
+
62
+
63
+ def github_project_link(github_url: str) -> str:
64
+ return f"[{urlparse(github_url).path}]({github_url})"
65
+
66
+
67
+ def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
68
+ algo_caps = algo.upper()
69
+ lines = [
70
+ f"# **{algo_caps}** Agent playing **{env}**",
71
+ f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
72
+ f"the {github_project_link(github_url)} repo.",
73
+ f"All models trained at this commit can be found at {wandb_report_url}.",
74
+ ]
75
+ return "\n\n".join(lines)
76
+
77
+
78
+ def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
79
+ if not commit_hash:
80
+ return github_project_link(github_url)
81
+ return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
82
+
83
+
84
+ def results_section(
85
+ table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
86
+ ) -> str:
87
+ # type: ignore
88
+ lines = [
89
+ "## Training Results",
90
+ f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
91
+ + "agents using different initial seeds. "
92
+ + f"These agents were trained by checking out "
93
+ + f"{github_tree_link(github_url, commit_hash)}. "
94
+ + "The best and last models were kept from each training. "
95
+ + "This submission has loaded the best models from each training, reevaluates "
96
+ + "them, and selects the best model from these latest evaluations (mean - std).",
97
+ ]
98
+ lines.append(evaluation_table(table_data))
99
+ return "\n\n".join(lines)
100
+
101
+
102
+ def prerequisites_section() -> str:
103
+ return """
104
+ ### Prerequisites: Weights & Biases (WandB)
105
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
106
+ By default training goes to a rl-algo-impls project while benchmarks go to
107
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
108
+ models and the model weights are uploaded to WandB.
109
+
110
+ Before doing anything below, you'll need to create a wandb account and run `wandb
111
+ login`.
112
+ """
113
+
114
+
115
+ def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
116
+ return f"""
117
+ ## Usage
118
+ {urlparse(github_url).path}: {github_url}
119
+
120
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
121
+ implementation could be sufficiently different to not be able to reproduce similar
122
+ results. You might need to checkout the commit the agent was trained on:
123
+ {github_tree_link(github_url, commit_hash)}.
124
+ ```
125
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
126
+ python enjoy.py --wandb-run-path={run_path}
127
+ ```
128
+
129
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
130
+ Colab starting from the
131
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
132
+ notebook.
133
+ """
134
+
135
+
136
+ def training_setion(
137
+ github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
138
+ ) -> str:
139
+ return f"""
140
+ ## Training
141
+ If you want the highest chance to reproduce these results, you'll want to checkout the
142
+ commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
143
+ training is deterministic, different hardware will give different results.
144
+
145
+ ```
146
+ python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
147
+ ```
148
+
149
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
150
+ Colab starting from the
151
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
152
+ notebook.
153
+ """
154
+
155
+
156
+ def benchmarking_section(report_url: str) -> str:
157
+ return f"""
158
+ ## Benchmarking (with Lambda Labs instance)
159
+ This and other models from {report_url} were generated by running a script on a Lambda
160
+ Labs instance. In a Lambda Labs instance terminal:
161
+ ```
162
+ git clone git@github.com:sgoodfriend/rl-algo-impls.git
163
+ cd rl-algo-impls
164
+ bash ./lambda_labs/setup.sh
165
+ wandb login
166
+ bash ./lambda_labs/benchmark.sh
167
+ ```
168
+
169
+ ### Alternative: Google Colab Pro+
170
+ As an alternative,
171
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
172
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
173
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
174
+ """
175
+
176
+
177
+ def hyperparams_section(run_config: Dict[str, Any]) -> str:
178
+ return f"""
179
+ ## Hyperparameters
180
+ This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
181
+ run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
182
+ close and has some additional data:
183
+ ```
184
+ {yaml.dump(run_config)}
185
+ ```
186
+ """
187
+
188
+
189
+ def model_card_text(
190
+ algo: str,
191
+ env: str,
192
+ github_url: str,
193
+ commit_hash: str,
194
+ wandb_report_url: str,
195
+ table_data: List[EvalTableData],
196
+ best_eval: EvalTableData,
197
+ ) -> str:
198
+ run, (_, _, config) = best_eval
199
+ run_path = "/".join(run.path)
200
+ return "\n\n".join(
201
+ [
202
+ header_section(algo, env, github_url, wandb_report_url),
203
+ results_section(table_data, algo, github_url, commit_hash),
204
+ prerequisites_section(),
205
+ usage_section(github_url, run_path, commit_hash),
206
+ training_setion(github_url, commit_hash, algo, env, config.seed()),
207
+ benchmarking_section(wandb_report_url),
208
+ hyperparams_section(run.config),
209
+ ]
210
+ )
pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "rl-algo-impls"
3
+ version = "0.1.0"
4
+ description = "Implementations of reinforcement learning algorithms"
5
+ authors = ["Scott Goodfriend <goodfriend.scott@gmail.com>"]
6
+ license = "MIT License"
7
+ readme = "README.md"
8
+ packages = [{include = "rl_algo_impls"}]
9
+
10
+ [tool.poetry.dependencies]
11
+ python = "~3.10"
12
+ "AutoROM.accept-rom-license" = "^0.4.2"
13
+ stable-baselines3 = {extras = ["extra"], version = "^1.7.0"}
14
+ scipy = "^1.10.0"
15
+ gym = {extras = ["box2d"], version = "^0.21.0"}
16
+ pyglet = "1.5.27"
17
+ PyYAML = "^6.0"
18
+ tensorboard = "^2.11.0"
19
+ pybullet = "^3.2.5"
20
+ wandb = "^0.13.9"
21
+ conda-lock = "^1.3.0"
22
+ torch-tb-profiler = "^0.4.1"
23
+ jupyter = "^1.0.0"
24
+ tabulate = "^0.9.0"
25
+ huggingface-hub = "^0.12.0"
26
+ cryptography = "39.0.1"
27
+ pyvirtualdisplay = "^3.0"
28
+ numexpr = "^2.8.4"
29
+ gym3 = "^0.3.3"
30
+ glfw = "1.12.0"
31
+ ipython = "^8.10.0"
32
+
33
+ [build-system]
34
+ requires = ["poetry-core"]
35
+ build-backend = "poetry.core.masonry.api"
replay.meta.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "1200x800", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmp3exrpbmd/ppo-BipedalWalker-v3/replay.mp4"]}, "episode": {"r": 325.3429870605469, "l": 827, "t": 19.199028}}
replay.mp4 ADDED
Binary file (838 kB). View file
 
runner/config.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from datetime import datetime
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, NamedTuple, Optional, TypedDict, Union
6
+
7
+
8
+ @dataclass
9
+ class RunArgs:
10
+ algo: str
11
+ env: str
12
+ seed: Optional[int] = None
13
+ use_deterministic_algorithms: bool = True
14
+
15
+
16
+ class EnvHyperparams(NamedTuple):
17
+ is_procgen: bool = False
18
+ n_envs: int = 1
19
+ frame_stack: int = 1
20
+ make_kwargs: Optional[Dict[str, Any]] = None
21
+ no_reward_timeout_steps: Optional[int] = None
22
+ no_reward_fire_steps: Optional[int] = None
23
+ vec_env_class: str = "dummy"
24
+ normalize: bool = False
25
+ normalize_kwargs: Optional[Dict[str, Any]] = None
26
+ rolling_length: int = 100
27
+ train_record_video: bool = False
28
+ video_step_interval: Union[int, float] = 1_000_000
29
+ initial_steps_to_truncate: Optional[int] = None
30
+ clip_atari_rewards: bool = True
31
+
32
+
33
+ class Hyperparams(TypedDict, total=False):
34
+ device: str
35
+ n_timesteps: Union[int, float]
36
+ env_hyperparams: Dict[str, Any]
37
+ policy_hyperparams: Dict[str, Any]
38
+ algo_hyperparams: Dict[str, Any]
39
+ eval_params: Dict[str, Any]
40
+
41
+
42
+ @dataclass
43
+ class Config:
44
+ args: RunArgs
45
+ hyperparams: Hyperparams
46
+ root_dir: str
47
+ run_id: str = datetime.now().isoformat()
48
+
49
+ def seed(self, training: bool = True) -> Optional[int]:
50
+ seed = self.args.seed
51
+ if training or seed is None:
52
+ return seed
53
+ return seed + self.env_hyperparams.get("n_envs", 1)
54
+
55
+ @property
56
+ def device(self) -> str:
57
+ return self.hyperparams.get("device", "auto")
58
+
59
+ @property
60
+ def n_timesteps(self) -> int:
61
+ return int(self.hyperparams.get("n_timesteps", 100_000))
62
+
63
+ @property
64
+ def env_hyperparams(self) -> Dict[str, Any]:
65
+ return self.hyperparams.get("env_hyperparams", {})
66
+
67
+ @property
68
+ def policy_hyperparams(self) -> Dict[str, Any]:
69
+ return self.hyperparams.get("policy_hyperparams", {})
70
+
71
+ @property
72
+ def algo_hyperparams(self) -> Dict[str, Any]:
73
+ return self.hyperparams.get("algo_hyperparams", {})
74
+
75
+ @property
76
+ def eval_params(self) -> Dict[str, Any]:
77
+ return self.hyperparams.get("eval_params", {})
78
+
79
+ @property
80
+ def algo(self) -> str:
81
+ return self.args.algo
82
+
83
+ @property
84
+ def env_id(self) -> str:
85
+ return self.hyperparams.get("env_id") or self.args.env
86
+
87
+ def model_name(self, include_seed: bool = True) -> str:
88
+ # Use arg env name instead of environment name
89
+ parts = [self.algo, self.args.env]
90
+ if include_seed and self.args.seed is not None:
91
+ parts.append(f"S{self.args.seed}")
92
+
93
+ # Assume that the custom arg name already has the necessary information
94
+ if not self.hyperparams.get("env_id"):
95
+ make_kwargs = self.env_hyperparams.get("make_kwargs", {})
96
+ if make_kwargs:
97
+ for k, v in make_kwargs.items():
98
+ if type(v) == bool and v:
99
+ parts.append(k)
100
+ elif type(v) == int and v:
101
+ parts.append(f"{k}{v}")
102
+ else:
103
+ parts.append(str(v))
104
+
105
+ return "-".join(parts)
106
+
107
+ @property
108
+ def run_name(self) -> str:
109
+ parts = [self.model_name(), self.run_id]
110
+ return "-".join(parts)
111
+
112
+ @property
113
+ def saved_models_dir(self) -> str:
114
+ return os.path.join(self.root_dir, "saved_models")
115
+
116
+ @property
117
+ def downloaded_models_dir(self) -> str:
118
+ return os.path.join(self.root_dir, "downloaded_models")
119
+
120
+ def model_dir_name(
121
+ self,
122
+ best: bool = False,
123
+ extension: str = "",
124
+ ) -> str:
125
+ return self.model_name() + ("-best" if best else "") + extension
126
+
127
+ def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
128
+ return os.path.join(
129
+ self.saved_models_dir if not downloaded else self.downloaded_models_dir,
130
+ self.model_dir_name(best=best),
131
+ )
132
+
133
+ @property
134
+ def runs_dir(self) -> str:
135
+ return os.path.join(self.root_dir, "runs")
136
+
137
+ @property
138
+ def tensorboard_summary_path(self) -> str:
139
+ return os.path.join(self.runs_dir, self.run_name)
140
+
141
+ @property
142
+ def logs_path(self) -> str:
143
+ return os.path.join(self.runs_dir, f"log.yml")
144
+
145
+ @property
146
+ def videos_dir(self) -> str:
147
+ return os.path.join(self.root_dir, "videos")
148
+
149
+ @property
150
+ def video_prefix(self) -> str:
151
+ return os.path.join(self.videos_dir, self.model_name())
152
+
153
+ @property
154
+ def best_videos_dir(self) -> str:
155
+ return os.path.join(self.videos_dir, f"{self.model_name()}-best")
runner/env.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import numpy as np
3
+ import os
4
+
5
+ from gym.wrappers.resize_observation import ResizeObservation
6
+ from gym.wrappers.gray_scale_observation import GrayScaleObservation
7
+ from gym.wrappers.frame_stack import FrameStack
8
+ from stable_baselines3.common.atari_wrappers import (
9
+ MaxAndSkipEnv,
10
+ NoopResetEnv,
11
+ )
12
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
13
+ from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
14
+ from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
15
+ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
16
+ from torch.utils.tensorboard.writer import SummaryWriter
17
+ from typing import Callable, Optional, Union
18
+
19
+ from runner.config import Config, EnvHyperparams
20
+ from shared.policy.policy import VEC_NORMALIZE_FILENAME
21
+ from wrappers.atari_wrappers import EpisodicLifeEnv, FireOnLifeStarttEnv, ClipRewardEnv
22
+ from wrappers.episode_record_video import EpisodeRecordVideo
23
+ from wrappers.episode_stats_writer import EpisodeStatsWriter
24
+ from wrappers.get_rgb_observation import GetRgbObservation
25
+ from wrappers.initial_step_truncate_wrapper import InitialStepTruncateWrapper
26
+ from wrappers.is_vector_env import IsVectorEnv
27
+ from wrappers.noop_env_seed import NoopEnvSeed
28
+ from wrappers.transpose_image_observation import TransposeImageObservation
29
+ from wrappers.video_compat_wrapper import VideoCompatWrapper
30
+
31
+ GeneralVecEnv = Union[VecEnv, gym.vector.VectorEnv, gym.Wrapper]
32
+
33
+
34
+ def make_env(
35
+ config: Config,
36
+ hparams: EnvHyperparams,
37
+ training: bool = True,
38
+ render: bool = False,
39
+ normalize_load_path: Optional[str] = None,
40
+ tb_writer: Optional[SummaryWriter] = None,
41
+ ) -> GeneralVecEnv:
42
+ if hparams.is_procgen:
43
+ return _make_procgen_env(
44
+ config,
45
+ hparams,
46
+ training=training,
47
+ render=render,
48
+ normalize_load_path=normalize_load_path,
49
+ tb_writer=tb_writer,
50
+ )
51
+ else:
52
+ return _make_vec_env(
53
+ config,
54
+ hparams,
55
+ training=training,
56
+ render=render,
57
+ normalize_load_path=normalize_load_path,
58
+ tb_writer=tb_writer,
59
+ )
60
+
61
+
62
+ def make_eval_env(
63
+ config: Config,
64
+ hparams: EnvHyperparams,
65
+ override_n_envs: Optional[int] = None,
66
+ **kwargs
67
+ ) -> GeneralVecEnv:
68
+ kwargs = kwargs.copy()
69
+ kwargs["training"] = False
70
+ if override_n_envs is not None:
71
+ hparams_kwargs = hparams._asdict()
72
+ hparams_kwargs["n_envs"] = override_n_envs
73
+ if override_n_envs == 1:
74
+ hparams_kwargs["vec_env_class"] = "dummy"
75
+ hparams = EnvHyperparams(**hparams_kwargs)
76
+ return make_env(config, hparams, **kwargs)
77
+
78
+
79
+ def _make_vec_env(
80
+ config: Config,
81
+ hparams: EnvHyperparams,
82
+ training: bool = True,
83
+ render: bool = False,
84
+ normalize_load_path: Optional[str] = None,
85
+ tb_writer: Optional[SummaryWriter] = None,
86
+ ) -> GeneralVecEnv:
87
+ (
88
+ _,
89
+ n_envs,
90
+ frame_stack,
91
+ make_kwargs,
92
+ no_reward_timeout_steps,
93
+ no_reward_fire_steps,
94
+ vec_env_class,
95
+ normalize,
96
+ normalize_kwargs,
97
+ rolling_length,
98
+ train_record_video,
99
+ video_step_interval,
100
+ initial_steps_to_truncate,
101
+ clip_atari_rewards,
102
+ ) = hparams
103
+
104
+ if "BulletEnv" in config.env_id:
105
+ import pybullet_envs
106
+
107
+ spec = gym.spec(config.env_id)
108
+ seed = config.seed(training=training)
109
+
110
+ def make(idx: int) -> Callable[[], gym.Env]:
111
+ env_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
112
+ if "BulletEnv" in config.env_id and render:
113
+ env_kwargs["render"] = True
114
+ if "CarRacing" in config.env_id:
115
+ env_kwargs["verbose"] = 0
116
+ if "procgen" in config.env_id:
117
+ if not render:
118
+ env_kwargs["render_mode"] = "rgb_array"
119
+
120
+ def _make() -> gym.Env:
121
+ env = gym.make(config.env_id, **env_kwargs)
122
+ env = gym.wrappers.RecordEpisodeStatistics(env)
123
+ env = VideoCompatWrapper(env)
124
+ if training and train_record_video and idx == 0:
125
+ env = EpisodeRecordVideo(
126
+ env,
127
+ config.video_prefix,
128
+ step_increment=n_envs,
129
+ video_step_interval=int(video_step_interval),
130
+ )
131
+ if training and initial_steps_to_truncate:
132
+ env = InitialStepTruncateWrapper(
133
+ env, idx * initial_steps_to_truncate // n_envs
134
+ )
135
+ if "AtariEnv" in spec.entry_point: # type: ignore
136
+ env = NoopResetEnv(env, noop_max=30)
137
+ env = MaxAndSkipEnv(env, skip=4)
138
+ env = EpisodicLifeEnv(env, training=training)
139
+ action_meanings = env.unwrapped.get_action_meanings()
140
+ if "FIRE" in action_meanings: # type: ignore
141
+ env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
142
+ if clip_atari_rewards:
143
+ env = ClipRewardEnv(env, training=training)
144
+ env = ResizeObservation(env, (84, 84))
145
+ env = GrayScaleObservation(env, keep_dim=False)
146
+ env = FrameStack(env, frame_stack)
147
+ elif "CarRacing" in config.env_id:
148
+ env = ResizeObservation(env, (64, 64))
149
+ env = GrayScaleObservation(env, keep_dim=False)
150
+ env = FrameStack(env, frame_stack)
151
+ elif "procgen" in config.env_id:
152
+ # env = GrayScaleObservation(env, keep_dim=False)
153
+ env = NoopEnvSeed(env)
154
+ env = TransposeImageObservation(env)
155
+ if frame_stack > 1:
156
+ env = FrameStack(env, frame_stack)
157
+
158
+ if no_reward_timeout_steps:
159
+ from wrappers.no_reward_timeout import NoRewardTimeout
160
+
161
+ env = NoRewardTimeout(
162
+ env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
163
+ )
164
+
165
+ if seed is not None:
166
+ env.seed(seed + idx)
167
+ env.action_space.seed(seed + idx)
168
+ env.observation_space.seed(seed + idx)
169
+
170
+ return env
171
+
172
+ return _make
173
+
174
+ VecEnvClass = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_class]
175
+ venv = VecEnvClass([make(i) for i in range(n_envs)])
176
+ if training:
177
+ assert tb_writer
178
+ venv = EpisodeStatsWriter(
179
+ venv, tb_writer, training=training, rolling_length=rolling_length
180
+ )
181
+ if normalize:
182
+ if normalize_load_path:
183
+ venv = VecNormalize.load(
184
+ os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
185
+ venv, # type: ignore
186
+ )
187
+ else:
188
+ venv = VecNormalize(
189
+ venv, # type: ignore
190
+ training=training,
191
+ **(normalize_kwargs or {}),
192
+ )
193
+ if not training:
194
+ venv.norm_reward = False
195
+ return venv
196
+
197
+
198
+ def _make_procgen_env(
199
+ config: Config,
200
+ hparams: EnvHyperparams,
201
+ training: bool = True,
202
+ render: bool = False,
203
+ normalize_load_path: Optional[str] = None,
204
+ tb_writer: Optional[SummaryWriter] = None,
205
+ ) -> GeneralVecEnv:
206
+ from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv
207
+ from gym3 import ViewerWrapper, ExtractDictObWrapper
208
+
209
+ (
210
+ _,
211
+ n_envs,
212
+ frame_stack,
213
+ make_kwargs,
214
+ _, # no_reward_timeout_steps
215
+ _, # no_reward_fire_steps
216
+ _, # vec_env_class
217
+ normalize,
218
+ normalize_kwargs,
219
+ rolling_length,
220
+ _, # train_record_video
221
+ _, # video_step_interval
222
+ _, # initial_steps_to_truncate
223
+ ) = hparams
224
+
225
+ seed = config.seed(training=training)
226
+
227
+ make_kwargs = make_kwargs or {}
228
+ make_kwargs["render_mode"] = "rgb_array"
229
+ if seed is not None:
230
+ make_kwargs["rand_seed"] = seed
231
+
232
+ envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs)
233
+ envs = ExtractDictObWrapper(envs, key="rgb")
234
+ if render:
235
+ envs = ViewerWrapper(envs, info_key="rgb")
236
+ envs = ToBaselinesVecEnv(envs)
237
+ envs = IsVectorEnv(envs)
238
+ # TODO: Handle Grayscale and/or FrameStack
239
+ envs = TransposeImageObservation(envs)
240
+
241
+ envs = gym.wrappers.RecordEpisodeStatistics(envs)
242
+
243
+ if seed is not None:
244
+ envs.action_space.seed(seed)
245
+ envs.observation_space.seed(seed)
246
+
247
+ if training:
248
+ assert tb_writer
249
+ envs = EpisodeStatsWriter(
250
+ envs, tb_writer, training=training, rolling_length=rolling_length
251
+ )
252
+ if normalize and training:
253
+ normalize_kwargs = normalize_kwargs or {}
254
+ # TODO: Handle reward stats saving/loading/syncing, but it's only important
255
+ # for checkpointing
256
+ envs = gym.wrappers.NormalizeReward(envs)
257
+ clip_obs = normalize_kwargs.get("clip_reward", 10.0)
258
+ envs = gym.wrappers.TransformReward(
259
+ envs, lambda r: np.clip(r, -clip_obs, clip_obs)
260
+ )
261
+
262
+ return envs
runner/evaluate.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Optional
6
+
7
+ from runner.env import make_eval_env
8
+ from runner.config import Config, EnvHyperparams, RunArgs
9
+ from runner.running_utils import (
10
+ load_hyperparams,
11
+ set_seeds,
12
+ get_device,
13
+ make_policy,
14
+ )
15
+ from shared.callbacks.eval_callback import evaluate
16
+ from shared.policy.policy import Policy
17
+ from shared.stats import EpisodesStats
18
+
19
+
20
+ @dataclass
21
+ class EvalArgs(RunArgs):
22
+ render: bool = True
23
+ best: bool = True
24
+ n_envs: Optional[int] = 1
25
+ n_episodes: int = 3
26
+ deterministic_eval: Optional[bool] = None
27
+ no_print_returns: bool = False
28
+ wandb_run_path: Optional[str] = None
29
+
30
+
31
+ class Evaluation(NamedTuple):
32
+ policy: Policy
33
+ stats: EpisodesStats
34
+ config: Config
35
+
36
+
37
+ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
38
+ if args.wandb_run_path:
39
+ import wandb
40
+
41
+ api = wandb.Api()
42
+ run = api.run(args.wandb_run_path)
43
+ hyperparams = run.config
44
+
45
+ args.algo = hyperparams["algo"]
46
+ args.env = hyperparams["env"]
47
+ args.seed = hyperparams.get("seed", None)
48
+ args.use_deterministic_algorithms = hyperparams.get(
49
+ "use_deterministic_algorithms", True
50
+ )
51
+
52
+ config = Config(args, hyperparams, root_dir)
53
+ model_path = config.model_dir_path(best=args.best, downloaded=True)
54
+
55
+ model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
56
+ run.file(model_archive_name).download()
57
+ if os.path.isdir(model_path):
58
+ shutil.rmtree(model_path)
59
+ shutil.unpack_archive(model_archive_name, model_path)
60
+ os.remove(model_archive_name)
61
+ else:
62
+ hyperparams = load_hyperparams(args.algo, args.env, root_dir)
63
+
64
+ config = Config(args, hyperparams, root_dir)
65
+ model_path = config.model_dir_path(best=args.best)
66
+
67
+ print(args)
68
+
69
+ set_seeds(args.seed, args.use_deterministic_algorithms)
70
+
71
+ env = make_eval_env(
72
+ config,
73
+ EnvHyperparams(**config.env_hyperparams),
74
+ override_n_envs=args.n_envs,
75
+ render=args.render,
76
+ normalize_load_path=model_path,
77
+ )
78
+ device = get_device(config.device, env)
79
+ policy = make_policy(
80
+ args.algo,
81
+ env,
82
+ device,
83
+ load_path=model_path,
84
+ **config.policy_hyperparams,
85
+ ).eval()
86
+
87
+ deterministic = (
88
+ args.deterministic_eval
89
+ if args.deterministic_eval is not None
90
+ else config.eval_params.get("deterministic", True)
91
+ )
92
+ return Evaluation(
93
+ policy,
94
+ evaluate(
95
+ env,
96
+ policy,
97
+ args.n_episodes,
98
+ render=args.render,
99
+ deterministic=deterministic,
100
+ print_returns=not args.no_print_returns,
101
+ ),
102
+ config,
103
+ )
runner/running_utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gym
3
+ import json
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os
7
+ import random
8
+ import torch
9
+ import torch.backends.cudnn
10
+ import yaml
11
+
12
+ from gym.spaces import Box, Discrete
13
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
14
+ from torch.utils.tensorboard.writer import SummaryWriter
15
+ from typing import Dict, Optional, Type, Union
16
+
17
+ from runner.config import Hyperparams
18
+ from shared.algorithm import Algorithm
19
+ from shared.callbacks.eval_callback import EvalCallback
20
+ from shared.policy.policy import Policy
21
+
22
+ from dqn.dqn import DQN
23
+ from dqn.policy import DQNPolicy
24
+ from vpg.vpg import VanillaPolicyGradient
25
+ from vpg.policy import VPGActorCritic
26
+ from ppo.ppo import PPO
27
+ from ppo.policy import PPOActorCritic
28
+
29
+ ALGOS: Dict[str, Type[Algorithm]] = {
30
+ "dqn": DQN,
31
+ "vpg": VanillaPolicyGradient,
32
+ "ppo": PPO,
33
+ }
34
+ POLICIES: Dict[str, Type[Policy]] = {
35
+ "dqn": DQNPolicy,
36
+ "vpg": VPGActorCritic,
37
+ "ppo": PPOActorCritic,
38
+ }
39
+
40
+ HYPERPARAMS_PATH = "hyperparams"
41
+
42
+
43
+ def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "--algo",
47
+ default=["dqn"],
48
+ type=str,
49
+ choices=list(ALGOS.keys()),
50
+ nargs="+" if multiple else 1,
51
+ help="Abbreviation(s) of algorithm(s)",
52
+ )
53
+ parser.add_argument(
54
+ "--env",
55
+ default=["CartPole-v1"],
56
+ type=str,
57
+ nargs="+" if multiple else 1,
58
+ help="Name of environment(s) in gym",
59
+ )
60
+ parser.add_argument(
61
+ "--seed",
62
+ default=[1],
63
+ type=int,
64
+ nargs="*" if multiple else "?",
65
+ help="Seeds to run experiment. Unset will do one run with no set seed",
66
+ )
67
+ parser.add_argument(
68
+ "--use-deterministic-algorithms",
69
+ default=True,
70
+ type=bool,
71
+ help="If seed set, set torch.use_deterministic_algorithms",
72
+ )
73
+ return parser
74
+
75
+
76
+ def load_hyperparams(algo: str, env_id: str, root_path: str) -> Hyperparams:
77
+ hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
78
+ with open(hyperparams_path, "r") as f:
79
+ hyperparams_dict = yaml.safe_load(f)
80
+
81
+ if env_id in hyperparams_dict:
82
+ return hyperparams_dict[env_id]
83
+
84
+ if "BulletEnv" in env_id:
85
+ import pybullet_envs
86
+ spec = gym.spec(env_id)
87
+ if "AtariEnv" in str(spec.entry_point) and "_atari" in hyperparams_dict:
88
+ return hyperparams_dict["_atari"]
89
+ else:
90
+ raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
91
+
92
+
93
+ def get_device(device: str, env: VecEnv) -> torch.device:
94
+ # cuda by default
95
+ if device == "auto":
96
+ device = "cuda"
97
+ # Apple MPS is a second choice (sometimes)
98
+ if device == "cuda" and not torch.cuda.is_available():
99
+ device = "mps"
100
+ # If no MPS, fallback to cpu
101
+ if device == "mps" and not torch.backends.mps.is_available():
102
+ device = "cpu"
103
+ # Simple environments like Discreet and 1-D Boxes might also be better
104
+ # served with the CPU.
105
+ if device == "mps":
106
+ obs_space = env.observation_space
107
+ if isinstance(obs_space, Discrete):
108
+ device = "cpu"
109
+ elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
110
+ device = "cpu"
111
+ print(f"Device: {device}")
112
+ return torch.device(device)
113
+
114
+
115
+ def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
116
+ if seed is None:
117
+ return
118
+ random.seed(seed)
119
+ np.random.seed(seed)
120
+ torch.manual_seed(seed)
121
+ torch.backends.cudnn.benchmark = False
122
+ torch.use_deterministic_algorithms(use_deterministic_algorithms)
123
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
124
+ # Stop warning and it would introduce stochasticity if I was using TF
125
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
126
+
127
+
128
+ def make_policy(
129
+ algo: str,
130
+ env: VecEnv,
131
+ device: torch.device,
132
+ load_path: Optional[str] = None,
133
+ **kwargs,
134
+ ) -> Policy:
135
+ policy = POLICIES[algo](env, **kwargs).to(device)
136
+ if load_path:
137
+ policy.load(load_path)
138
+ return policy
139
+
140
+
141
+ def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
142
+ figure = plt.figure()
143
+ cumulative_steps = [
144
+ (idx + 1) * callback.step_freq for idx in range(len(callback.stats))
145
+ ]
146
+ plt.plot(
147
+ cumulative_steps,
148
+ [s.score.mean for s in callback.stats],
149
+ "b-",
150
+ label="mean",
151
+ )
152
+ plt.plot(
153
+ cumulative_steps,
154
+ [s.score.mean - s.score.std for s in callback.stats],
155
+ "g--",
156
+ label="mean-std",
157
+ )
158
+ plt.fill_between(
159
+ cumulative_steps,
160
+ [s.score.min for s in callback.stats], # type: ignore
161
+ [s.score.max for s in callback.stats], # type: ignore
162
+ facecolor="cyan",
163
+ label="range",
164
+ )
165
+ plt.xlabel("Steps")
166
+ plt.ylabel("Score")
167
+ plt.legend()
168
+ plt.title(f"Eval {run_name}")
169
+ tb_writer.add_figure("eval", figure)
170
+
171
+
172
+ Scalar = Union[bool, str, float, int, None]
173
+
174
+
175
+ def hparam_dict(
176
+ hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
177
+ ) -> Dict[str, Scalar]:
178
+ flattened = args.copy()
179
+ for k, v in flattened.items():
180
+ if isinstance(v, list):
181
+ flattened[k] = json.dumps(v)
182
+ for k, v in hyperparams.items():
183
+ if isinstance(v, dict):
184
+ for sk, sv in v.items():
185
+ key = f"{k}/{sk}"
186
+ if isinstance(sv, dict) or isinstance(sv, list):
187
+ flattened[key] = str(sv)
188
+ else:
189
+ flattened[key] = sv
190
+ else:
191
+ flattened[k] = v # type: ignore
192
+ return flattened # type: ignore
runner/train.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ import dataclasses
7
+ import shutil
8
+ import wandb
9
+ import yaml
10
+
11
+ from dataclasses import dataclass
12
+ from torch.utils.tensorboard.writer import SummaryWriter
13
+ from typing import Any, Dict, Optional, Sequence
14
+
15
+ from shared.callbacks.eval_callback import EvalCallback
16
+ from runner.config import Config, EnvHyperparams, RunArgs
17
+ from runner.env import make_env, make_eval_env
18
+ from runner.running_utils import (
19
+ ALGOS,
20
+ load_hyperparams,
21
+ set_seeds,
22
+ get_device,
23
+ make_policy,
24
+ plot_eval_callback,
25
+ hparam_dict,
26
+ )
27
+ from shared.stats import EpisodesStats
28
+
29
+
30
+ @dataclass
31
+ class TrainArgs(RunArgs):
32
+ wandb_project_name: Optional[str] = None
33
+ wandb_entity: Optional[str] = None
34
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
35
+
36
+
37
+ def train(args: TrainArgs):
38
+ print(args)
39
+ hyperparams = load_hyperparams(args.algo, args.env, os.getcwd())
40
+ print(hyperparams)
41
+ config = Config(args, hyperparams, os.getcwd())
42
+
43
+ wandb_enabled = args.wandb_project_name
44
+ if wandb_enabled:
45
+ wandb.tensorboard.patch(
46
+ root_logdir=config.tensorboard_summary_path, pytorch=True
47
+ )
48
+ wandb.init(
49
+ project=args.wandb_project_name,
50
+ entity=args.wandb_entity,
51
+ config=hyperparams, # type: ignore
52
+ name=config.run_name,
53
+ monitor_gym=True,
54
+ save_code=True,
55
+ tags=args.wandb_tags,
56
+ )
57
+ wandb.config.update(args)
58
+
59
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
60
+
61
+ set_seeds(args.seed, args.use_deterministic_algorithms)
62
+
63
+ env = make_env(
64
+ config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
65
+ )
66
+ device = get_device(config.device, env)
67
+ policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
68
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
69
+
70
+ num_parameters = policy.num_parameters()
71
+ num_trainable_parameters = policy.num_trainable_parameters()
72
+ if wandb_enabled:
73
+ wandb.run.summary["num_parameters"] = num_parameters
74
+ wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters
75
+ else:
76
+ print(
77
+ f"num_parameters = {num_parameters} ; "
78
+ f"num_trainable_parameters = {num_trainable_parameters}"
79
+ )
80
+
81
+ eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
82
+ record_best_videos = config.eval_params.get("record_best_videos", True)
83
+ callback = EvalCallback(
84
+ policy,
85
+ eval_env,
86
+ tb_writer,
87
+ best_model_path=config.model_dir_path(best=True),
88
+ **config.eval_params,
89
+ video_env=make_eval_env(
90
+ config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
91
+ )
92
+ if record_best_videos
93
+ else None,
94
+ best_video_dir=config.best_videos_dir,
95
+ )
96
+ algo.learn(config.n_timesteps, callback=callback)
97
+
98
+ policy.save(config.model_dir_path(best=False))
99
+
100
+ eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
101
+
102
+ plot_eval_callback(callback, tb_writer, config.run_name)
103
+
104
+ log_dict: Dict[str, Any] = {
105
+ "eval": eval_stats._asdict(),
106
+ }
107
+ if callback.best:
108
+ log_dict["best_eval"] = callback.best._asdict()
109
+ log_dict.update(hyperparams)
110
+ log_dict.update(vars(args))
111
+ with open(config.logs_path, "a") as f:
112
+ yaml.dump({config.run_name: log_dict}, f)
113
+
114
+ best_eval_stats: EpisodesStats = callback.best # type: ignore
115
+ tb_writer.add_hparams(
116
+ hparam_dict(hyperparams, vars(args)),
117
+ {
118
+ "hparam/best_mean": best_eval_stats.score.mean,
119
+ "hparam/best_result": best_eval_stats.score.mean
120
+ - best_eval_stats.score.std,
121
+ "hparam/last_mean": eval_stats.score.mean,
122
+ "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
123
+ },
124
+ None,
125
+ config.run_name,
126
+ )
127
+
128
+ tb_writer.close()
129
+
130
+ if wandb_enabled:
131
+ shutil.make_archive(
132
+ os.path.join(wandb.run.dir, config.model_dir_name()),
133
+ "zip",
134
+ config.model_dir_path(),
135
+ )
136
+ shutil.make_archive(
137
+ os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
138
+ "zip",
139
+ config.model_dir_path(best=True),
140
+ )
141
+ wandb.finish()
saved_models/ppo-BipedalWalker-v3-S2-best/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea5f9bf8495d29f3616d5ef7478ac49b3e0a0400a543781e080adf04256d4c98
3
+ size 51104
saved_models/ppo-BipedalWalker-v3-S2-best/vecnormalize.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea2f45021090e8ab4b74cf7dac0fc992f78d11765cf0734d738c157d04ddb86f
3
+ size 8776
shared/algorithm.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+
4
+ from abc import ABC, abstractmethod
5
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
6
+ from torch.utils.tensorboard.writer import SummaryWriter
7
+ from typing import List, Optional, TypeVar
8
+
9
+ from shared.callbacks.callback import Callback
10
+ from shared.policy.policy import Policy
11
+ from shared.stats import EpisodesStats
12
+
13
+ AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
14
+
15
+ class Algorithm(ABC):
16
+ @abstractmethod
17
+ def __init__(
18
+ self,
19
+ policy: Policy,
20
+ env: VecEnv,
21
+ device: torch.device,
22
+ tb_writer: SummaryWriter,
23
+ **kwargs,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.policy = policy
27
+ self.env = env
28
+ self.device = device
29
+ self.tb_writer = tb_writer
30
+
31
+ @abstractmethod
32
+ def learn(
33
+ self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None
34
+ ) -> AlgorithmSelf:
35
+ ...
shared/callbacks/callback.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class Callback(ABC):
5
+
6
+ def __init__(self) -> None:
7
+ super().__init__()
8
+ self.timesteps_elapsed = 0
9
+
10
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
11
+ self.timesteps_elapsed += timesteps_elapsed
12
+ return True
shared/callbacks/eval_callback.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ import os
4
+
5
+ from copy import deepcopy
6
+ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
7
+ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
8
+ from time import perf_counter
9
+ from torch.utils.tensorboard.writer import SummaryWriter
10
+ from typing import List, Optional, Union
11
+
12
+ from shared.callbacks.callback import Callback
13
+ from shared.policy.policy import Policy
14
+ from shared.stats import Episode, EpisodeAccumulator, EpisodesStats
15
+ from wrappers.vec_episode_recorder import VecEpisodeRecorder
16
+
17
+
18
+ class EvaluateAccumulator(EpisodeAccumulator):
19
+ def __init__(
20
+ self,
21
+ num_envs: int,
22
+ goal_episodes: int,
23
+ print_returns: bool = True,
24
+ ignore_first_episode: bool = False,
25
+ ):
26
+ super().__init__(num_envs)
27
+ self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
28
+ self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
29
+ self.print_returns = print_returns
30
+ if ignore_first_episode:
31
+ first_done = set()
32
+
33
+ def should_record_done(idx: int) -> bool:
34
+ has_done_first_episode = idx in first_done
35
+ first_done.add(idx)
36
+ return has_done_first_episode
37
+
38
+ self.should_record_done = should_record_done
39
+ else:
40
+ self.should_record_done = lambda idx: True
41
+
42
+ def on_done(self, ep_idx: int, episode: Episode) -> None:
43
+ if (
44
+ self.should_record_done(ep_idx)
45
+ and len(self.completed_episodes_by_env_idx[ep_idx])
46
+ >= self.goal_episodes_per_env
47
+ ):
48
+ return
49
+ self.completed_episodes_by_env_idx[ep_idx].append(episode)
50
+ if self.print_returns:
51
+ print(
52
+ f"Episode {len(self)} | "
53
+ f"Score {episode.score} | "
54
+ f"Length {episode.length}"
55
+ )
56
+
57
+ def __len__(self) -> int:
58
+ return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
59
+
60
+ @property
61
+ def episodes(self) -> List[Episode]:
62
+ return list(itertools.chain(*self.completed_episodes_by_env_idx))
63
+
64
+ def is_done(self) -> bool:
65
+ return all(
66
+ len(ce) == self.goal_episodes_per_env
67
+ for ce in self.completed_episodes_by_env_idx
68
+ )
69
+
70
+
71
+ def evaluate(
72
+ env: VecEnv,
73
+ policy: Policy,
74
+ n_episodes: int,
75
+ render: bool = False,
76
+ deterministic: bool = True,
77
+ print_returns: bool = True,
78
+ ignore_first_episode: bool = False,
79
+ ) -> EpisodesStats:
80
+ policy.eval()
81
+ episodes = EvaluateAccumulator(
82
+ env.num_envs, n_episodes, print_returns, ignore_first_episode
83
+ )
84
+
85
+ obs = env.reset()
86
+ while not episodes.is_done():
87
+ act = policy.act(obs, deterministic=deterministic)
88
+ obs, rew, done, _ = env.step(act)
89
+ episodes.step(rew, done)
90
+ if render:
91
+ env.render()
92
+ stats = EpisodesStats(episodes.episodes)
93
+ if print_returns:
94
+ print(stats)
95
+ return stats
96
+
97
+
98
+ class EvalCallback(Callback):
99
+ def __init__(
100
+ self,
101
+ policy: Policy,
102
+ env: VecEnv,
103
+ tb_writer: SummaryWriter,
104
+ best_model_path: Optional[str] = None,
105
+ step_freq: Union[int, float] = 50_000,
106
+ n_episodes: int = 10,
107
+ save_best: bool = True,
108
+ deterministic: bool = True,
109
+ record_best_videos: bool = True,
110
+ video_env: Optional[VecEnv] = None,
111
+ best_video_dir: Optional[str] = None,
112
+ max_video_length: int = 3600,
113
+ ignore_first_episode: bool = False,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.policy = policy
117
+ self.env = env
118
+ self.tb_writer = tb_writer
119
+ self.best_model_path = best_model_path
120
+ self.step_freq = int(step_freq)
121
+ self.n_episodes = n_episodes
122
+ self.save_best = save_best
123
+ self.deterministic = deterministic
124
+ self.stats: List[EpisodesStats] = []
125
+ self.best = None
126
+
127
+ self.record_best_videos = record_best_videos
128
+ assert video_env or not record_best_videos
129
+ self.video_env = video_env
130
+ assert best_video_dir or not record_best_videos
131
+ self.best_video_dir = best_video_dir
132
+ if best_video_dir:
133
+ os.makedirs(best_video_dir, exist_ok=True)
134
+ self.max_video_length = max_video_length
135
+ self.best_video_base_path = None
136
+
137
+ self.ignore_first_episode = ignore_first_episode
138
+
139
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
140
+ super().on_step(timesteps_elapsed)
141
+ if self.timesteps_elapsed // self.step_freq >= len(self.stats):
142
+ sync_vec_normalize(self.policy.vec_normalize, self.env)
143
+ self.evaluate()
144
+ return True
145
+
146
+ def evaluate(
147
+ self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
148
+ ) -> EpisodesStats:
149
+ start_time = perf_counter()
150
+ eval_stat = evaluate(
151
+ self.env,
152
+ self.policy,
153
+ n_episodes or self.n_episodes,
154
+ deterministic=self.deterministic,
155
+ print_returns=print_returns or False,
156
+ ignore_first_episode=self.ignore_first_episode,
157
+ )
158
+ end_time = perf_counter()
159
+ self.tb_writer.add_scalar(
160
+ "eval/steps_per_second",
161
+ eval_stat.length.sum() / (end_time - start_time),
162
+ self.timesteps_elapsed,
163
+ )
164
+ self.policy.train(True)
165
+ print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
166
+
167
+ self.stats.append(eval_stat)
168
+
169
+ if not self.best or eval_stat >= self.best:
170
+ strictly_better = not self.best or eval_stat > self.best
171
+ self.best = eval_stat
172
+ if self.save_best:
173
+ assert self.best_model_path
174
+ self.policy.save(self.best_model_path)
175
+ print("Saved best model")
176
+ self.best.write_to_tensorboard(
177
+ self.tb_writer, "best_eval", self.timesteps_elapsed
178
+ )
179
+ if strictly_better and self.record_best_videos:
180
+ assert self.video_env and self.best_video_dir
181
+ sync_vec_normalize(self.policy.vec_normalize, self.video_env)
182
+ self.best_video_base_path = os.path.join(
183
+ self.best_video_dir, str(self.timesteps_elapsed)
184
+ )
185
+ video_wrapped = VecEpisodeRecorder(
186
+ self.video_env,
187
+ self.best_video_base_path,
188
+ max_video_length=self.max_video_length,
189
+ )
190
+ video_stats = evaluate(
191
+ video_wrapped,
192
+ self.policy,
193
+ 1,
194
+ deterministic=self.deterministic,
195
+ print_returns=False,
196
+ )
197
+ print(f"Saved best video: {video_stats}")
198
+
199
+ eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
200
+
201
+ return eval_stat
202
+
203
+
204
+ def sync_vec_normalize(
205
+ origin_vec_normalize: Optional[VecNormalize], destination_env: VecEnv
206
+ ) -> None:
207
+ if origin_vec_normalize is not None:
208
+ eval_env_wrapper = destination_env
209
+ while isinstance(eval_env_wrapper, VecEnvWrapper):
210
+ if isinstance(eval_env_wrapper, VecNormalize):
211
+ if hasattr(origin_vec_normalize, "obs_rms"):
212
+ eval_env_wrapper.obs_rms = deepcopy(origin_vec_normalize.obs_rms)
213
+ eval_env_wrapper.ret_rms = deepcopy(origin_vec_normalize.ret_rms)
214
+ eval_env_wrapper = eval_env_wrapper.venv
shared/gae.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from typing import NamedTuple, Sequence
5
+
6
+ from shared.policy.on_policy import OnPolicy
7
+ from shared.trajectory import Trajectory
8
+
9
+
10
+ class RtgAdvantage(NamedTuple):
11
+ rewards_to_go: torch.Tensor
12
+ advantage: torch.Tensor
13
+
14
+
15
+ def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
16
+ dc = x.copy()
17
+ for i in reversed(range(len(x) - 1)):
18
+ dc[i] += gamma * dc[i + 1]
19
+ return dc
20
+
21
+
22
+ def compute_advantage(
23
+ trajectories: Sequence[Trajectory],
24
+ policy: OnPolicy,
25
+ gamma: float,
26
+ gae_lambda: float,
27
+ device: torch.device,
28
+ ) -> torch.Tensor:
29
+ advantage = []
30
+ for traj in trajectories:
31
+ last_val = 0
32
+ if not traj.terminated and traj.next_obs is not None:
33
+ last_val = policy.value(traj.next_obs)
34
+ rew = np.append(np.array(traj.rew), last_val)
35
+ v = np.append(np.array(traj.v), last_val)
36
+ deltas = rew[:-1] + gamma * v[1:] - v[:-1]
37
+ advantage.append(discounted_cumsum(deltas, gamma * gae_lambda))
38
+ return torch.as_tensor(
39
+ np.concatenate(advantage), dtype=torch.float32, device=device
40
+ )
41
+
42
+
43
+ def compute_rtg_and_advantage(
44
+ trajectories: Sequence[Trajectory],
45
+ policy: OnPolicy,
46
+ gamma: float,
47
+ gae_lambda: float,
48
+ device: torch.device,
49
+ ) -> RtgAdvantage:
50
+ rewards_to_go = []
51
+ advantages = []
52
+ for traj in trajectories:
53
+ last_val = 0
54
+ if not traj.terminated and traj.next_obs is not None:
55
+ last_val = policy.value(traj.next_obs)
56
+ rew = np.append(np.array(traj.rew), last_val)
57
+ v = np.append(np.array(traj.v), last_val)
58
+ deltas = rew[:-1] + gamma * v[1:] - v[:-1]
59
+ adv = discounted_cumsum(deltas, gamma * gae_lambda)
60
+ advantages.append(adv)
61
+ rewards_to_go.append(v[:-1] + adv)
62
+ return RtgAdvantage(
63
+ torch.as_tensor(
64
+ np.concatenate(rewards_to_go), dtype=torch.float32, device=device
65
+ ),
66
+ torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
67
+ )
shared/module/feature_extractor.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from abc import ABC, abstractmethod
7
+ from gym.spaces import Box, Discrete
8
+ from stable_baselines3.common.preprocessing import get_flattened_obs_dim
9
+ from typing import Dict, Optional, Sequence, Type
10
+
11
+ from shared.module.module import layer_init
12
+
13
+
14
+ class CnnFeatureExtractor(nn.Module, ABC):
15
+ @abstractmethod
16
+ def __init__(
17
+ self,
18
+ in_channels: int,
19
+ activation: Type[nn.Module] = nn.ReLU,
20
+ init_layers_orthogonal: Optional[bool] = None,
21
+ **kwargs,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+
26
+ class NatureCnn(CnnFeatureExtractor):
27
+ """
28
+ CNN from DQN Nature paper: Mnih, Volodymyr, et al.
29
+ "Human-level control through deep reinforcement learning."
30
+ Nature 518.7540 (2015): 529-533.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ in_channels: int,
36
+ activation: Type[nn.Module] = nn.ReLU,
37
+ init_layers_orthogonal: Optional[bool] = None,
38
+ **kwargs,
39
+ ) -> None:
40
+ if init_layers_orthogonal is None:
41
+ init_layers_orthogonal = True
42
+ super().__init__(in_channels, activation, init_layers_orthogonal)
43
+ self.cnn = nn.Sequential(
44
+ layer_init(
45
+ nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
46
+ init_layers_orthogonal,
47
+ ),
48
+ activation(),
49
+ layer_init(
50
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
51
+ init_layers_orthogonal,
52
+ ),
53
+ activation(),
54
+ layer_init(
55
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
56
+ init_layers_orthogonal,
57
+ ),
58
+ activation(),
59
+ nn.Flatten(),
60
+ )
61
+
62
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
63
+ return self.cnn(obs)
64
+
65
+
66
+ class ResidualBlock(nn.Module):
67
+ def __init__(
68
+ self,
69
+ channels: int,
70
+ activation: Type[nn.Module] = nn.ReLU,
71
+ init_layers_orthogonal: bool = False,
72
+ ) -> None:
73
+ super().__init__()
74
+ self.residual = nn.Sequential(
75
+ activation(),
76
+ layer_init(
77
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
78
+ ),
79
+ activation(),
80
+ layer_init(
81
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
82
+ ),
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return x + self.residual(x)
87
+
88
+
89
+ class ConvSequence(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels: int,
93
+ out_channels: int,
94
+ activation: Type[nn.Module] = nn.ReLU,
95
+ init_layers_orthogonal: bool = False,
96
+ ) -> None:
97
+ super().__init__()
98
+ self.seq = nn.Sequential(
99
+ layer_init(
100
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
101
+ init_layers_orthogonal,
102
+ ),
103
+ nn.MaxPool2d(3, stride=2, padding=1),
104
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
105
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
106
+ )
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ return self.seq(x)
110
+
111
+
112
+ class ImpalaCnn(CnnFeatureExtractor):
113
+ """
114
+ IMPALA-style CNN architecture
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ in_channels: int,
120
+ activation: Type[nn.Module] = nn.ReLU,
121
+ init_layers_orthogonal: Optional[bool] = None,
122
+ impala_channels: Sequence[int] = (16, 32, 32),
123
+ **kwargs,
124
+ ) -> None:
125
+ if init_layers_orthogonal is None:
126
+ init_layers_orthogonal = False
127
+ super().__init__(in_channels, activation, init_layers_orthogonal)
128
+ sequences = []
129
+ for out_channels in impala_channels:
130
+ sequences.append(
131
+ ConvSequence(
132
+ in_channels, out_channels, activation, init_layers_orthogonal
133
+ )
134
+ )
135
+ in_channels = out_channels
136
+ sequences.extend(
137
+ [
138
+ activation(),
139
+ nn.Flatten(),
140
+ ]
141
+ )
142
+ self.seq = nn.Sequential(*sequences)
143
+
144
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
145
+ return self.seq(obs)
146
+
147
+
148
+ CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
149
+ "nature": NatureCnn,
150
+ "impala": ImpalaCnn,
151
+ }
152
+
153
+
154
+ class FeatureExtractor(nn.Module):
155
+ def __init__(
156
+ self,
157
+ obs_space: gym.Space,
158
+ activation: Type[nn.Module],
159
+ init_layers_orthogonal: bool = False,
160
+ cnn_feature_dim: int = 512,
161
+ cnn_style: str = "nature",
162
+ cnn_layers_init_orthogonal: Optional[bool] = None,
163
+ impala_channels: Sequence[int] = (16, 32, 32),
164
+ ) -> None:
165
+ super().__init__()
166
+ if isinstance(obs_space, Box):
167
+ # Conv2D: (channels, height, width)
168
+ if len(obs_space.shape) == 3:
169
+ cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
170
+ obs_space.shape[0],
171
+ activation,
172
+ init_layers_orthogonal=cnn_layers_init_orthogonal,
173
+ impala_channels=impala_channels,
174
+ )
175
+
176
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
177
+ if len(obs.shape) == 3:
178
+ obs = obs.unsqueeze(0)
179
+ return obs.float() / 255.0
180
+
181
+ with torch.no_grad():
182
+ cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
183
+ self.preprocess = preprocess
184
+ self.feature_extractor = nn.Sequential(
185
+ cnn,
186
+ layer_init(
187
+ nn.Linear(cnn_out.shape[1], cnn_feature_dim),
188
+ init_layers_orthogonal,
189
+ ),
190
+ activation(),
191
+ )
192
+ self.out_dim = cnn_feature_dim
193
+ elif len(obs_space.shape) == 1:
194
+
195
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
196
+ if len(obs.shape) == 1:
197
+ obs = obs.unsqueeze(0)
198
+ return obs.float()
199
+
200
+ self.preprocess = preprocess
201
+ self.feature_extractor = nn.Flatten()
202
+ self.out_dim = get_flattened_obs_dim(obs_space)
203
+ else:
204
+ raise ValueError(f"Unsupported observation space: {obs_space}")
205
+ elif isinstance(obs_space, Discrete):
206
+ self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
207
+ self.feature_extractor = nn.Flatten()
208
+ self.out_dim = obs_space.n
209
+ else:
210
+ raise NotImplementedError
211
+
212
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
213
+ if self.preprocess:
214
+ obs = self.preprocess(obs)
215
+ return self.feature_extractor(obs)