sgoodfriend commited on
Commit
cb23ed3
1 Parent(s): 8bf4dee

VPG playing MountainCarContinuous-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. benchmarks/benchmark_test.sh +0 -32
  2. benchmarks/colab_atari1.sh +0 -5
  3. benchmarks/colab_atari2.sh +0 -5
  4. benchmarks/colab_basic.sh +0 -5
  5. benchmarks/colab_benchmark.ipynb +0 -195
  6. benchmarks/colab_carracing.sh +0 -5
  7. benchmarks/colab_pybullet.sh +0 -5
  8. benchmarks/train_loop.sh +0 -15
  9. colab_enjoy.ipynb +0 -198
  10. colab_requirements.txt +0 -14
  11. colab_train.ipynb +0 -200
  12. dqn/dqn.py +0 -182
  13. dqn/policy.py +0 -45
  14. dqn/q_net.py +0 -39
  15. hf-deep-rl/dqn_SpaceInvadersNoFrameskip_v4.ipynb +0 -0
  16. hyperparams/dqn.yml +0 -130
  17. hyperparams/ppo.yml +0 -334
  18. hyperparams/vpg.yml +0 -176
  19. lambda_labs/benchmark.sh +0 -32
  20. lambda_labs/impala_atari_benchmark.sh +0 -19
  21. lambda_labs/lambda_requirements.txt +0 -16
  22. lambda_labs/procgen_benchmark.sh +0 -18
  23. lambda_labs/setup.sh +0 -10
  24. lambda_labs/starpilot_hard_benchmark.sh +0 -18
  25. poetry.lock +0 -0
  26. ppo/policy.py +0 -31
  27. ppo/ppo.py +0 -311
  28. publish/markdown_format.py +0 -210
  29. replay.meta.json +1 -1
  30. rl_algo_impls/benchmark_publish.py +2 -2
  31. rl_algo_impls/huggingface_publish.py +1 -0
  32. runner/config.py +0 -154
  33. runner/env.py +0 -256
  34. runner/evaluate.py +0 -103
  35. runner/running_utils.py +0 -192
  36. runner/train.py +0 -134
  37. saved_models/vpg-MountainCarContinuous-v0-S3-best/vecnormalize.pkl +0 -3
  38. shared/algorithm.py +0 -35
  39. shared/callbacks/callback.py +0 -12
  40. shared/callbacks/eval_callback.py +0 -206
  41. shared/gae.py +0 -67
  42. shared/module/feature_extractor.py +0 -215
  43. shared/module/module.py +0 -40
  44. shared/policy/actor.py +0 -305
  45. shared/policy/critic.py +0 -28
  46. shared/policy/on_policy.py +0 -201
  47. shared/policy/policy.py +0 -74
  48. shared/schedule.py +0 -19
  49. shared/stats.py +0 -153
  50. shared/trajectory.py +0 -80
benchmarks/benchmark_test.sh DELETED
@@ -1,32 +0,0 @@
1
- source benchmarks/train_loop.sh
2
-
3
- export WANDB_PROJECT_NAME="rl-algo-impls"
4
-
5
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
6
-
7
- ALGOS=(
8
- # "vpg"
9
- "dqn"
10
- # "ppo"
11
- )
12
- ENVS=(
13
- # Basic
14
- "CartPole-v1"
15
- "MountainCar-v0"
16
- # "MountainCarContinuous-v0"
17
- "Acrobot-v1"
18
- "LunarLander-v2"
19
- # # PyBullet
20
- # "HalfCheetahBulletEnv-v0"
21
- # "AntBulletEnv-v0"
22
- # "HopperBulletEnv-v0"
23
- # "Walker2DBulletEnv-v0"
24
- # # CarRacing
25
- # "CarRacing-v0"
26
- # Atari
27
- "PongNoFrameskip-v4"
28
- "BreakoutNoFrameskip-v4"
29
- "SpaceInvadersNoFrameskip-v4"
30
- "QbertNoFrameskip-v4"
31
- )
32
- train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/colab_atari1.sh DELETED
@@ -1,5 +0,0 @@
1
- source benchmarks/train_loop.sh
2
- ALGOS="ppo"
3
- ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
4
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
- train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
benchmarks/colab_atari2.sh DELETED
@@ -1,5 +0,0 @@
1
- source benchmarks/train_loop.sh
2
- ALGOS="ppo"
3
- ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
4
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
- train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
benchmarks/colab_basic.sh DELETED
@@ -1,5 +0,0 @@
1
- source benchmarks/train_loop.sh
2
- ALGOS="ppo"
3
- ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
4
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
- train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
benchmarks/colab_benchmark.ipynb DELETED
@@ -1,195 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "machine_shape": "hm",
8
- "authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
9
- "include_colab_link": true
10
- },
11
- "kernelspec": {
12
- "name": "python3",
13
- "display_name": "Python 3"
14
- },
15
- "language_info": {
16
- "name": "python"
17
- },
18
- "gpuClass": "standard",
19
- "accelerator": "GPU"
20
- },
21
- "cells": [
22
- {
23
- "cell_type": "markdown",
24
- "metadata": {
25
- "id": "view-in-github",
26
- "colab_type": "text"
27
- },
28
- "source": [
29
- "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
- ]
31
- },
32
- {
33
- "cell_type": "markdown",
34
- "source": [
35
- "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
- "## Parameters\n",
37
- "\n",
38
- "\n",
39
- "1. Wandb\n",
40
- "\n"
41
- ],
42
- "metadata": {
43
- "id": "S-tXDWP8WTLc"
44
- }
45
- },
46
- {
47
- "cell_type": "code",
48
- "source": [
49
- "from getpass import getpass\n",
50
- "import os\n",
51
- "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
- ],
53
- "metadata": {
54
- "id": "1ZtdYgxWNGwZ"
55
- },
56
- "execution_count": null,
57
- "outputs": []
58
- },
59
- {
60
- "cell_type": "markdown",
61
- "source": [
62
- "## Setup\n",
63
- "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
64
- ],
65
- "metadata": {
66
- "id": "bsG35Io0hmKG"
67
- }
68
- },
69
- {
70
- "cell_type": "code",
71
- "source": [
72
- "%%capture\n",
73
- "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
74
- ],
75
- "metadata": {
76
- "id": "k5ynTV25hdAf"
77
- },
78
- "execution_count": null,
79
- "outputs": []
80
- },
81
- {
82
- "cell_type": "markdown",
83
- "source": [
84
- "Installing the correct packages:\n",
85
- "\n",
86
- "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
87
- ],
88
- "metadata": {
89
- "id": "jKxGok-ElYQ7"
90
- }
91
- },
92
- {
93
- "cell_type": "code",
94
- "source": [
95
- "%%capture\n",
96
- "!apt install python-opengl\n",
97
- "!apt install ffmpeg\n",
98
- "!apt install xvfb\n",
99
- "!apt install swig"
100
- ],
101
- "metadata": {
102
- "id": "nn6EETTc2Ewf"
103
- },
104
- "execution_count": null,
105
- "outputs": []
106
- },
107
- {
108
- "cell_type": "code",
109
- "source": [
110
- "%%capture\n",
111
- "%cd /content/rl-algo-impls\n",
112
- "!pip install -r colab_requirements.txt"
113
- ],
114
- "metadata": {
115
- "id": "AfZh9rH3yQii"
116
- },
117
- "execution_count": null,
118
- "outputs": []
119
- },
120
- {
121
- "cell_type": "markdown",
122
- "source": [
123
- "## Run Once Per Runtime"
124
- ],
125
- "metadata": {
126
- "id": "4o5HOLjc4wq7"
127
- }
128
- },
129
- {
130
- "cell_type": "code",
131
- "source": [
132
- "import wandb\n",
133
- "wandb.login()"
134
- ],
135
- "metadata": {
136
- "id": "PCXa5tdS2qFX"
137
- },
138
- "execution_count": null,
139
- "outputs": []
140
- },
141
- {
142
- "cell_type": "markdown",
143
- "source": [
144
- "## Restart Session beteween runs"
145
- ],
146
- "metadata": {
147
- "id": "AZBZfSUV43JQ"
148
- }
149
- },
150
- {
151
- "cell_type": "code",
152
- "source": [
153
- "%%capture\n",
154
- "from pyvirtualdisplay import Display\n",
155
- "\n",
156
- "virtual_display = Display(visible=0, size=(1400, 900))\n",
157
- "virtual_display.start()"
158
- ],
159
- "metadata": {
160
- "id": "VzemeQJP2NO9"
161
- },
162
- "execution_count": null,
163
- "outputs": []
164
- },
165
- {
166
- "cell_type": "markdown",
167
- "source": [
168
- "The below 5 bash scripts train agents on environments with 3 seeds each:\n",
169
- "- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
170
- "- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
171
- "- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
172
- ],
173
- "metadata": {
174
- "id": "nSHfna0hLlO1"
175
- }
176
- },
177
- {
178
- "cell_type": "code",
179
- "source": [
180
- "%cd /content/rl-algo-impls\n",
181
- "os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
182
- "!./benchmarks/colab_basic.sh\n",
183
- "!./benchmarks/colab_pybullet.sh\n",
184
- "# !./benchmarks/colab_carracing.sh\n",
185
- "# !./benchmarks/colab_atari1.sh\n",
186
- "# !./benchmarks/colab_atari2.sh"
187
- ],
188
- "metadata": {
189
- "id": "07aHYFH1zfXa"
190
- },
191
- "execution_count": null,
192
- "outputs": []
193
- }
194
- ]
195
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
benchmarks/colab_carracing.sh DELETED
@@ -1,5 +0,0 @@
1
- source benchmarks/train_loop.sh
2
- ALGOS="ppo"
3
- ENVS="CarRacing-v0"
4
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
- train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
benchmarks/colab_pybullet.sh DELETED
@@ -1,5 +0,0 @@
1
- source benchmarks/train_loop.sh
2
- ALGOS="ppo"
3
- ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
4
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
- train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
benchmarks/train_loop.sh DELETED
@@ -1,15 +0,0 @@
1
- train_loop () {
2
- local WANDB_TAGS="benchmark_$(git rev-parse --short HEAD) host_$(hostname)"
3
- local algo
4
- local env
5
- local seed
6
- local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
7
- local SEEDS="${SEEDS:-1 2 3}"
8
- for algo in $(echo $1); do
9
- for env in $(echo $2); do
10
- for seed in $SEEDS; do
11
- echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME
12
- done
13
- done
14
- done
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colab_enjoy.ipynb DELETED
@@ -1,198 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "machine_shape": "hm",
8
- "authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
9
- "include_colab_link": true
10
- },
11
- "kernelspec": {
12
- "name": "python3",
13
- "display_name": "Python 3"
14
- },
15
- "language_info": {
16
- "name": "python"
17
- },
18
- "gpuClass": "standard",
19
- "accelerator": "GPU"
20
- },
21
- "cells": [
22
- {
23
- "cell_type": "markdown",
24
- "metadata": {
25
- "id": "view-in-github",
26
- "colab_type": "text"
27
- },
28
- "source": [
29
- "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
- ]
31
- },
32
- {
33
- "cell_type": "markdown",
34
- "source": [
35
- "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
- "## Parameters\n",
37
- "\n",
38
- "\n",
39
- "1. Wandb\n",
40
- "\n"
41
- ],
42
- "metadata": {
43
- "id": "S-tXDWP8WTLc"
44
- }
45
- },
46
- {
47
- "cell_type": "code",
48
- "source": [
49
- "from getpass import getpass\n",
50
- "import os\n",
51
- "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
- ],
53
- "metadata": {
54
- "id": "1ZtdYgxWNGwZ"
55
- },
56
- "execution_count": null,
57
- "outputs": []
58
- },
59
- {
60
- "cell_type": "markdown",
61
- "source": [
62
- "2. enjoy.py parameters"
63
- ],
64
- "metadata": {
65
- "id": "ao0nAh3MOdN7"
66
- }
67
- },
68
- {
69
- "cell_type": "code",
70
- "source": [
71
- "WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
72
- ],
73
- "metadata": {
74
- "id": "jKL_NFhVOjSc"
75
- },
76
- "execution_count": 2,
77
- "outputs": []
78
- },
79
- {
80
- "cell_type": "markdown",
81
- "source": [
82
- "## Setup\n",
83
- "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
84
- ],
85
- "metadata": {
86
- "id": "bsG35Io0hmKG"
87
- }
88
- },
89
- {
90
- "cell_type": "code",
91
- "source": [
92
- "%%capture\n",
93
- "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
94
- ],
95
- "metadata": {
96
- "id": "k5ynTV25hdAf"
97
- },
98
- "execution_count": 3,
99
- "outputs": []
100
- },
101
- {
102
- "cell_type": "markdown",
103
- "source": [
104
- "Installing the correct packages:\n",
105
- "\n",
106
- "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
107
- ],
108
- "metadata": {
109
- "id": "jKxGok-ElYQ7"
110
- }
111
- },
112
- {
113
- "cell_type": "code",
114
- "source": [
115
- "%%capture\n",
116
- "!apt install python-opengl\n",
117
- "!apt install ffmpeg\n",
118
- "!apt install xvfb\n",
119
- "!apt install swig"
120
- ],
121
- "metadata": {
122
- "id": "nn6EETTc2Ewf"
123
- },
124
- "execution_count": 4,
125
- "outputs": []
126
- },
127
- {
128
- "cell_type": "code",
129
- "source": [
130
- "%%capture\n",
131
- "%cd /content/rl-algo-impls\n",
132
- "!pip install -r colab_requirements.txt"
133
- ],
134
- "metadata": {
135
- "id": "AfZh9rH3yQii"
136
- },
137
- "execution_count": 5,
138
- "outputs": []
139
- },
140
- {
141
- "cell_type": "markdown",
142
- "source": [
143
- "## Run Once Per Runtime"
144
- ],
145
- "metadata": {
146
- "id": "4o5HOLjc4wq7"
147
- }
148
- },
149
- {
150
- "cell_type": "code",
151
- "source": [
152
- "import wandb\n",
153
- "wandb.login()"
154
- ],
155
- "metadata": {
156
- "id": "PCXa5tdS2qFX"
157
- },
158
- "execution_count": null,
159
- "outputs": []
160
- },
161
- {
162
- "cell_type": "markdown",
163
- "source": [
164
- "## Restart Session beteween runs"
165
- ],
166
- "metadata": {
167
- "id": "AZBZfSUV43JQ"
168
- }
169
- },
170
- {
171
- "cell_type": "code",
172
- "source": [
173
- "%%capture\n",
174
- "from pyvirtualdisplay import Display\n",
175
- "\n",
176
- "virtual_display = Display(visible=0, size=(1400, 900))\n",
177
- "virtual_display.start()"
178
- ],
179
- "metadata": {
180
- "id": "VzemeQJP2NO9"
181
- },
182
- "execution_count": 7,
183
- "outputs": []
184
- },
185
- {
186
- "cell_type": "code",
187
- "source": [
188
- "%cd /content/rl-algo-impls\n",
189
- "!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
190
- ],
191
- "metadata": {
192
- "id": "07aHYFH1zfXa"
193
- },
194
- "execution_count": null,
195
- "outputs": []
196
- }
197
- ]
198
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colab_requirements.txt DELETED
@@ -1,14 +0,0 @@
1
- AutoROM.accept-rom-license >= 0.4.2, < 0.5
2
- stable-baselines3[extra] >= 1.7.0, < 1.8
3
- gym[box2d] >= 0.21.0, < 0.22
4
- pyglet == 1.5.27
5
- wandb >= 0.13.10, < 0.14
6
- pyvirtualdisplay == 3.0
7
- pybullet >= 3.2.5, < 3.3
8
- tabulate >= 0.9.0, < 0.10
9
- huggingface-hub >= 0.12.0, < 0.13
10
- numexpr >= 2.8.4, < 2.9
11
- gym3 >= 0.3.3, < 0.4
12
- glfw >= 1.12.0, < 1.13
13
- procgen >= 0.10.7, < 0.11
14
- ipython >= 8.10.0, < 8.11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colab_train.ipynb DELETED
@@ -1,200 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "machine_shape": "hm",
8
- "authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
9
- "include_colab_link": true
10
- },
11
- "kernelspec": {
12
- "name": "python3",
13
- "display_name": "Python 3"
14
- },
15
- "language_info": {
16
- "name": "python"
17
- },
18
- "gpuClass": "standard",
19
- "accelerator": "GPU"
20
- },
21
- "cells": [
22
- {
23
- "cell_type": "markdown",
24
- "metadata": {
25
- "id": "view-in-github",
26
- "colab_type": "text"
27
- },
28
- "source": [
29
- "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
- ]
31
- },
32
- {
33
- "cell_type": "markdown",
34
- "source": [
35
- "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
- "## Parameters\n",
37
- "\n",
38
- "\n",
39
- "1. Wandb\n",
40
- "\n"
41
- ],
42
- "metadata": {
43
- "id": "S-tXDWP8WTLc"
44
- }
45
- },
46
- {
47
- "cell_type": "code",
48
- "source": [
49
- "from getpass import getpass\n",
50
- "import os\n",
51
- "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
- ],
53
- "metadata": {
54
- "id": "1ZtdYgxWNGwZ"
55
- },
56
- "execution_count": null,
57
- "outputs": []
58
- },
59
- {
60
- "cell_type": "markdown",
61
- "source": [
62
- "2. train run parameters"
63
- ],
64
- "metadata": {
65
- "id": "ao0nAh3MOdN7"
66
- }
67
- },
68
- {
69
- "cell_type": "code",
70
- "source": [
71
- "ALGO = \"ppo\"\n",
72
- "ENV = \"CartPole-v1\"\n",
73
- "SEED = 1"
74
- ],
75
- "metadata": {
76
- "id": "jKL_NFhVOjSc"
77
- },
78
- "execution_count": null,
79
- "outputs": []
80
- },
81
- {
82
- "cell_type": "markdown",
83
- "source": [
84
- "## Setup\n",
85
- "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
86
- ],
87
- "metadata": {
88
- "id": "bsG35Io0hmKG"
89
- }
90
- },
91
- {
92
- "cell_type": "code",
93
- "source": [
94
- "%%capture\n",
95
- "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
96
- ],
97
- "metadata": {
98
- "id": "k5ynTV25hdAf"
99
- },
100
- "execution_count": null,
101
- "outputs": []
102
- },
103
- {
104
- "cell_type": "markdown",
105
- "source": [
106
- "Installing the correct packages:\n",
107
- "\n",
108
- "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
109
- ],
110
- "metadata": {
111
- "id": "jKxGok-ElYQ7"
112
- }
113
- },
114
- {
115
- "cell_type": "code",
116
- "source": [
117
- "%%capture\n",
118
- "!apt install python-opengl\n",
119
- "!apt install ffmpeg\n",
120
- "!apt install xvfb\n",
121
- "!apt install swig"
122
- ],
123
- "metadata": {
124
- "id": "nn6EETTc2Ewf"
125
- },
126
- "execution_count": null,
127
- "outputs": []
128
- },
129
- {
130
- "cell_type": "code",
131
- "source": [
132
- "%%capture\n",
133
- "%cd /content/rl-algo-impls\n",
134
- "!pip install -r colab_requirements.txt"
135
- ],
136
- "metadata": {
137
- "id": "AfZh9rH3yQii"
138
- },
139
- "execution_count": null,
140
- "outputs": []
141
- },
142
- {
143
- "cell_type": "markdown",
144
- "source": [
145
- "## Run Once Per Runtime"
146
- ],
147
- "metadata": {
148
- "id": "4o5HOLjc4wq7"
149
- }
150
- },
151
- {
152
- "cell_type": "code",
153
- "source": [
154
- "import wandb\n",
155
- "wandb.login()"
156
- ],
157
- "metadata": {
158
- "id": "PCXa5tdS2qFX"
159
- },
160
- "execution_count": null,
161
- "outputs": []
162
- },
163
- {
164
- "cell_type": "markdown",
165
- "source": [
166
- "## Restart Session beteween runs"
167
- ],
168
- "metadata": {
169
- "id": "AZBZfSUV43JQ"
170
- }
171
- },
172
- {
173
- "cell_type": "code",
174
- "source": [
175
- "%%capture\n",
176
- "from pyvirtualdisplay import Display\n",
177
- "\n",
178
- "virtual_display = Display(visible=0, size=(1400, 900))\n",
179
- "virtual_display.start()"
180
- ],
181
- "metadata": {
182
- "id": "VzemeQJP2NO9"
183
- },
184
- "execution_count": null,
185
- "outputs": []
186
- },
187
- {
188
- "cell_type": "code",
189
- "source": [
190
- "%cd /content/rl-algo-impls\n",
191
- "!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
192
- ],
193
- "metadata": {
194
- "id": "07aHYFH1zfXa"
195
- },
196
- "execution_count": null,
197
- "outputs": []
198
- }
199
- ]
200
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dqn/dqn.py DELETED
@@ -1,182 +0,0 @@
1
- import copy
2
- import numpy as np
3
- import random
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from collections import deque
9
- from torch.optim import Adam
10
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
11
- from torch.utils.tensorboard.writer import SummaryWriter
12
- from typing import List, NamedTuple, Optional, TypeVar
13
-
14
- from dqn.policy import DQNPolicy
15
- from shared.algorithm import Algorithm
16
- from shared.callbacks.callback import Callback
17
- from shared.schedule import linear_schedule
18
-
19
-
20
- class Transition(NamedTuple):
21
- obs: np.ndarray
22
- action: np.ndarray
23
- reward: float
24
- done: bool
25
- next_obs: np.ndarray
26
-
27
-
28
- class Batch(NamedTuple):
29
- obs: np.ndarray
30
- actions: np.ndarray
31
- rewards: np.ndarray
32
- dones: np.ndarray
33
- next_obs: np.ndarray
34
-
35
-
36
- class ReplayBuffer:
37
- def __init__(self, num_envs: int, maxlen: int) -> None:
38
- self.num_envs = num_envs
39
- self.buffer = deque(maxlen=maxlen)
40
-
41
- def add(
42
- self,
43
- obs: VecEnvObs,
44
- action: np.ndarray,
45
- reward: np.ndarray,
46
- done: np.ndarray,
47
- next_obs: VecEnvObs,
48
- ) -> None:
49
- assert isinstance(obs, np.ndarray)
50
- assert isinstance(next_obs, np.ndarray)
51
- for i in range(self.num_envs):
52
- self.buffer.append(
53
- Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
54
- )
55
-
56
- def sample(self, batch_size: int) -> Batch:
57
- ts = random.sample(self.buffer, batch_size)
58
- return Batch(
59
- obs=np.array([t.obs for t in ts]),
60
- actions=np.array([t.action for t in ts]),
61
- rewards=np.array([t.reward for t in ts]),
62
- dones=np.array([t.done for t in ts]),
63
- next_obs=np.array([t.next_obs for t in ts]),
64
- )
65
-
66
- def __len__(self) -> int:
67
- return len(self.buffer)
68
-
69
-
70
- DQNSelf = TypeVar("DQNSelf", bound="DQN")
71
-
72
-
73
- class DQN(Algorithm):
74
- def __init__(
75
- self,
76
- policy: DQNPolicy,
77
- env: VecEnv,
78
- device: torch.device,
79
- tb_writer: SummaryWriter,
80
- learning_rate: float = 1e-4,
81
- buffer_size: int = 1_000_000,
82
- learning_starts: int = 50_000,
83
- batch_size: int = 32,
84
- tau: float = 1.0,
85
- gamma: float = 0.99,
86
- train_freq: int = 4,
87
- gradient_steps: int = 1,
88
- target_update_interval: int = 10_000,
89
- exploration_fraction: float = 0.1,
90
- exploration_initial_eps: float = 1.0,
91
- exploration_final_eps: float = 0.05,
92
- max_grad_norm: float = 10.0,
93
- ) -> None:
94
- super().__init__(policy, env, device, tb_writer)
95
- self.policy = policy
96
-
97
- self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
98
-
99
- self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
100
- self.target_q_net.train(False)
101
- self.tau = tau
102
- self.target_update_interval = target_update_interval
103
-
104
- self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
105
- self.batch_size = batch_size
106
-
107
- self.learning_starts = learning_starts
108
- self.train_freq = train_freq
109
- self.gradient_steps = gradient_steps
110
-
111
- self.gamma = gamma
112
- self.exploration_eps_schedule = linear_schedule(
113
- exploration_initial_eps,
114
- exploration_final_eps,
115
- end_fraction=exploration_fraction,
116
- )
117
-
118
- self.max_grad_norm = max_grad_norm
119
-
120
- def learn(
121
- self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
122
- ) -> DQNSelf:
123
- self.policy.train(True)
124
- obs = self.env.reset()
125
- obs = self._collect_rollout(self.learning_starts, obs, 1)
126
- learning_steps = total_timesteps - self.learning_starts
127
- timesteps_elapsed = 0
128
- steps_since_target_update = 0
129
- while timesteps_elapsed < learning_steps:
130
- progress = timesteps_elapsed / learning_steps
131
- eps = self.exploration_eps_schedule(progress)
132
- obs = self._collect_rollout(self.train_freq, obs, eps)
133
- rollout_steps = self.train_freq
134
- timesteps_elapsed += rollout_steps
135
- for _ in range(
136
- self.gradient_steps if self.gradient_steps > 0 else self.train_freq
137
- ):
138
- self.train()
139
- steps_since_target_update += rollout_steps
140
- if steps_since_target_update >= self.target_update_interval:
141
- self._update_target()
142
- steps_since_target_update = 0
143
- if callback:
144
- callback.on_step(timesteps_elapsed=rollout_steps)
145
- return self
146
-
147
- def train(self) -> None:
148
- if len(self.replay_buffer) < self.batch_size:
149
- return
150
- o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
151
- o = torch.as_tensor(o, device=self.device)
152
- a = torch.as_tensor(a, device=self.device).unsqueeze(1)
153
- r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
154
- d = torch.as_tensor(d, dtype=torch.long, device=self.device)
155
- next_o = torch.as_tensor(next_o, device=self.device)
156
-
157
- with torch.no_grad():
158
- target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
159
- current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
160
- loss = F.smooth_l1_loss(current, target)
161
-
162
- self.optimizer.zero_grad()
163
- loss.backward()
164
- if self.max_grad_norm:
165
- nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
166
- self.optimizer.step()
167
-
168
- def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
169
- for _ in range(0, timesteps, self.env.num_envs):
170
- action = self.policy.act(obs, eps, deterministic=False)
171
- next_obs, reward, done, _ = self.env.step(action)
172
- self.replay_buffer.add(obs, action, reward, done, next_obs)
173
- obs = next_obs
174
- return obs
175
-
176
- def _update_target(self) -> None:
177
- for target_param, param in zip(
178
- self.target_q_net.parameters(), self.policy.q_net.parameters()
179
- ):
180
- target_param.data.copy_(
181
- self.tau * param.data + (1 - self.tau) * target_param.data
182
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dqn/policy.py DELETED
@@ -1,45 +0,0 @@
1
- import numpy as np
2
- import os
3
- import torch
4
-
5
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
6
- from typing import Optional, Sequence, TypeVar
7
-
8
- from dqn.q_net import QNetwork
9
- from shared.policy.policy import Policy
10
-
11
- DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
12
-
13
-
14
- class DQNPolicy(Policy):
15
- def __init__(
16
- self,
17
- env: VecEnv,
18
- hidden_sizes: Sequence[int] = [],
19
- cnn_feature_dim: int = 512,
20
- cnn_style: str = "nature",
21
- cnn_layers_init_orthogonal: Optional[bool] = None,
22
- **kwargs,
23
- ) -> None:
24
- super().__init__(env, **kwargs)
25
- self.q_net = QNetwork(
26
- env.observation_space,
27
- env.action_space,
28
- hidden_sizes,
29
- cnn_feature_dim=cnn_feature_dim,
30
- cnn_style=cnn_style,
31
- cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
32
- )
33
-
34
- def act(
35
- self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
36
- ) -> np.ndarray:
37
- assert eps == 0 if deterministic else eps >= 0
38
- if not deterministic and np.random.random() < eps:
39
- return np.array(
40
- [self.env.action_space.sample() for _ in range(self.env.num_envs)]
41
- )
42
- else:
43
- o = self._as_tensor(obs)
44
- with torch.no_grad():
45
- return self.q_net(o).argmax(axis=1).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dqn/q_net.py DELETED
@@ -1,39 +0,0 @@
1
- import gym
2
- import torch as th
3
- import torch.nn as nn
4
-
5
- from gym.spaces import Discrete
6
- from typing import Optional, Sequence, Type
7
-
8
- from shared.module.feature_extractor import FeatureExtractor
9
- from shared.module.module import mlp
10
-
11
-
12
- class QNetwork(nn.Module):
13
- def __init__(
14
- self,
15
- observation_space: gym.Space,
16
- action_space: gym.Space,
17
- hidden_sizes: Sequence[int] = [],
18
- activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
19
- cnn_feature_dim: int = 512,
20
- cnn_style: str = "nature",
21
- cnn_layers_init_orthogonal: Optional[bool] = None,
22
- ) -> None:
23
- super().__init__()
24
- assert isinstance(action_space, Discrete)
25
- self._feature_extractor = FeatureExtractor(
26
- observation_space,
27
- activation,
28
- cnn_feature_dim=cnn_feature_dim,
29
- cnn_style=cnn_style,
30
- cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
31
- )
32
- layer_sizes = (
33
- (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
34
- )
35
- self._fc = mlp(layer_sizes, activation)
36
-
37
- def forward(self, obs: th.Tensor) -> th.Tensor:
38
- x = self._feature_extractor(obs)
39
- return self._fc(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf-deep-rl/dqn_SpaceInvadersNoFrameskip_v4.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
hyperparams/dqn.yml DELETED
@@ -1,130 +0,0 @@
1
- CartPole-v1: &cartpole-defaults
2
- n_timesteps: !!float 5e4
3
- env_hyperparams:
4
- rolling_length: 50
5
- policy_hyperparams:
6
- hidden_sizes: [256, 256]
7
- algo_hyperparams:
8
- learning_rate: !!float 2.3e-3
9
- batch_size: 64
10
- buffer_size: 100000
11
- learning_starts: 1000
12
- gamma: 0.99
13
- target_update_interval: 10
14
- train_freq: 256
15
- gradient_steps: 128
16
- exploration_fraction: 0.16
17
- exploration_final_eps: 0.04
18
- eval_params:
19
- step_freq: !!float 1e4
20
-
21
- CartPole-v0:
22
- <<: *cartpole-defaults
23
- n_timesteps: !!float 4e4
24
-
25
- MountainCar-v0:
26
- n_timesteps: !!float 1.2e5
27
- env_hyperparams:
28
- rolling_length: 50
29
- policy_hyperparams:
30
- hidden_sizes: [256, 256]
31
- algo_hyperparams:
32
- learning_rate: !!float 4e-3
33
- batch_size: 128
34
- buffer_size: 10000
35
- learning_starts: 1000
36
- gamma: 0.98
37
- target_update_interval: 600
38
- train_freq: 16
39
- gradient_steps: 8
40
- exploration_fraction: 0.2
41
- exploration_final_eps: 0.07
42
-
43
- Acrobot-v1:
44
- n_timesteps: !!float 1e5
45
- env_hyperparams:
46
- rolling_length: 50
47
- policy_hyperparams:
48
- hidden_sizes: [256, 256]
49
- algo_hyperparams:
50
- learning_rate: !!float 6.3e-4
51
- batch_size: 128
52
- buffer_size: 50000
53
- learning_starts: 0
54
- gamma: 0.99
55
- target_update_interval: 250
56
- train_freq: 4
57
- gradient_steps: -1
58
- exploration_fraction: 0.12
59
- exploration_final_eps: 0.1
60
-
61
- LunarLander-v2:
62
- n_timesteps: !!float 5e5
63
- env_hyperparams:
64
- rolling_length: 50
65
- policy_hyperparams:
66
- hidden_sizes: [256, 256]
67
- algo_hyperparams:
68
- learning_rate: !!float 1e-4
69
- batch_size: 256
70
- buffer_size: 100000
71
- learning_starts: 10000
72
- gamma: 0.99
73
- target_update_interval: 250
74
- train_freq: 8
75
- gradient_steps: -1
76
- exploration_fraction: 0.12
77
- exploration_final_eps: 0.1
78
- max_grad_norm: 0.5
79
- eval_params:
80
- step_freq: 25_000
81
-
82
- _atari: &atari-defaults
83
- n_timesteps: !!float 1e7
84
- env_hyperparams:
85
- frame_stack: 4
86
- no_reward_timeout_steps: 1_000
87
- no_reward_fire_steps: 500
88
- n_envs: 8
89
- vec_env_class: "subproc"
90
- algo_hyperparams:
91
- buffer_size: 100000
92
- learning_rate: !!float 1e-4
93
- batch_size: 32
94
- learning_starts: 100000
95
- target_update_interval: 1000
96
- train_freq: 8
97
- gradient_steps: 2
98
- exploration_fraction: 0.1
99
- exploration_final_eps: 0.01
100
- eval_params:
101
- deterministic: false
102
-
103
- PongNoFrameskip-v4:
104
- <<: *atari-defaults
105
- n_timesteps: !!float 2.5e6
106
-
107
- _impala-atari: &impala-atari-defaults
108
- <<: *atari-defaults
109
- policy_hyperparams:
110
- cnn_style: impala
111
- cnn_feature_dim: 256
112
- init_layers_orthogonal: true
113
- cnn_layers_init_orthogonal: false
114
-
115
- impala-PongNoFrameskip-v4:
116
- <<: *impala-atari-defaults
117
- env_id: PongNoFrameskip-v4
118
- n_timesteps: !!float 2.5e6
119
-
120
- impala-BreakoutNoFrameskip-v4:
121
- <<: *impala-atari-defaults
122
- env_id: BreakoutNoFrameskip-v4
123
-
124
- impala-SpaceInvadersNoFrameskip-v4:
125
- <<: *impala-atari-defaults
126
- env_id: SpaceInvadersNoFrameskip-v4
127
-
128
- impala-QbertNoFrameskip-v4:
129
- <<: *impala-atari-defaults
130
- env_id: QbertNoFrameskip-v4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyperparams/ppo.yml DELETED
@@ -1,334 +0,0 @@
1
- CartPole-v1: &cartpole-defaults
2
- n_timesteps: !!float 1e5
3
- env_hyperparams:
4
- n_envs: 8
5
- algo_hyperparams:
6
- n_steps: 32
7
- batch_size: 256
8
- n_epochs: 20
9
- gae_lambda: 0.8
10
- gamma: 0.98
11
- ent_coef: 0.0
12
- learning_rate: 0.001
13
- learning_rate_decay: linear
14
- clip_range: 0.2
15
- clip_range_decay: linear
16
- eval_params:
17
- step_freq: !!float 2.5e4
18
-
19
- CartPole-v0:
20
- <<: *cartpole-defaults
21
- n_timesteps: !!float 5e4
22
-
23
- MountainCar-v0:
24
- n_timesteps: !!float 1e6
25
- env_hyperparams:
26
- normalize: true
27
- n_envs: 16
28
- algo_hyperparams:
29
- n_steps: 16
30
- n_epochs: 4
31
- gae_lambda: 0.98
32
- gamma: 0.99
33
- ent_coef: 0.0
34
-
35
- MountainCarContinuous-v0:
36
- n_timesteps: !!float 1e5
37
- env_hyperparams:
38
- normalize: true
39
- n_envs: 4
40
- # policy_hyperparams:
41
- # init_layers_orthogonal: false
42
- # log_std_init: -3.29
43
- # use_sde: true
44
- algo_hyperparams:
45
- n_steps: 512
46
- batch_size: 256
47
- n_epochs: 10
48
- learning_rate: !!float 7.77e-5
49
- ent_coef: 0.01 # 0.00429
50
- ent_coef_decay: linear
51
- clip_range: 0.1
52
- gae_lambda: 0.9
53
- max_grad_norm: 5
54
- vf_coef: 0.19
55
- eval_params:
56
- step_freq: 5000
57
-
58
- Acrobot-v1:
59
- n_timesteps: !!float 1e6
60
- env_hyperparams:
61
- n_envs: 16
62
- normalize: true
63
- algo_hyperparams:
64
- n_steps: 256
65
- n_epochs: 4
66
- gae_lambda: 0.94
67
- gamma: 0.99
68
- ent_coef: 0.0
69
-
70
- LunarLander-v2:
71
- n_timesteps: !!float 1e6
72
- env_hyperparams:
73
- n_envs: 16
74
- algo_hyperparams:
75
- n_steps: 1024
76
- batch_size: 64
77
- n_epochs: 4
78
- gae_lambda: 0.98
79
- gamma: 0.999
80
- ent_coef: 0.01
81
- ent_coef_decay: linear
82
- normalize_advantage: false
83
-
84
- CarRacing-v0: &carracing-defaults
85
- n_timesteps: !!float 4e6
86
- env_hyperparams:
87
- n_envs: 8
88
- frame_stack: 4
89
- policy_hyperparams: &carracing-policy-defaults
90
- use_sde: true
91
- log_std_init: -2
92
- init_layers_orthogonal: false
93
- activation_fn: relu
94
- share_features_extractor: false
95
- cnn_feature_dim: 256
96
- hidden_sizes: [256]
97
- algo_hyperparams:
98
- n_steps: 512
99
- batch_size: 128
100
- n_epochs: 10
101
- learning_rate: !!float 1e-4
102
- learning_rate_decay: linear
103
- gamma: 0.99
104
- gae_lambda: 0.95
105
- ent_coef: 0.0
106
- sde_sample_freq: 4
107
- max_grad_norm: 0.5
108
- vf_coef: 0.5
109
- clip_range: 0.2
110
-
111
- impala-CarRacing-v0:
112
- <<: *carracing-defaults
113
- env_id: CarRacing-v0
114
- policy_hyperparams:
115
- <<: *carracing-policy-defaults
116
- cnn_style: impala
117
- init_layers_orthogonal: true
118
- cnn_layers_init_orthogonal: false
119
- hidden_sizes: []
120
-
121
- # BreakoutNoFrameskip-v4
122
- # PongNoFrameskip-v4
123
- # SpaceInvadersNoFrameskip-v4
124
- # QbertNoFrameskip-v4
125
- _atari: &atari-defaults
126
- n_timesteps: !!float 1e7
127
- env_hyperparams: &atari-env-defaults
128
- n_envs: 8
129
- frame_stack: 4
130
- no_reward_timeout_steps: 1000
131
- no_reward_fire_steps: 500
132
- vec_env_class: subproc
133
- policy_hyperparams: &atari-policy-defaults
134
- activation_fn: relu
135
- algo_hyperparams:
136
- n_steps: 128
137
- batch_size: 256
138
- n_epochs: 4
139
- learning_rate: !!float 2.5e-4
140
- learning_rate_decay: linear
141
- clip_range: 0.1
142
- clip_range_decay: linear
143
- vf_coef: 0.5
144
- ent_coef: 0.01
145
- eval_params:
146
- deterministic: false
147
-
148
- debug-PongNoFrameskip-v4:
149
- <<: *atari-defaults
150
- device: cpu
151
- env_id: PongNoFrameskip-v4
152
- env_hyperparams:
153
- <<: *atari-env-defaults
154
- vec_env_class: dummy
155
-
156
- _impala-atari: &impala-atari-defaults
157
- <<: *atari-defaults
158
- policy_hyperparams:
159
- <<: *atari-policy-defaults
160
- cnn_style: impala
161
- cnn_feature_dim: 256
162
- init_layers_orthogonal: true
163
- cnn_layers_init_orthogonal: false
164
-
165
- impala-PongNoFrameskip-v4:
166
- <<: *impala-atari-defaults
167
- env_id: PongNoFrameskip-v4
168
-
169
- impala-BreakoutNoFrameskip-v4:
170
- <<: *impala-atari-defaults
171
- env_id: BreakoutNoFrameskip-v4
172
-
173
- impala-SpaceInvadersNoFrameskip-v4:
174
- <<: *impala-atari-defaults
175
- env_id: SpaceInvadersNoFrameskip-v4
176
-
177
- impala-QbertNoFrameskip-v4:
178
- <<: *impala-atari-defaults
179
- env_id: QbertNoFrameskip-v4
180
-
181
- HalfCheetahBulletEnv-v0: &pybullet-defaults
182
- n_timesteps: !!float 2e6
183
- env_hyperparams: &pybullet-env-defaults
184
- n_envs: 16
185
- normalize: true
186
- policy_hyperparams: &pybullet-policy-defaults
187
- pi_hidden_sizes: [256, 256]
188
- v_hidden_sizes: [256, 256]
189
- activation_fn: relu
190
- algo_hyperparams: &pybullet-algo-defaults
191
- n_steps: 512
192
- batch_size: 128
193
- n_epochs: 20
194
- gamma: 0.99
195
- gae_lambda: 0.9
196
- ent_coef: 0.0
197
- max_grad_norm: 0.5
198
- vf_coef: 0.5
199
- learning_rate: !!float 3e-5
200
- clip_range: 0.4
201
-
202
- AntBulletEnv-v0:
203
- <<: *pybullet-defaults
204
- policy_hyperparams:
205
- <<: *pybullet-policy-defaults
206
- algo_hyperparams:
207
- <<: *pybullet-algo-defaults
208
-
209
- Walker2DBulletEnv-v0:
210
- <<: *pybullet-defaults
211
- algo_hyperparams:
212
- <<: *pybullet-algo-defaults
213
- clip_range_decay: linear
214
-
215
- HopperBulletEnv-v0:
216
- <<: *pybullet-defaults
217
- algo_hyperparams:
218
- <<: *pybullet-algo-defaults
219
- clip_range_decay: linear
220
-
221
- HumanoidBulletEnv-v0:
222
- <<: *pybullet-defaults
223
- n_timesteps: !!float 1e7
224
- env_hyperparams:
225
- <<: *pybullet-env-defaults
226
- n_envs: 8
227
- policy_hyperparams:
228
- <<: *pybullet-policy-defaults
229
- # log_std_init: -1
230
- algo_hyperparams:
231
- <<: *pybullet-algo-defaults
232
- n_steps: 2048
233
- batch_size: 64
234
- n_epochs: 10
235
- gae_lambda: 0.95
236
- learning_rate: !!float 2.5e-4
237
- clip_range: 0.2
238
-
239
- _procgen: &procgen-defaults
240
- env_hyperparams: &procgen-env-defaults
241
- is_procgen: true
242
- n_envs: 64
243
- # grayscale: false
244
- # frame_stack: 4
245
- normalize: true # procgen only normalizes reward
246
- policy_hyperparams: &procgen-policy-defaults
247
- activation_fn: relu
248
- cnn_style: impala
249
- cnn_feature_dim: 256
250
- init_layers_orthogonal: true
251
- cnn_layers_init_orthogonal: false
252
- algo_hyperparams: &procgen-algo-defaults
253
- gamma: 0.999
254
- gae_lambda: 0.95
255
- n_steps: 256
256
- batch_size: 2048
257
- n_epochs: 3
258
- ent_coef: 0.01
259
- clip_range: 0.2
260
- # clip_range_decay: linear
261
- clip_range_vf: 0.2
262
- learning_rate: !!float 5e-4
263
- # learning_rate_decay: linear
264
- vf_coef: 0.5
265
- eval_params: &procgen-eval-defaults
266
- ignore_first_episode: true
267
- # deterministic: false
268
- step_freq: !!float 1e5
269
-
270
- _procgen-easy: &procgen-easy-defaults
271
- <<: *procgen-defaults
272
- n_timesteps: !!float 25e6
273
- env_hyperparams: &procgen-easy-env-defaults
274
- <<: *procgen-env-defaults
275
- make_kwargs:
276
- distribution_mode: easy
277
-
278
- procgen-coinrun-easy: &coinrun-easy-defaults
279
- <<: *procgen-easy-defaults
280
- env_id: coinrun
281
-
282
- debug-procgen-coinrun:
283
- <<: *coinrun-easy-defaults
284
- device: cpu
285
-
286
- procgen-starpilot-easy:
287
- <<: *procgen-easy-defaults
288
- env_id: starpilot
289
-
290
- procgen-bossfight-easy:
291
- <<: *procgen-easy-defaults
292
- env_id: bossfight
293
-
294
- procgen-bigfish-easy:
295
- <<: *procgen-easy-defaults
296
- env_id: bigfish
297
-
298
- _procgen-hard: &procgen-hard-defaults
299
- <<: *procgen-defaults
300
- n_timesteps: !!float 200e6
301
- env_hyperparams: &procgen-hard-env-defaults
302
- <<: *procgen-env-defaults
303
- n_envs: 256
304
- make_kwargs:
305
- distribution_mode: hard
306
- algo_hyperparams:
307
- <<: *procgen-algo-defaults
308
- batch_size: 8192
309
- eval_params:
310
- <<: *procgen-eval-defaults
311
- step_freq: !!float 5e5
312
-
313
- procgen-starpilot-hard: &procgen-starpilot-hard-defaults
314
- <<: *procgen-hard-defaults
315
- env_id: starpilot
316
-
317
- procgen-starpilot-hard-2xIMPALA:
318
- <<: *procgen-starpilot-hard-defaults
319
- policy_hyperparams:
320
- <<: *procgen-policy-defaults
321
- impala_channels: [32, 64, 64]
322
-
323
- procgen-starpilot-hard-2xIMPALA-fat:
324
- <<: *procgen-starpilot-hard-defaults
325
- policy_hyperparams:
326
- <<: *procgen-policy-defaults
327
- impala_channels: [32, 64, 64]
328
- cnn_feature_dim: 512
329
-
330
- procgen-starpilot-hard-4xIMPALA:
331
- <<: *procgen-starpilot-hard-defaults
332
- policy_hyperparams:
333
- <<: *procgen-policy-defaults
334
- impala_channels: [64, 128, 128]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hyperparams/vpg.yml DELETED
@@ -1,176 +0,0 @@
1
- CartPole-v1: &cartpole-defaults
2
- n_timesteps: !!float 4e5
3
- algo_hyperparams:
4
- n_steps: 4096
5
- pi_lr: 0.01
6
- gamma: 0.99
7
- gae_lambda: 1
8
- val_lr: 0.01
9
- train_v_iters: 80
10
- eval_params:
11
- step_freq: !!float 2.5e4
12
-
13
- CartPole-v0:
14
- <<: *cartpole-defaults
15
- n_timesteps: !!float 1e5
16
- algo_hyperparams:
17
- n_steps: 1024
18
- pi_lr: 0.01
19
- gamma: 0.99
20
- gae_lambda: 1
21
- val_lr: 0.01
22
- train_v_iters: 80
23
-
24
- MountainCar-v0:
25
- n_timesteps: !!float 1e6
26
- env_hyperparams:
27
- normalize: true
28
- n_envs: 16
29
- algo_hyperparams:
30
- n_steps: 200
31
- pi_lr: 0.005
32
- gamma: 0.99
33
- gae_lambda: 0.97
34
- val_lr: 0.01
35
- train_v_iters: 80
36
- max_grad_norm: 0.5
37
-
38
- MountainCarContinuous-v0:
39
- n_timesteps: !!float 3e5
40
- env_hyperparams:
41
- normalize: true
42
- n_envs: 4
43
- # policy_hyperparams:
44
- # init_layers_orthogonal: false
45
- # log_std_init: -3.29
46
- # use_sde: true
47
- algo_hyperparams:
48
- n_steps: 1000
49
- pi_lr: !!float 5e-4
50
- gamma: 0.99
51
- gae_lambda: 0.9
52
- val_lr: !!float 1e-3
53
- train_v_iters: 80
54
- max_grad_norm: 5
55
- eval_params:
56
- step_freq: 5000
57
-
58
- Acrobot-v1:
59
- n_timesteps: !!float 2e5
60
- algo_hyperparams:
61
- n_steps: 2048
62
- pi_lr: 0.005
63
- gamma: 0.99
64
- gae_lambda: 0.97
65
- val_lr: 0.01
66
- train_v_iters: 80
67
- max_grad_norm: 0.5
68
-
69
- LunarLander-v2:
70
- n_timesteps: !!float 4e6
71
- policy_hyperparams:
72
- hidden_sizes: [256, 256]
73
- algo_hyperparams:
74
- n_steps: 2048
75
- pi_lr: 0.0001
76
- gamma: 0.999
77
- gae_lambda: 0.97
78
- val_lr: 0.0001
79
- train_v_iters: 80
80
- max_grad_norm: 0.5
81
- eval_params:
82
- deterministic: false
83
-
84
- CarRacing-v0:
85
- n_timesteps: !!float 4e6
86
- env_hyperparams:
87
- frame_stack: 4
88
- n_envs: 4
89
- vec_env_class: "dummy"
90
- policy_hyperparams:
91
- use_sde: true
92
- log_std_init: -2
93
- init_layers_orthogonal: false
94
- activation_fn: relu
95
- cnn_feature_dim: 256
96
- hidden_sizes: [256]
97
- algo_hyperparams:
98
- n_steps: 1000
99
- pi_lr: !!float 5e-5
100
- gamma: 0.99
101
- gae_lambda: 0.95
102
- val_lr: !!float 1e-4
103
- train_v_iters: 40
104
- max_grad_norm: 0.5
105
- sde_sample_freq: 4
106
-
107
- HalfCheetahBulletEnv-v0: &pybullet-defaults
108
- n_timesteps: !!float 2e6
109
- policy_hyperparams: &pybullet-policy-defaults
110
- hidden_sizes: [256, 256]
111
- algo_hyperparams: &pybullet-algo-defaults
112
- n_steps: 4000
113
- pi_lr: !!float 3e-4
114
- gamma: 0.99
115
- gae_lambda: 0.97
116
- val_lr: !!float 1e-3
117
- train_v_iters: 80
118
- max_grad_norm: 0.5
119
-
120
- AntBulletEnv-v0:
121
- <<: *pybullet-defaults
122
- policy_hyperparams:
123
- <<: *pybullet-policy-defaults
124
- hidden_sizes: [400, 300]
125
- algo_hyperparams:
126
- <<: *pybullet-algo-defaults
127
- pi_lr: !!float 7e-4
128
- val_lr: !!float 7e-3
129
-
130
- HopperBulletEnv-v0:
131
- <<: *pybullet-defaults
132
-
133
- Walker2DBulletEnv-v0:
134
- <<: *pybullet-defaults
135
-
136
- FrozenLake-v1:
137
- n_timesteps: !!float 8e5
138
- env_params:
139
- make_kwargs:
140
- map_name: 8x8
141
- is_slippery: true
142
- policy_hyperparams:
143
- hidden_sizes: [64]
144
- algo_hyperparams:
145
- n_steps: 2048
146
- pi_lr: 0.01
147
- gamma: 0.99
148
- gae_lambda: 0.98
149
- val_lr: 0.01
150
- train_v_iters: 80
151
- max_grad_norm: 0.5
152
- eval_params:
153
- step_freq: !!float 5e4
154
- n_episodes: 10
155
- save_best: true
156
-
157
- _atari: &atari-defaults
158
- n_timesteps: !!float 1e7
159
- env_hyperparams:
160
- n_envs: 4
161
- frame_stack: 4
162
- no_reward_timeout_steps: 1000
163
- no_reward_fire_steps: 500
164
- vec_env_class: subproc
165
- policy_hyperparams:
166
- activation_fn: relu
167
- algo_hyperparams:
168
- n_steps: 2048
169
- pi_lr: !!float 1e-4
170
- gamma: 0.99
171
- gae_lambda: 0.95
172
- val_lr: !!float 2e-4
173
- train_v_iters: 80
174
- max_grad_norm: 0.5
175
- eval_params:
176
- deterministic: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lambda_labs/benchmark.sh DELETED
@@ -1,32 +0,0 @@
1
- source benchmarks/train_loop.sh
2
-
3
- # export WANDB_PROJECT_NAME="rl-algo-impls"
4
-
5
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
6
-
7
- ALGOS=(
8
- # "vpg"
9
- # "dqn"
10
- "ppo"
11
- )
12
- ENVS=(
13
- # Basic
14
- "CartPole-v1"
15
- "MountainCar-v0"
16
- "MountainCarContinuous-v0"
17
- "Acrobot-v1"
18
- "LunarLander-v2"
19
- # PyBullet
20
- "HalfCheetahBulletEnv-v0"
21
- "AntBulletEnv-v0"
22
- "HopperBulletEnv-v0"
23
- "Walker2DBulletEnv-v0"
24
- # CarRacing
25
- "CarRacing-v0"
26
- # Atari
27
- "PongNoFrameskip-v4"
28
- "BreakoutNoFrameskip-v4"
29
- "SpaceInvadersNoFrameskip-v4"
30
- "QbertNoFrameskip-v4"
31
- )
32
- train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lambda_labs/impala_atari_benchmark.sh DELETED
@@ -1,19 +0,0 @@
1
- source benchmarks/train_loop.sh
2
-
3
- # export WANDB_PROJECT_NAME="rl-algo-impls"
4
-
5
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-5}"
6
-
7
- ALGOS=(
8
- # "vpg"
9
- # "dqn"
10
- "ppo"
11
- )
12
- ENVS=(
13
- "impala-PongNoFrameskip-v4"
14
- "impala-BreakoutNoFrameskip-v4"
15
- "impala-SpaceInvadersNoFrameskip-v4"
16
- "impala-QbertNoFrameskip-v4"
17
- "impala-CarRacing-v0"
18
- )
19
- train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lambda_labs/lambda_requirements.txt DELETED
@@ -1,16 +0,0 @@
1
- scipy >= 1.10.0, < 1.11
2
- tensorboard >= ^2.11.0, < 2.12
3
- AutoROM.accept-rom-license >= 0.4.2, < 0.5
4
- stable-baselines3[extra] >= 1.7.0, < 1.8
5
- gym[box2d] >= 0.21.0, < 0.22
6
- pyglet == 1.5.27
7
- wandb >= 0.13.10, < 0.14
8
- pyvirtualdisplay == 3.0
9
- pybullet >= 3.2.5, < 3.3
10
- tabulate >= 0.9.0, < 0.10
11
- huggingface-hub >= 0.12.0, < 0.13
12
- numexpr >= 2.8.4, < 2.9
13
- gym3 >= 0.3.3, < 0.4
14
- glfw >= 1.12.0, < 1.13
15
- procgen >= 0.10.7, < 0.11
16
- ipython >= 8.10.0, < 8.11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lambda_labs/procgen_benchmark.sh DELETED
@@ -1,18 +0,0 @@
1
- source benchmarks/train_loop.sh
2
-
3
- # export WANDB_PROJECT_NAME="rl-algo-impls"
4
-
5
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
6
-
7
- ALGOS=(
8
- # "vpg"
9
- # "dqn"
10
- "ppo"
11
- )
12
- ENVS=(
13
- "procgen-coinrun-easy"
14
- "procgen-starpilot-easy"
15
- "procgen-bossfight-easy"
16
- "procgen-bigfish-easy"
17
- )
18
- train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lambda_labs/setup.sh DELETED
@@ -1,10 +0,0 @@
1
- sudo apt update
2
- sudo apt install -y python-opengl
3
- sudo apt install -y ffmpeg
4
- sudo apt install -y xvfb
5
- sudo apt install -y swig
6
-
7
- python3 -m pip install --upgrade pip
8
- pip install --upgrade torch torchvision torchaudio
9
-
10
- pip install --upgrade -r ~/rl-algo-impls/lambda_labs/lambda_requirements.txt
 
 
 
 
 
 
 
 
 
 
 
lambda_labs/starpilot_hard_benchmark.sh DELETED
@@ -1,18 +0,0 @@
1
- source benchmarks/train_loop.sh
2
-
3
- # export WANDB_PROJECT_NAME="rl-algo-impls"
4
-
5
- BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
6
-
7
- ALGOS=(
8
- # "vpg"
9
- # "dqn"
10
- "ppo"
11
- )
12
- ENVS=(
13
- "procgen-starpilot-hard"
14
- "procgen-starpilot-hard-2xIMPALA"
15
- "procgen-starpilot-hard-2xIMPALA-fat"
16
- "procgen-starpilot-hard-4xIMPALA"
17
- )
18
- train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
poetry.lock DELETED
The diff for this file is too large to render. See raw diff
 
ppo/policy.py DELETED
@@ -1,31 +0,0 @@
1
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv
2
- from typing import Optional, Sequence
3
-
4
- from gym.spaces import Box, Discrete
5
- from shared.policy.on_policy import ActorCritic, default_hidden_sizes
6
-
7
-
8
- class PPOActorCritic(ActorCritic):
9
- def __init__(
10
- self,
11
- env: VecEnv,
12
- pi_hidden_sizes: Optional[Sequence[int]] = None,
13
- v_hidden_sizes: Optional[Sequence[int]] = None,
14
- **kwargs,
15
- ) -> None:
16
- pi_hidden_sizes = (
17
- pi_hidden_sizes
18
- if pi_hidden_sizes is not None
19
- else default_hidden_sizes(env.observation_space)
20
- )
21
- v_hidden_sizes = (
22
- v_hidden_sizes
23
- if v_hidden_sizes is not None
24
- else default_hidden_sizes(env.observation_space)
25
- )
26
- super().__init__(
27
- env,
28
- pi_hidden_sizes,
29
- v_hidden_sizes,
30
- **kwargs,
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ppo/ppo.py DELETED
@@ -1,311 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from dataclasses import asdict, dataclass, field
7
- from torch.optim import Adam
8
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
9
- from torch.utils.tensorboard.writer import SummaryWriter
10
- from typing import List, Optional, Sequence, NamedTuple, TypeVar
11
-
12
- from shared.algorithm import Algorithm
13
- from shared.callbacks.callback import Callback
14
- from shared.gae import compute_advantage, compute_rtg_and_advantage, RtgAdvantage
15
- from shared.policy.on_policy import ActorCritic
16
- from shared.schedule import constant_schedule, linear_schedule
17
- from shared.trajectory import Trajectory, TrajectoryAccumulator
18
-
19
-
20
- @dataclass
21
- class PPOTrajectory(Trajectory):
22
- logp_a: List[float] = field(default_factory=list)
23
-
24
- def add(
25
- self,
26
- obs: np.ndarray,
27
- act: np.ndarray,
28
- next_obs: np.ndarray,
29
- rew: float,
30
- terminated: bool,
31
- v: float,
32
- logp_a: float,
33
- ):
34
- super().add(obs, act, next_obs, rew, terminated, v)
35
- self.logp_a.append(logp_a)
36
-
37
-
38
- class PPOTrajectoryAccumulator(TrajectoryAccumulator):
39
- def __init__(self, num_envs: int) -> None:
40
- super().__init__(num_envs, PPOTrajectory)
41
-
42
- def step(
43
- self,
44
- obs: VecEnvObs,
45
- action: np.ndarray,
46
- next_obs: VecEnvObs,
47
- reward: np.ndarray,
48
- done: np.ndarray,
49
- val: np.ndarray,
50
- logp_a: np.ndarray,
51
- ) -> None:
52
- super().step(obs, action, next_obs, reward, done, val, logp_a)
53
-
54
-
55
- class TrainStepStats(NamedTuple):
56
- loss: float
57
- pi_loss: float
58
- v_loss: float
59
- entropy_loss: float
60
- approx_kl: float
61
- clipped_frac: float
62
-
63
-
64
- @dataclass
65
- class TrainStats:
66
- loss: float
67
- pi_loss: float
68
- v_loss: float
69
- entropy_loss: float
70
- approx_kl: float
71
- clipped_frac: float
72
-
73
- def __init__(self, step_stats: List[TrainStepStats]) -> None:
74
- self.loss = np.mean([s.loss for s in step_stats]).item()
75
- self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
76
- self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
77
- self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
78
- self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
79
- self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
80
-
81
- def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
82
- tb_writer.add_scalars("losses", asdict(self), global_step=global_step)
83
-
84
- def __repr__(self) -> str:
85
- return " | ".join(
86
- [
87
- f"Loss: {round(self.loss, 2)}",
88
- f"Pi L: {round(self.pi_loss, 2)}",
89
- f"V L: {round(self.v_loss, 2)}",
90
- f"E L: {round(self.entropy_loss, 2)}",
91
- f"Apx KL Div: {round(self.approx_kl, 2)}",
92
- f"Clip Frac: {round(self.clipped_frac, 2)}",
93
- ]
94
- )
95
-
96
-
97
- PPOSelf = TypeVar("PPOSelf", bound="PPO")
98
-
99
-
100
- class PPO(Algorithm):
101
- def __init__(
102
- self,
103
- policy: ActorCritic,
104
- env: VecEnv,
105
- device: torch.device,
106
- tb_writer: SummaryWriter,
107
- learning_rate: float = 3e-4,
108
- learning_rate_decay: str = "none",
109
- n_steps: int = 2048,
110
- batch_size: int = 64,
111
- n_epochs: int = 10,
112
- gamma: float = 0.99,
113
- gae_lambda: float = 0.95,
114
- clip_range: float = 0.2,
115
- clip_range_decay: str = "none",
116
- clip_range_vf: Optional[float] = None,
117
- clip_range_vf_decay: str = "none",
118
- normalize_advantage: bool = True,
119
- ent_coef: float = 0.0,
120
- ent_coef_decay: str = "none",
121
- vf_coef: float = 0.5,
122
- ppo2_vf_coef_halving: bool = False,
123
- max_grad_norm: float = 0.5,
124
- update_rtg_between_epochs: bool = False,
125
- sde_sample_freq: int = -1,
126
- ) -> None:
127
- super().__init__(policy, env, device, tb_writer)
128
- self.policy = policy
129
-
130
- self.gamma = gamma
131
- self.gae_lambda = gae_lambda
132
- self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
133
- self.lr_schedule = (
134
- linear_schedule(learning_rate, 0)
135
- if learning_rate_decay == "linear"
136
- else constant_schedule(learning_rate)
137
- )
138
- self.max_grad_norm = max_grad_norm
139
- self.clip_range_schedule = (
140
- linear_schedule(clip_range, 0)
141
- if clip_range_decay == "linear"
142
- else constant_schedule(clip_range)
143
- )
144
- self.clip_range_vf_schedule = None
145
- if clip_range_vf:
146
- self.clip_range_vf_schedule = (
147
- linear_schedule(clip_range_vf, 0)
148
- if clip_range_vf_decay == "linear"
149
- else constant_schedule(clip_range_vf)
150
- )
151
- self.normalize_advantage = normalize_advantage
152
- self.ent_coef_schedule = (
153
- linear_schedule(ent_coef, 0)
154
- if ent_coef_decay == "linear"
155
- else constant_schedule(ent_coef)
156
- )
157
- self.vf_coef = vf_coef
158
- self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
159
-
160
- self.n_steps = n_steps
161
- self.batch_size = batch_size
162
- self.n_epochs = n_epochs
163
- self.sde_sample_freq = sde_sample_freq
164
-
165
- self.update_rtg_between_epochs = update_rtg_between_epochs
166
-
167
- def learn(
168
- self: PPOSelf,
169
- total_timesteps: int,
170
- callback: Optional[Callback] = None,
171
- ) -> PPOSelf:
172
- obs = self.env.reset()
173
- ts_elapsed = 0
174
- while ts_elapsed < total_timesteps:
175
- accumulator = self._collect_trajectories(obs)
176
- progress = ts_elapsed / total_timesteps
177
- train_stats = self.train(accumulator.all_trajectories, progress)
178
- rollout_steps = self.n_steps * self.env.num_envs
179
- ts_elapsed += rollout_steps
180
- train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
181
- if callback:
182
- callback.on_step(timesteps_elapsed=rollout_steps)
183
-
184
- return self
185
-
186
- def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
187
- self.policy.eval()
188
- accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
189
- self.policy.reset_noise()
190
- for i in range(self.n_steps):
191
- if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
192
- self.policy.reset_noise()
193
- action, value, logp_a, clamped_action = self.policy.step(obs)
194
- next_obs, reward, done, _ = self.env.step(clamped_action)
195
- accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
196
- obs = next_obs
197
- return accumulator
198
-
199
- def train(self, trajectories: List[PPOTrajectory], progress: float) -> TrainStats:
200
- self.policy.train()
201
- learning_rate = self.lr_schedule(progress)
202
- self.optimizer.param_groups[0]["lr"] = learning_rate
203
-
204
- pi_clip = self.clip_range_schedule(progress)
205
- v_clip = (
206
- self.clip_range_vf_schedule(progress)
207
- if self.clip_range_vf_schedule
208
- else None
209
- )
210
- ent_coef = self.ent_coef_schedule(progress)
211
-
212
- obs = torch.as_tensor(
213
- np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
214
- )
215
- act = torch.as_tensor(
216
- np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
217
- )
218
- rtg, adv = compute_rtg_and_advantage(
219
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
220
- )
221
- orig_v = torch.as_tensor(
222
- np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
223
- )
224
- orig_logp_a = torch.as_tensor(
225
- np.concatenate([np.array(t.logp_a) for t in trajectories]),
226
- device=self.device,
227
- )
228
-
229
- step_stats = []
230
- for _ in range(self.n_epochs):
231
- if self.update_rtg_between_epochs:
232
- rtg, adv = compute_rtg_and_advantage(
233
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
234
- )
235
- else:
236
- adv = compute_advantage(
237
- trajectories, self.policy, self.gamma, self.gae_lambda, self.device
238
- )
239
- idxs = torch.randperm(len(obs))
240
- for i in range(0, len(obs), self.batch_size):
241
- mb_idxs = idxs[i : i + self.batch_size]
242
- mb_adv = adv[mb_idxs]
243
- if self.normalize_advantage:
244
- mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
245
- step_stats.append(
246
- self._train_step(
247
- pi_clip,
248
- v_clip,
249
- ent_coef,
250
- obs[mb_idxs],
251
- act[mb_idxs],
252
- rtg[mb_idxs],
253
- mb_adv,
254
- orig_v[mb_idxs],
255
- orig_logp_a[mb_idxs],
256
- )
257
- )
258
-
259
- return TrainStats(step_stats)
260
-
261
- def _train_step(
262
- self,
263
- pi_clip: float,
264
- v_clip: Optional[float],
265
- ent_coef: float,
266
- obs: torch.Tensor,
267
- act: torch.Tensor,
268
- rtg: torch.Tensor,
269
- adv: torch.Tensor,
270
- orig_v: torch.Tensor,
271
- orig_logp_a: torch.Tensor,
272
- ) -> TrainStepStats:
273
- logp_a, entropy, v = self.policy(obs, act)
274
- logratio = logp_a - orig_logp_a
275
- ratio = torch.exp(logratio)
276
- clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
277
- pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
278
-
279
- v_loss_unclipped = (v - rtg) ** 2
280
- if v_clip:
281
- v_loss_clipped = (
282
- orig_v + torch.clamp(v - orig_v, -v_clip, v_clip) - rtg
283
- ) ** 2
284
- v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
285
- else:
286
- v_loss = v_loss_unclipped.mean()
287
- if self.ppo2_vf_coef_halving:
288
- v_loss *= 0.5
289
-
290
- entropy_loss = entropy.mean()
291
-
292
- loss = pi_loss - ent_coef * entropy_loss + self.vf_coef * v_loss
293
-
294
- self.optimizer.zero_grad()
295
- loss.backward()
296
- nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
297
- self.optimizer.step()
298
-
299
- with torch.no_grad():
300
- approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
301
- clipped_frac = (
302
- ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
303
- )
304
- return TrainStepStats(
305
- loss.item(),
306
- pi_loss.item(),
307
- v_loss.item(),
308
- entropy_loss.item(),
309
- approx_kl,
310
- clipped_frac,
311
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
publish/markdown_format.py DELETED
@@ -1,210 +0,0 @@
1
- import os
2
- import pandas as pd
3
- import wandb.apis.public
4
- import yaml
5
-
6
- from collections import defaultdict
7
- from dataclasses import dataclass, asdict
8
- from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
9
- from urllib.parse import urlparse
10
-
11
- from runner.evaluate import Evaluation
12
-
13
- EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
14
-
15
-
16
- @dataclass
17
- class EvaluationRow:
18
- algo: str
19
- env: str
20
- seed: Optional[int]
21
- reward_mean: float
22
- reward_std: float
23
- eval_episodes: int
24
- best: str
25
- wandb_url: str
26
-
27
- @staticmethod
28
- def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
29
- results = defaultdict(list)
30
- for r in rows:
31
- for k, v in asdict(r).items():
32
- results[k].append(v)
33
- return pd.DataFrame(results)
34
-
35
-
36
- class EvalTableData(NamedTuple):
37
- run: wandb.apis.public.Run
38
- evaluation: Evaluation
39
-
40
-
41
- def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
42
- best_stats = sorted(
43
- [d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
44
- )[0]
45
- table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
46
- rows = [
47
- EvaluationRow(
48
- config.algo,
49
- config.env_id,
50
- config.seed(),
51
- stats.score.mean,
52
- stats.score.std,
53
- len(stats),
54
- "*" if stats == best_stats else "",
55
- f"[wandb]({r.url})",
56
- )
57
- for (r, (_, stats, config)) in table_data
58
- ]
59
- df = EvaluationRow.data_frame(rows)
60
- return df.to_markdown(index=False)
61
-
62
-
63
- def github_project_link(github_url: str) -> str:
64
- return f"[{urlparse(github_url).path}]({github_url})"
65
-
66
-
67
- def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
68
- algo_caps = algo.upper()
69
- lines = [
70
- f"# **{algo_caps}** Agent playing **{env}**",
71
- f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
72
- f"the {github_project_link(github_url)} repo.",
73
- f"All models trained at this commit can be found at {wandb_report_url}.",
74
- ]
75
- return "\n\n".join(lines)
76
-
77
-
78
- def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
79
- if not commit_hash:
80
- return github_project_link(github_url)
81
- return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
82
-
83
-
84
- def results_section(
85
- table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
86
- ) -> str:
87
- # type: ignore
88
- lines = [
89
- "## Training Results",
90
- f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
91
- + "agents using different initial seeds. "
92
- + f"These agents were trained by checking out "
93
- + f"{github_tree_link(github_url, commit_hash)}. "
94
- + "The best and last models were kept from each training. "
95
- + "This submission has loaded the best models from each training, reevaluates "
96
- + "them, and selects the best model from these latest evaluations (mean - std).",
97
- ]
98
- lines.append(evaluation_table(table_data))
99
- return "\n\n".join(lines)
100
-
101
-
102
- def prerequisites_section() -> str:
103
- return """
104
- ### Prerequisites: Weights & Biases (WandB)
105
- Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
106
- By default training goes to a rl-algo-impls project while benchmarks go to
107
- rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
108
- models and the model weights are uploaded to WandB.
109
-
110
- Before doing anything below, you'll need to create a wandb account and run `wandb
111
- login`.
112
- """
113
-
114
-
115
- def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
116
- return f"""
117
- ## Usage
118
- {urlparse(github_url).path}: {github_url}
119
-
120
- Note: While the model state dictionary and hyperaparameters are saved, the latest
121
- implementation could be sufficiently different to not be able to reproduce similar
122
- results. You might need to checkout the commit the agent was trained on:
123
- {github_tree_link(github_url, commit_hash)}.
124
- ```
125
- # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
126
- python enjoy.py --wandb-run-path={run_path}
127
- ```
128
-
129
- Setup hasn't been completely worked out yet, so you might be best served by using Google
130
- Colab starting from the
131
- [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
132
- notebook.
133
- """
134
-
135
-
136
- def training_setion(
137
- github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
138
- ) -> str:
139
- return f"""
140
- ## Training
141
- If you want the highest chance to reproduce these results, you'll want to checkout the
142
- commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
143
- training is deterministic, different hardware will give different results.
144
-
145
- ```
146
- python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
147
- ```
148
-
149
- Setup hasn't been completely worked out yet, so you might be best served by using Google
150
- Colab starting from the
151
- [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
152
- notebook.
153
- """
154
-
155
-
156
- def benchmarking_section(report_url: str) -> str:
157
- return f"""
158
- ## Benchmarking (with Lambda Labs instance)
159
- This and other models from {report_url} were generated by running a script on a Lambda
160
- Labs instance. In a Lambda Labs instance terminal:
161
- ```
162
- git clone git@github.com:sgoodfriend/rl-algo-impls.git
163
- cd rl-algo-impls
164
- bash ./lambda_labs/setup.sh
165
- wandb login
166
- bash ./lambda_labs/benchmark.sh
167
- ```
168
-
169
- ### Alternative: Google Colab Pro+
170
- As an alternative,
171
- [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
172
- can be used. However, this requires a Google Colab Pro+ subscription and running across
173
- 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
174
- """
175
-
176
-
177
- def hyperparams_section(run_config: Dict[str, Any]) -> str:
178
- return f"""
179
- ## Hyperparameters
180
- This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
181
- run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
182
- close and has some additional data:
183
- ```
184
- {yaml.dump(run_config)}
185
- ```
186
- """
187
-
188
-
189
- def model_card_text(
190
- algo: str,
191
- env: str,
192
- github_url: str,
193
- commit_hash: str,
194
- wandb_report_url: str,
195
- table_data: List[EvalTableData],
196
- best_eval: EvalTableData,
197
- ) -> str:
198
- run, (_, _, config) = best_eval
199
- run_path = "/".join(run.path)
200
- return "\n\n".join(
201
- [
202
- header_section(algo, env, github_url, wandb_report_url),
203
- results_section(table_data, algo, github_url, commit_hash),
204
- prerequisites_section(),
205
- usage_section(github_url, run_path, commit_hash),
206
- training_setion(github_url, commit_hash, algo, env, config.seed()),
207
- benchmarking_section(wandb_report_url),
208
- hyperparams_section(run.config),
209
- ]
210
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
replay.meta.json CHANGED
@@ -1 +1 @@
1
- {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/tmp/tmp60zfhp1t/vpg-MountainCarContinuous-v0/replay.mp4"]}, "episode": {"r": 99.03748321533203, "l": 757, "t": 4.666058}}
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/tmp/tmp1r6b3tjy/vpg-MountainCarContinuous-v0/replay.mp4"]}, "episode": {"r": 99.03748321533203, "l": 757, "t": 4.424438}}
rl_algo_impls/benchmark_publish.py CHANGED
@@ -54,8 +54,8 @@ def benchmark_publish() -> None:
54
  "--virtual-display", action="store_true", help="Use headless virtual display"
55
  )
56
  # parser.set_defaults(
57
- # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
58
- # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
59
  # envs=[],
60
  # exclude_envs=[],
61
  # )
 
54
  "--virtual-display", action="store_true", help="Use headless virtual display"
55
  )
56
  # parser.set_defaults(
57
+ # wandb_tags=["benchmark_2067e21", "host_155-248-199-228"],
58
+ # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/09frjfcs",
59
  # envs=[],
60
  # exclude_envs=[],
61
  # )
rl_algo_impls/huggingface_publish.py CHANGED
@@ -162,6 +162,7 @@ def publish(
162
  path_in_repo="",
163
  commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
164
  token=huggingface_token,
 
165
  )
166
  print(f"Pushed model to the hub: {repo_url}")
167
 
 
162
  path_in_repo="",
163
  commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
164
  token=huggingface_token,
165
+ delete_patterns="*",
166
  )
167
  print(f"Pushed model to the hub: {repo_url}")
168
 
runner/config.py DELETED
@@ -1,154 +0,0 @@
1
- import os
2
-
3
- from datetime import datetime
4
- from dataclasses import dataclass
5
- from typing import Any, Dict, NamedTuple, Optional, TypedDict, Union
6
-
7
-
8
- @dataclass
9
- class RunArgs:
10
- algo: str
11
- env: str
12
- seed: Optional[int] = None
13
- use_deterministic_algorithms: bool = True
14
-
15
-
16
- class EnvHyperparams(NamedTuple):
17
- is_procgen: bool = False
18
- n_envs: int = 1
19
- frame_stack: int = 1
20
- make_kwargs: Optional[Dict[str, Any]] = None
21
- no_reward_timeout_steps: Optional[int] = None
22
- no_reward_fire_steps: Optional[int] = None
23
- vec_env_class: str = "dummy"
24
- normalize: bool = False
25
- normalize_kwargs: Optional[Dict[str, Any]] = None
26
- rolling_length: int = 100
27
- train_record_video: bool = False
28
- video_step_interval: Union[int, float] = 1_000_000
29
- initial_steps_to_truncate: Optional[int] = None
30
-
31
-
32
- class Hyperparams(TypedDict, total=False):
33
- device: str
34
- n_timesteps: Union[int, float]
35
- env_hyperparams: Dict[str, Any]
36
- policy_hyperparams: Dict[str, Any]
37
- algo_hyperparams: Dict[str, Any]
38
- eval_params: Dict[str, Any]
39
-
40
-
41
- @dataclass
42
- class Config:
43
- args: RunArgs
44
- hyperparams: Hyperparams
45
- root_dir: str
46
- run_id: str = datetime.now().isoformat()
47
-
48
- def seed(self, training: bool = True) -> Optional[int]:
49
- seed = self.args.seed
50
- if training or seed is None:
51
- return seed
52
- return seed + self.env_hyperparams.get("n_envs", 1)
53
-
54
- @property
55
- def device(self) -> str:
56
- return self.hyperparams.get("device", "auto")
57
-
58
- @property
59
- def n_timesteps(self) -> int:
60
- return int(self.hyperparams.get("n_timesteps", 100_000))
61
-
62
- @property
63
- def env_hyperparams(self) -> Dict[str, Any]:
64
- return self.hyperparams.get("env_hyperparams", {})
65
-
66
- @property
67
- def policy_hyperparams(self) -> Dict[str, Any]:
68
- return self.hyperparams.get("policy_hyperparams", {})
69
-
70
- @property
71
- def algo_hyperparams(self) -> Dict[str, Any]:
72
- return self.hyperparams.get("algo_hyperparams", {})
73
-
74
- @property
75
- def eval_params(self) -> Dict[str, Any]:
76
- return self.hyperparams.get("eval_params", {})
77
-
78
- @property
79
- def algo(self) -> str:
80
- return self.args.algo
81
-
82
- @property
83
- def env_id(self) -> str:
84
- return self.hyperparams.get("env_id") or self.args.env
85
-
86
- def model_name(self, include_seed: bool = True) -> str:
87
- # Use arg env name instead of environment name
88
- parts = [self.algo, self.args.env]
89
- if include_seed and self.args.seed is not None:
90
- parts.append(f"S{self.args.seed}")
91
-
92
- # Assume that the custom arg name already has the necessary information
93
- if not self.hyperparams.get("env_id"):
94
- make_kwargs = self.env_hyperparams.get("make_kwargs", {})
95
- if make_kwargs:
96
- for k, v in make_kwargs.items():
97
- if type(v) == bool and v:
98
- parts.append(k)
99
- elif type(v) == int and v:
100
- parts.append(f"{k}{v}")
101
- else:
102
- parts.append(str(v))
103
-
104
- return "-".join(parts)
105
-
106
- @property
107
- def run_name(self) -> str:
108
- parts = [self.model_name(), self.run_id]
109
- return "-".join(parts)
110
-
111
- @property
112
- def saved_models_dir(self) -> str:
113
- return os.path.join(self.root_dir, "saved_models")
114
-
115
- @property
116
- def downloaded_models_dir(self) -> str:
117
- return os.path.join(self.root_dir, "downloaded_models")
118
-
119
- def model_dir_name(
120
- self,
121
- best: bool = False,
122
- extension: str = "",
123
- ) -> str:
124
- return self.model_name() + ("-best" if best else "") + extension
125
-
126
- def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
127
- return os.path.join(
128
- self.saved_models_dir if not downloaded else self.downloaded_models_dir,
129
- self.model_dir_name(best=best),
130
- )
131
-
132
- @property
133
- def runs_dir(self) -> str:
134
- return os.path.join(self.root_dir, "runs")
135
-
136
- @property
137
- def tensorboard_summary_path(self) -> str:
138
- return os.path.join(self.runs_dir, self.run_name)
139
-
140
- @property
141
- def logs_path(self) -> str:
142
- return os.path.join(self.runs_dir, f"log.yml")
143
-
144
- @property
145
- def videos_dir(self) -> str:
146
- return os.path.join(self.root_dir, "videos")
147
-
148
- @property
149
- def video_prefix(self) -> str:
150
- return os.path.join(self.videos_dir, self.model_name())
151
-
152
- @property
153
- def best_videos_dir(self) -> str:
154
- return os.path.join(self.videos_dir, f"{self.model_name()}-best")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
runner/env.py DELETED
@@ -1,256 +0,0 @@
1
- import gym
2
- import numpy as np
3
- import os
4
-
5
- from gym.wrappers.resize_observation import ResizeObservation
6
- from gym.wrappers.gray_scale_observation import GrayScaleObservation
7
- from gym.wrappers.frame_stack import FrameStack
8
- from procgen.env import ProcgenEnv
9
- from stable_baselines3.common.atari_wrappers import (
10
- MaxAndSkipEnv,
11
- NoopResetEnv,
12
- )
13
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv
14
- from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
15
- from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
16
- from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
17
- from torch.utils.tensorboard.writer import SummaryWriter
18
- from typing import Callable, Optional, Union
19
-
20
- from runner.config import Config, EnvHyperparams
21
- from shared.policy.policy import VEC_NORMALIZE_FILENAME
22
- from wrappers.atari_wrappers import EpisodicLifeEnv, FireOnLifeStarttEnv, ClipRewardEnv
23
- from wrappers.episode_record_video import EpisodeRecordVideo
24
- from wrappers.episode_stats_writer import EpisodeStatsWriter
25
- from wrappers.get_rgb_observation import GetRgbObservation
26
- from wrappers.initial_step_truncate_wrapper import InitialStepTruncateWrapper
27
- from wrappers.is_vector_env import IsVectorEnv
28
- from wrappers.noop_env_seed import NoopEnvSeed
29
- from wrappers.transpose_image_observation import TransposeImageObservation
30
- from wrappers.video_compat_wrapper import VideoCompatWrapper
31
-
32
- GeneralVecEnv = Union[VecEnv, gym.vector.VectorEnv, gym.Wrapper]
33
-
34
-
35
- def make_env(
36
- config: Config,
37
- hparams: EnvHyperparams,
38
- training: bool = True,
39
- render: bool = False,
40
- normalize_load_path: Optional[str] = None,
41
- tb_writer: Optional[SummaryWriter] = None,
42
- ) -> GeneralVecEnv:
43
- if hparams.is_procgen:
44
- return _make_procgen_env(
45
- config,
46
- hparams,
47
- training=training,
48
- render=render,
49
- normalize_load_path=normalize_load_path,
50
- tb_writer=tb_writer,
51
- )
52
- else:
53
- return _make_vec_env(
54
- config,
55
- hparams,
56
- training=training,
57
- render=render,
58
- normalize_load_path=normalize_load_path,
59
- tb_writer=tb_writer,
60
- )
61
-
62
-
63
- def make_eval_env(
64
- config: Config,
65
- hparams: EnvHyperparams,
66
- override_n_envs: Optional[int] = None,
67
- **kwargs
68
- ) -> GeneralVecEnv:
69
- kwargs = kwargs.copy()
70
- kwargs["training"] = False
71
- if override_n_envs is not None:
72
- hparams_kwargs = hparams._asdict()
73
- hparams_kwargs["n_envs"] = override_n_envs
74
- if override_n_envs == 1:
75
- hparams_kwargs["vec_env_class"] = "dummy"
76
- hparams = EnvHyperparams(**hparams_kwargs)
77
- return make_env(config, hparams, **kwargs)
78
-
79
-
80
- def _make_vec_env(
81
- config: Config,
82
- hparams: EnvHyperparams,
83
- training: bool = True,
84
- render: bool = False,
85
- normalize_load_path: Optional[str] = None,
86
- tb_writer: Optional[SummaryWriter] = None,
87
- ) -> GeneralVecEnv:
88
- (
89
- _,
90
- n_envs,
91
- frame_stack,
92
- make_kwargs,
93
- no_reward_timeout_steps,
94
- no_reward_fire_steps,
95
- vec_env_class,
96
- normalize,
97
- normalize_kwargs,
98
- rolling_length,
99
- train_record_video,
100
- video_step_interval,
101
- initial_steps_to_truncate,
102
- ) = hparams
103
-
104
- if "BulletEnv" in config.env_id:
105
- import pybullet_envs
106
-
107
- spec = gym.spec(config.env_id)
108
- seed = config.seed(training=training)
109
-
110
- def make(idx: int) -> Callable[[], gym.Env]:
111
- env_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
112
- if "BulletEnv" in config.env_id and render:
113
- env_kwargs["render"] = True
114
- if "CarRacing" in config.env_id:
115
- env_kwargs["verbose"] = 0
116
- if "procgen" in config.env_id:
117
- if not render:
118
- env_kwargs["render_mode"] = "rgb_array"
119
-
120
- def _make() -> gym.Env:
121
- env = gym.make(config.env_id, **env_kwargs)
122
- env = gym.wrappers.RecordEpisodeStatistics(env)
123
- env = VideoCompatWrapper(env)
124
- if training and train_record_video and idx == 0:
125
- env = EpisodeRecordVideo(
126
- env,
127
- config.video_prefix,
128
- step_increment=n_envs,
129
- video_step_interval=int(video_step_interval),
130
- )
131
- if training and initial_steps_to_truncate:
132
- env = InitialStepTruncateWrapper(
133
- env, idx * initial_steps_to_truncate // n_envs
134
- )
135
- if "AtariEnv" in spec.entry_point: # type: ignore
136
- env = NoopResetEnv(env, noop_max=30)
137
- env = MaxAndSkipEnv(env, skip=4)
138
- env = EpisodicLifeEnv(env, training=training)
139
- action_meanings = env.unwrapped.get_action_meanings()
140
- if "FIRE" in action_meanings: # type: ignore
141
- env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
142
- env = ClipRewardEnv(env, training=training)
143
- env = ResizeObservation(env, (84, 84))
144
- env = GrayScaleObservation(env, keep_dim=False)
145
- env = FrameStack(env, frame_stack)
146
- elif "CarRacing" in config.env_id:
147
- env = ResizeObservation(env, (64, 64))
148
- env = GrayScaleObservation(env, keep_dim=False)
149
- env = FrameStack(env, frame_stack)
150
- elif "procgen" in config.env_id:
151
- # env = GrayScaleObservation(env, keep_dim=False)
152
- env = NoopEnvSeed(env)
153
- env = TransposeImageObservation(env)
154
- if frame_stack > 1:
155
- env = FrameStack(env, frame_stack)
156
-
157
- if no_reward_timeout_steps:
158
- from wrappers.no_reward_timeout import NoRewardTimeout
159
-
160
- env = NoRewardTimeout(
161
- env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
162
- )
163
-
164
- if seed is not None:
165
- env.seed(seed + idx)
166
- env.action_space.seed(seed + idx)
167
- env.observation_space.seed(seed + idx)
168
-
169
- return env
170
-
171
- return _make
172
-
173
- VecEnvClass = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_class]
174
- venv = VecEnvClass([make(i) for i in range(n_envs)])
175
- if training:
176
- assert tb_writer
177
- venv = EpisodeStatsWriter(
178
- venv, tb_writer, training=training, rolling_length=rolling_length
179
- )
180
- if normalize:
181
- if normalize_load_path:
182
- venv = VecNormalize.load(
183
- os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
184
- venv, # type: ignore
185
- )
186
- else:
187
- venv = VecNormalize(
188
- venv, # type: ignore
189
- training=training,
190
- **(normalize_kwargs or {}),
191
- )
192
- if not training:
193
- venv.norm_reward = False
194
- return venv
195
-
196
-
197
- def _make_procgen_env(
198
- config: Config,
199
- hparams: EnvHyperparams,
200
- training: bool = True,
201
- render: bool = False,
202
- normalize_load_path: Optional[str] = None,
203
- tb_writer: Optional[SummaryWriter] = None,
204
- ) -> GeneralVecEnv:
205
- (
206
- _,
207
- n_envs,
208
- frame_stack,
209
- make_kwargs,
210
- _, # no_reward_timeout_steps
211
- _, # no_reward_fire_steps
212
- _, # vec_env_class
213
- normalize,
214
- normalize_kwargs,
215
- rolling_length,
216
- _, # train_record_video
217
- _, # video_step_interval
218
- _, # initial_steps_to_truncate
219
- ) = hparams
220
-
221
- seed = config.seed(training=training)
222
-
223
- make_kwargs = make_kwargs or {}
224
- if not render:
225
- make_kwargs["render_mode"] = "rgb_array"
226
- if seed is not None:
227
- make_kwargs["rand_seed"] = seed
228
-
229
- envs = ProcgenEnv(n_envs, config.env_id, **make_kwargs)
230
- envs = IsVectorEnv(envs)
231
- envs = GetRgbObservation(envs)
232
- # TODO: Handle Grayscale and/or FrameStack
233
- envs = TransposeImageObservation(envs)
234
-
235
- envs = gym.wrappers.RecordEpisodeStatistics(envs)
236
-
237
- if seed is not None:
238
- envs.action_space.seed(seed)
239
- envs.observation_space.seed(seed)
240
-
241
- if training:
242
- assert tb_writer
243
- envs = EpisodeStatsWriter(
244
- envs, tb_writer, training=training, rolling_length=rolling_length
245
- )
246
- if normalize and training:
247
- normalize_kwargs = normalize_kwargs or {}
248
- # TODO: Handle reward stats saving/loading/syncing, but it's only important
249
- # for checkpointing
250
- envs = gym.wrappers.NormalizeReward(envs)
251
- clip_obs = normalize_kwargs.get("clip_reward", 10.0)
252
- envs = gym.wrappers.TransformReward(
253
- envs, lambda r: np.clip(r, -clip_obs, clip_obs)
254
- )
255
-
256
- return envs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
runner/evaluate.py DELETED
@@ -1,103 +0,0 @@
1
- import os
2
- import shutil
3
-
4
- from dataclasses import dataclass
5
- from typing import NamedTuple, Optional
6
-
7
- from runner.env import make_eval_env
8
- from runner.config import Config, EnvHyperparams, RunArgs
9
- from runner.running_utils import (
10
- load_hyperparams,
11
- set_seeds,
12
- get_device,
13
- make_policy,
14
- )
15
- from shared.callbacks.eval_callback import evaluate
16
- from shared.policy.policy import Policy
17
- from shared.stats import EpisodesStats
18
-
19
-
20
- @dataclass
21
- class EvalArgs(RunArgs):
22
- render: bool = True
23
- best: bool = True
24
- n_envs: Optional[int] = 1
25
- n_episodes: int = 3
26
- deterministic_eval: Optional[bool] = None
27
- no_print_returns: bool = False
28
- wandb_run_path: Optional[str] = None
29
-
30
-
31
- class Evaluation(NamedTuple):
32
- policy: Policy
33
- stats: EpisodesStats
34
- config: Config
35
-
36
-
37
- def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
38
- if args.wandb_run_path:
39
- import wandb
40
-
41
- api = wandb.Api()
42
- run = api.run(args.wandb_run_path)
43
- hyperparams = run.config
44
-
45
- args.algo = hyperparams["algo"]
46
- args.env = hyperparams["env"]
47
- args.seed = hyperparams.get("seed", None)
48
- args.use_deterministic_algorithms = hyperparams.get(
49
- "use_deterministic_algorithms", True
50
- )
51
-
52
- config = Config(args, hyperparams, root_dir)
53
- model_path = config.model_dir_path(best=args.best, downloaded=True)
54
-
55
- model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
56
- run.file(model_archive_name).download()
57
- if os.path.isdir(model_path):
58
- shutil.rmtree(model_path)
59
- shutil.unpack_archive(model_archive_name, model_path)
60
- os.remove(model_archive_name)
61
- else:
62
- hyperparams = load_hyperparams(args.algo, args.env, root_dir)
63
-
64
- config = Config(args, hyperparams, root_dir)
65
- model_path = config.model_dir_path(best=args.best)
66
-
67
- print(args)
68
-
69
- set_seeds(args.seed, args.use_deterministic_algorithms)
70
-
71
- env = make_eval_env(
72
- config,
73
- EnvHyperparams(**config.env_hyperparams),
74
- override_n_envs=args.n_envs,
75
- render=args.render,
76
- normalize_load_path=model_path,
77
- )
78
- device = get_device(config.device, env)
79
- policy = make_policy(
80
- args.algo,
81
- env,
82
- device,
83
- load_path=model_path,
84
- **config.policy_hyperparams,
85
- ).eval()
86
-
87
- deterministic = (
88
- args.deterministic_eval
89
- if args.deterministic_eval is not None
90
- else config.eval_params.get("deterministic", True)
91
- )
92
- return Evaluation(
93
- policy,
94
- evaluate(
95
- env,
96
- policy,
97
- args.n_episodes,
98
- render=args.render,
99
- deterministic=deterministic,
100
- print_returns=not args.no_print_returns,
101
- ),
102
- config,
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
runner/running_utils.py DELETED
@@ -1,192 +0,0 @@
1
- import argparse
2
- import gym
3
- import json
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
- import os
7
- import random
8
- import torch
9
- import torch.backends.cudnn
10
- import yaml
11
-
12
- from gym.spaces import Box, Discrete
13
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv
14
- from torch.utils.tensorboard.writer import SummaryWriter
15
- from typing import Dict, Optional, Type, Union
16
-
17
- from runner.config import Hyperparams
18
- from shared.algorithm import Algorithm
19
- from shared.callbacks.eval_callback import EvalCallback
20
- from shared.policy.policy import Policy
21
-
22
- from dqn.dqn import DQN
23
- from dqn.policy import DQNPolicy
24
- from vpg.vpg import VanillaPolicyGradient
25
- from vpg.policy import VPGActorCritic
26
- from ppo.ppo import PPO
27
- from ppo.policy import PPOActorCritic
28
-
29
- ALGOS: Dict[str, Type[Algorithm]] = {
30
- "dqn": DQN,
31
- "vpg": VanillaPolicyGradient,
32
- "ppo": PPO,
33
- }
34
- POLICIES: Dict[str, Type[Policy]] = {
35
- "dqn": DQNPolicy,
36
- "vpg": VPGActorCritic,
37
- "ppo": PPOActorCritic,
38
- }
39
-
40
- HYPERPARAMS_PATH = "hyperparams"
41
-
42
-
43
- def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
44
- parser = argparse.ArgumentParser()
45
- parser.add_argument(
46
- "--algo",
47
- default=["dqn"],
48
- type=str,
49
- choices=list(ALGOS.keys()),
50
- nargs="+" if multiple else 1,
51
- help="Abbreviation(s) of algorithm(s)",
52
- )
53
- parser.add_argument(
54
- "--env",
55
- default=["CartPole-v1"],
56
- type=str,
57
- nargs="+" if multiple else 1,
58
- help="Name of environment(s) in gym",
59
- )
60
- parser.add_argument(
61
- "--seed",
62
- default=[1],
63
- type=int,
64
- nargs="*" if multiple else "?",
65
- help="Seeds to run experiment. Unset will do one run with no set seed",
66
- )
67
- parser.add_argument(
68
- "--use-deterministic-algorithms",
69
- default=True,
70
- type=bool,
71
- help="If seed set, set torch.use_deterministic_algorithms",
72
- )
73
- return parser
74
-
75
-
76
- def load_hyperparams(algo: str, env_id: str, root_path: str) -> Hyperparams:
77
- hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
78
- with open(hyperparams_path, "r") as f:
79
- hyperparams_dict = yaml.safe_load(f)
80
-
81
- if env_id in hyperparams_dict:
82
- return hyperparams_dict[env_id]
83
-
84
- if "BulletEnv" in env_id:
85
- import pybullet_envs
86
- spec = gym.spec(env_id)
87
- if "AtariEnv" in str(spec.entry_point) and "_atari" in hyperparams_dict:
88
- return hyperparams_dict["_atari"]
89
- else:
90
- raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
91
-
92
-
93
- def get_device(device: str, env: VecEnv) -> torch.device:
94
- # cuda by default
95
- if device == "auto":
96
- device = "cuda"
97
- # Apple MPS is a second choice (sometimes)
98
- if device == "cuda" and not torch.cuda.is_available():
99
- device = "mps"
100
- # If no MPS, fallback to cpu
101
- if device == "mps" and not torch.backends.mps.is_available():
102
- device = "cpu"
103
- # Simple environments like Discreet and 1-D Boxes might also be better
104
- # served with the CPU.
105
- if device == "mps":
106
- obs_space = env.observation_space
107
- if isinstance(obs_space, Discrete):
108
- device = "cpu"
109
- elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
110
- device = "cpu"
111
- print(f"Device: {device}")
112
- return torch.device(device)
113
-
114
-
115
- def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
116
- if seed is None:
117
- return
118
- random.seed(seed)
119
- np.random.seed(seed)
120
- torch.manual_seed(seed)
121
- torch.backends.cudnn.benchmark = False
122
- torch.use_deterministic_algorithms(use_deterministic_algorithms)
123
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
124
- # Stop warning and it would introduce stochasticity if I was using TF
125
- os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
126
-
127
-
128
- def make_policy(
129
- algo: str,
130
- env: VecEnv,
131
- device: torch.device,
132
- load_path: Optional[str] = None,
133
- **kwargs,
134
- ) -> Policy:
135
- policy = POLICIES[algo](env, **kwargs).to(device)
136
- if load_path:
137
- policy.load(load_path)
138
- return policy
139
-
140
-
141
- def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
142
- figure = plt.figure()
143
- cumulative_steps = [
144
- (idx + 1) * callback.step_freq for idx in range(len(callback.stats))
145
- ]
146
- plt.plot(
147
- cumulative_steps,
148
- [s.score.mean for s in callback.stats],
149
- "b-",
150
- label="mean",
151
- )
152
- plt.plot(
153
- cumulative_steps,
154
- [s.score.mean - s.score.std for s in callback.stats],
155
- "g--",
156
- label="mean-std",
157
- )
158
- plt.fill_between(
159
- cumulative_steps,
160
- [s.score.min for s in callback.stats], # type: ignore
161
- [s.score.max for s in callback.stats], # type: ignore
162
- facecolor="cyan",
163
- label="range",
164
- )
165
- plt.xlabel("Steps")
166
- plt.ylabel("Score")
167
- plt.legend()
168
- plt.title(f"Eval {run_name}")
169
- tb_writer.add_figure("eval", figure)
170
-
171
-
172
- Scalar = Union[bool, str, float, int, None]
173
-
174
-
175
- def hparam_dict(
176
- hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
177
- ) -> Dict[str, Scalar]:
178
- flattened = args.copy()
179
- for k, v in flattened.items():
180
- if isinstance(v, list):
181
- flattened[k] = json.dumps(v)
182
- for k, v in hyperparams.items():
183
- if isinstance(v, dict):
184
- for sk, sv in v.items():
185
- key = f"{k}/{sk}"
186
- if isinstance(sv, dict) or isinstance(sv, list):
187
- flattened[key] = str(sv)
188
- else:
189
- flattened[key] = sv
190
- else:
191
- flattened[k] = v # type: ignore
192
- return flattened # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
runner/train.py DELETED
@@ -1,134 +0,0 @@
1
- # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
- import os
3
-
4
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
-
6
- import dataclasses
7
- import shutil
8
- import wandb
9
- import yaml
10
-
11
- from dataclasses import dataclass
12
- from torch.utils.tensorboard.writer import SummaryWriter
13
- from typing import Any, Dict, Optional, Sequence
14
-
15
- from shared.callbacks.eval_callback import EvalCallback
16
- from runner.config import Config, EnvHyperparams, RunArgs
17
- from runner.env import make_env, make_eval_env
18
- from runner.running_utils import (
19
- ALGOS,
20
- load_hyperparams,
21
- set_seeds,
22
- get_device,
23
- make_policy,
24
- plot_eval_callback,
25
- hparam_dict,
26
- )
27
- from shared.stats import EpisodesStats
28
-
29
-
30
- @dataclass
31
- class TrainArgs(RunArgs):
32
- wandb_project_name: Optional[str] = None
33
- wandb_entity: Optional[str] = None
34
- wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
35
-
36
-
37
- def train(args: TrainArgs):
38
- print(args)
39
- hyperparams = load_hyperparams(args.algo, args.env, os.getcwd())
40
- print(hyperparams)
41
- config = Config(args, hyperparams, os.getcwd())
42
-
43
- wandb_enabled = args.wandb_project_name
44
- if wandb_enabled:
45
- wandb.tensorboard.patch(
46
- root_logdir=config.tensorboard_summary_path, pytorch=True
47
- )
48
- wandb.init(
49
- project=args.wandb_project_name,
50
- entity=args.wandb_entity,
51
- config=hyperparams, # type: ignore
52
- name=config.run_name,
53
- monitor_gym=True,
54
- save_code=True,
55
- tags=args.wandb_tags,
56
- )
57
- wandb.config.update(args)
58
-
59
- tb_writer = SummaryWriter(config.tensorboard_summary_path)
60
-
61
- set_seeds(args.seed, args.use_deterministic_algorithms)
62
-
63
- env = make_env(
64
- config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
65
- )
66
- device = get_device(config.device, env)
67
- policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
68
- algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
69
-
70
- eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
71
- record_best_videos = config.eval_params.get("record_best_videos", True)
72
- callback = EvalCallback(
73
- policy,
74
- eval_env,
75
- tb_writer,
76
- best_model_path=config.model_dir_path(best=True),
77
- **config.eval_params,
78
- video_env=make_eval_env(
79
- config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
80
- )
81
- if record_best_videos
82
- else None,
83
- best_video_dir=config.best_videos_dir,
84
- )
85
- algo.learn(config.n_timesteps, callback=callback)
86
-
87
- policy.save(config.model_dir_path(best=False))
88
-
89
- eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
90
-
91
- plot_eval_callback(callback, tb_writer, config.run_name)
92
-
93
- log_dict: Dict[str, Any] = {
94
- "eval": eval_stats._asdict(),
95
- }
96
- if callback.best:
97
- log_dict["best_eval"] = callback.best._asdict()
98
- log_dict.update(hyperparams)
99
- log_dict.update(vars(args))
100
- with open(config.logs_path, "a") as f:
101
- yaml.dump({config.run_name: log_dict}, f)
102
-
103
- best_eval_stats: EpisodesStats = callback.best # type: ignore
104
- tb_writer.add_hparams(
105
- hparam_dict(hyperparams, vars(args)),
106
- {
107
- "hparam/best_mean": best_eval_stats.score.mean,
108
- "hparam/best_result": best_eval_stats.score.mean
109
- - best_eval_stats.score.std,
110
- "hparam/last_mean": eval_stats.score.mean,
111
- "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
112
- },
113
- None,
114
- config.run_name,
115
- )
116
-
117
- tb_writer.close()
118
-
119
- if wandb_enabled:
120
- wandb.run.summary["num_parameters"] = policy.num_parameters()
121
- wandb.run.summary[
122
- "num_trainable_parameters"
123
- ] = policy.num_trainable_parameters()
124
- shutil.make_archive(
125
- os.path.join(wandb.run.dir, config.model_dir_name()),
126
- "zip",
127
- config.model_dir_path(),
128
- )
129
- shutil.make_archive(
130
- os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
131
- "zip",
132
- config.model_dir_path(best=True),
133
- )
134
- wandb.finish()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
saved_models/vpg-MountainCarContinuous-v0-S3-best/vecnormalize.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cdae91d5d329b0e30416b6c517f2c590d532fc6239ffca077a5d58986c39edd4
3
- size 6572
 
 
 
 
shared/algorithm.py DELETED
@@ -1,35 +0,0 @@
1
- import gym
2
- import torch
3
-
4
- from abc import ABC, abstractmethod
5
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv
6
- from torch.utils.tensorboard.writer import SummaryWriter
7
- from typing import List, Optional, TypeVar
8
-
9
- from shared.callbacks.callback import Callback
10
- from shared.policy.policy import Policy
11
- from shared.stats import EpisodesStats
12
-
13
- AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
14
-
15
- class Algorithm(ABC):
16
- @abstractmethod
17
- def __init__(
18
- self,
19
- policy: Policy,
20
- env: VecEnv,
21
- device: torch.device,
22
- tb_writer: SummaryWriter,
23
- **kwargs,
24
- ) -> None:
25
- super().__init__()
26
- self.policy = policy
27
- self.env = env
28
- self.device = device
29
- self.tb_writer = tb_writer
30
-
31
- @abstractmethod
32
- def learn(
33
- self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None
34
- ) -> AlgorithmSelf:
35
- ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/callbacks/callback.py DELETED
@@ -1,12 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
-
4
- class Callback(ABC):
5
-
6
- def __init__(self) -> None:
7
- super().__init__()
8
- self.timesteps_elapsed = 0
9
-
10
- def on_step(self, timesteps_elapsed: int = 1) -> bool:
11
- self.timesteps_elapsed += timesteps_elapsed
12
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/callbacks/eval_callback.py DELETED
@@ -1,206 +0,0 @@
1
- import itertools
2
- import numpy as np
3
- import os
4
-
5
- from copy import deepcopy
6
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
7
- from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
8
- from torch.utils.tensorboard.writer import SummaryWriter
9
- from typing import List, Optional, Union
10
-
11
- from shared.callbacks.callback import Callback
12
- from shared.policy.policy import Policy
13
- from shared.stats import Episode, EpisodeAccumulator, EpisodesStats
14
- from wrappers.vec_episode_recorder import VecEpisodeRecorder
15
-
16
-
17
- class EvaluateAccumulator(EpisodeAccumulator):
18
- def __init__(
19
- self,
20
- num_envs: int,
21
- goal_episodes: int,
22
- print_returns: bool = True,
23
- ignore_first_episode: bool = False,
24
- ):
25
- super().__init__(num_envs)
26
- self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
27
- self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
28
- self.print_returns = print_returns
29
- if ignore_first_episode:
30
- first_done = set()
31
-
32
- def should_record_done(idx: int) -> bool:
33
- has_done_first_episode = idx in first_done
34
- first_done.add(idx)
35
- return has_done_first_episode
36
-
37
- self.should_record_done = should_record_done
38
- else:
39
- self.should_record_done = lambda idx: True
40
-
41
- def on_done(self, ep_idx: int, episode: Episode) -> None:
42
- if (
43
- self.should_record_done(ep_idx)
44
- and len(self.completed_episodes_by_env_idx[ep_idx])
45
- >= self.goal_episodes_per_env
46
- ):
47
- return
48
- self.completed_episodes_by_env_idx[ep_idx].append(episode)
49
- if self.print_returns:
50
- print(
51
- f"Episode {len(self)} | "
52
- f"Score {episode.score} | "
53
- f"Length {episode.length}"
54
- )
55
-
56
- def __len__(self) -> int:
57
- return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
58
-
59
- @property
60
- def episodes(self) -> List[Episode]:
61
- return list(itertools.chain(*self.completed_episodes_by_env_idx))
62
-
63
- def is_done(self) -> bool:
64
- return all(
65
- len(ce) == self.goal_episodes_per_env
66
- for ce in self.completed_episodes_by_env_idx
67
- )
68
-
69
-
70
- def evaluate(
71
- env: VecEnv,
72
- policy: Policy,
73
- n_episodes: int,
74
- render: bool = False,
75
- deterministic: bool = True,
76
- print_returns: bool = True,
77
- ignore_first_episode: bool = False,
78
- ) -> EpisodesStats:
79
- policy.eval()
80
- episodes = EvaluateAccumulator(
81
- env.num_envs, n_episodes, print_returns, ignore_first_episode
82
- )
83
-
84
- obs = env.reset()
85
- while not episodes.is_done():
86
- act = policy.act(obs, deterministic=deterministic)
87
- obs, rew, done, _ = env.step(act)
88
- episodes.step(rew, done)
89
- if render:
90
- env.render()
91
- stats = EpisodesStats(episodes.episodes)
92
- if print_returns:
93
- print(stats)
94
- return stats
95
-
96
-
97
- class EvalCallback(Callback):
98
- def __init__(
99
- self,
100
- policy: Policy,
101
- env: VecEnv,
102
- tb_writer: SummaryWriter,
103
- best_model_path: Optional[str] = None,
104
- step_freq: Union[int, float] = 50_000,
105
- n_episodes: int = 10,
106
- save_best: bool = True,
107
- deterministic: bool = True,
108
- record_best_videos: bool = True,
109
- video_env: Optional[VecEnv] = None,
110
- best_video_dir: Optional[str] = None,
111
- max_video_length: int = 3600,
112
- ignore_first_episode: bool = False,
113
- ) -> None:
114
- super().__init__()
115
- self.policy = policy
116
- self.env = env
117
- self.tb_writer = tb_writer
118
- self.best_model_path = best_model_path
119
- self.step_freq = int(step_freq)
120
- self.n_episodes = n_episodes
121
- self.save_best = save_best
122
- self.deterministic = deterministic
123
- self.stats: List[EpisodesStats] = []
124
- self.best = None
125
-
126
- self.record_best_videos = record_best_videos
127
- assert video_env or not record_best_videos
128
- self.video_env = video_env
129
- assert best_video_dir or not record_best_videos
130
- self.best_video_dir = best_video_dir
131
- if best_video_dir:
132
- os.makedirs(best_video_dir, exist_ok=True)
133
- self.max_video_length = max_video_length
134
- self.best_video_base_path = None
135
-
136
- self.ignore_first_episode = ignore_first_episode
137
-
138
- def on_step(self, timesteps_elapsed: int = 1) -> bool:
139
- super().on_step(timesteps_elapsed)
140
- if self.timesteps_elapsed // self.step_freq >= len(self.stats):
141
- sync_vec_normalize(self.policy.vec_normalize, self.env)
142
- self.evaluate()
143
- return True
144
-
145
- def evaluate(
146
- self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
147
- ) -> EpisodesStats:
148
- eval_stat = evaluate(
149
- self.env,
150
- self.policy,
151
- n_episodes or self.n_episodes,
152
- deterministic=self.deterministic,
153
- print_returns=print_returns or False,
154
- ignore_first_episode=self.ignore_first_episode,
155
- )
156
- self.policy.train(True)
157
- print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
158
-
159
- self.stats.append(eval_stat)
160
-
161
- if not self.best or eval_stat >= self.best:
162
- strictly_better = not self.best or eval_stat > self.best
163
- self.best = eval_stat
164
- if self.save_best:
165
- assert self.best_model_path
166
- self.policy.save(self.best_model_path)
167
- print("Saved best model")
168
- self.best.write_to_tensorboard(
169
- self.tb_writer, "best_eval", self.timesteps_elapsed
170
- )
171
- if strictly_better and self.record_best_videos:
172
- assert self.video_env and self.best_video_dir
173
- sync_vec_normalize(self.policy.vec_normalize, self.video_env)
174
- self.best_video_base_path = os.path.join(
175
- self.best_video_dir, str(self.timesteps_elapsed)
176
- )
177
- video_wrapped = VecEpisodeRecorder(
178
- self.video_env,
179
- self.best_video_base_path,
180
- max_video_length=self.max_video_length,
181
- )
182
- video_stats = evaluate(
183
- video_wrapped,
184
- self.policy,
185
- 1,
186
- deterministic=self.deterministic,
187
- print_returns=False,
188
- )
189
- print(f"Saved best video: {video_stats}")
190
-
191
- eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
192
-
193
- return eval_stat
194
-
195
-
196
- def sync_vec_normalize(
197
- origin_vec_normalize: Optional[VecNormalize], destination_env: VecEnv
198
- ) -> None:
199
- if origin_vec_normalize is not None:
200
- eval_env_wrapper = destination_env
201
- while isinstance(eval_env_wrapper, VecEnvWrapper):
202
- if isinstance(eval_env_wrapper, VecNormalize):
203
- if hasattr(origin_vec_normalize, "obs_rms"):
204
- eval_env_wrapper.obs_rms = deepcopy(origin_vec_normalize.obs_rms)
205
- eval_env_wrapper.ret_rms = deepcopy(origin_vec_normalize.ret_rms)
206
- eval_env_wrapper = eval_env_wrapper.venv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/gae.py DELETED
@@ -1,67 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- from typing import NamedTuple, Sequence
5
-
6
- from shared.policy.on_policy import OnPolicy
7
- from shared.trajectory import Trajectory
8
-
9
-
10
- class RtgAdvantage(NamedTuple):
11
- rewards_to_go: torch.Tensor
12
- advantage: torch.Tensor
13
-
14
-
15
- def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
16
- dc = x.copy()
17
- for i in reversed(range(len(x) - 1)):
18
- dc[i] += gamma * dc[i + 1]
19
- return dc
20
-
21
-
22
- def compute_advantage(
23
- trajectories: Sequence[Trajectory],
24
- policy: OnPolicy,
25
- gamma: float,
26
- gae_lambda: float,
27
- device: torch.device,
28
- ) -> torch.Tensor:
29
- advantage = []
30
- for traj in trajectories:
31
- last_val = 0
32
- if not traj.terminated and traj.next_obs is not None:
33
- last_val = policy.value(traj.next_obs)
34
- rew = np.append(np.array(traj.rew), last_val)
35
- v = np.append(np.array(traj.v), last_val)
36
- deltas = rew[:-1] + gamma * v[1:] - v[:-1]
37
- advantage.append(discounted_cumsum(deltas, gamma * gae_lambda))
38
- return torch.as_tensor(
39
- np.concatenate(advantage), dtype=torch.float32, device=device
40
- )
41
-
42
-
43
- def compute_rtg_and_advantage(
44
- trajectories: Sequence[Trajectory],
45
- policy: OnPolicy,
46
- gamma: float,
47
- gae_lambda: float,
48
- device: torch.device,
49
- ) -> RtgAdvantage:
50
- rewards_to_go = []
51
- advantages = []
52
- for traj in trajectories:
53
- last_val = 0
54
- if not traj.terminated and traj.next_obs is not None:
55
- last_val = policy.value(traj.next_obs)
56
- rew = np.append(np.array(traj.rew), last_val)
57
- v = np.append(np.array(traj.v), last_val)
58
- deltas = rew[:-1] + gamma * v[1:] - v[:-1]
59
- adv = discounted_cumsum(deltas, gamma * gae_lambda)
60
- advantages.append(adv)
61
- rewards_to_go.append(v[:-1] + adv)
62
- return RtgAdvantage(
63
- torch.as_tensor(
64
- np.concatenate(rewards_to_go), dtype=torch.float32, device=device
65
- ),
66
- torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/module/feature_extractor.py DELETED
@@ -1,215 +0,0 @@
1
- import gym
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from abc import ABC, abstractmethod
7
- from gym.spaces import Box, Discrete
8
- from stable_baselines3.common.preprocessing import get_flattened_obs_dim
9
- from typing import Dict, Optional, Sequence, Type
10
-
11
- from shared.module.module import layer_init
12
-
13
-
14
- class CnnFeatureExtractor(nn.Module, ABC):
15
- @abstractmethod
16
- def __init__(
17
- self,
18
- in_channels: int,
19
- activation: Type[nn.Module] = nn.ReLU,
20
- init_layers_orthogonal: Optional[bool] = None,
21
- **kwargs,
22
- ) -> None:
23
- super().__init__()
24
-
25
-
26
- class NatureCnn(CnnFeatureExtractor):
27
- """
28
- CNN from DQN Nature paper: Mnih, Volodymyr, et al.
29
- "Human-level control through deep reinforcement learning."
30
- Nature 518.7540 (2015): 529-533.
31
- """
32
-
33
- def __init__(
34
- self,
35
- in_channels: int,
36
- activation: Type[nn.Module] = nn.ReLU,
37
- init_layers_orthogonal: Optional[bool] = None,
38
- **kwargs,
39
- ) -> None:
40
- if init_layers_orthogonal is None:
41
- init_layers_orthogonal = True
42
- super().__init__(in_channels, activation, init_layers_orthogonal)
43
- self.cnn = nn.Sequential(
44
- layer_init(
45
- nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
46
- init_layers_orthogonal,
47
- ),
48
- activation(),
49
- layer_init(
50
- nn.Conv2d(32, 64, kernel_size=4, stride=2),
51
- init_layers_orthogonal,
52
- ),
53
- activation(),
54
- layer_init(
55
- nn.Conv2d(64, 64, kernel_size=3, stride=1),
56
- init_layers_orthogonal,
57
- ),
58
- activation(),
59
- nn.Flatten(),
60
- )
61
-
62
- def forward(self, obs: torch.Tensor) -> torch.Tensor:
63
- return self.cnn(obs)
64
-
65
-
66
- class ResidualBlock(nn.Module):
67
- def __init__(
68
- self,
69
- channels: int,
70
- activation: Type[nn.Module] = nn.ReLU,
71
- init_layers_orthogonal: bool = False,
72
- ) -> None:
73
- super().__init__()
74
- self.residual = nn.Sequential(
75
- activation(),
76
- layer_init(
77
- nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
78
- ),
79
- activation(),
80
- layer_init(
81
- nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
82
- ),
83
- )
84
-
85
- def forward(self, x: torch.Tensor) -> torch.Tensor:
86
- return x + self.residual(x)
87
-
88
-
89
- class ConvSequence(nn.Module):
90
- def __init__(
91
- self,
92
- in_channels: int,
93
- out_channels: int,
94
- activation: Type[nn.Module] = nn.ReLU,
95
- init_layers_orthogonal: bool = False,
96
- ) -> None:
97
- super().__init__()
98
- self.seq = nn.Sequential(
99
- layer_init(
100
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
101
- init_layers_orthogonal,
102
- ),
103
- nn.MaxPool2d(3, stride=2, padding=1),
104
- ResidualBlock(out_channels, activation, init_layers_orthogonal),
105
- ResidualBlock(out_channels, activation, init_layers_orthogonal),
106
- )
107
-
108
- def forward(self, x: torch.Tensor) -> torch.Tensor:
109
- return self.seq(x)
110
-
111
-
112
- class ImpalaCnn(CnnFeatureExtractor):
113
- """
114
- IMPALA-style CNN architecture
115
- """
116
-
117
- def __init__(
118
- self,
119
- in_channels: int,
120
- activation: Type[nn.Module] = nn.ReLU,
121
- init_layers_orthogonal: Optional[bool] = None,
122
- impala_channels: Sequence[int] = (16, 32, 32),
123
- **kwargs,
124
- ) -> None:
125
- if init_layers_orthogonal is None:
126
- init_layers_orthogonal = False
127
- super().__init__(in_channels, activation, init_layers_orthogonal)
128
- sequences = []
129
- for out_channels in impala_channels:
130
- sequences.append(
131
- ConvSequence(
132
- in_channels, out_channels, activation, init_layers_orthogonal
133
- )
134
- )
135
- in_channels = out_channels
136
- sequences.extend(
137
- [
138
- activation(),
139
- nn.Flatten(),
140
- ]
141
- )
142
- self.seq = nn.Sequential(*sequences)
143
-
144
- def forward(self, obs: torch.Tensor) -> torch.Tensor:
145
- return self.seq(obs)
146
-
147
-
148
- CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
149
- "nature": NatureCnn,
150
- "impala": ImpalaCnn,
151
- }
152
-
153
-
154
- class FeatureExtractor(nn.Module):
155
- def __init__(
156
- self,
157
- obs_space: gym.Space,
158
- activation: Type[nn.Module],
159
- init_layers_orthogonal: bool = False,
160
- cnn_feature_dim: int = 512,
161
- cnn_style: str = "nature",
162
- cnn_layers_init_orthogonal: Optional[bool] = None,
163
- impala_channels: Sequence[int] = (16, 32, 32),
164
- ) -> None:
165
- super().__init__()
166
- if isinstance(obs_space, Box):
167
- # Conv2D: (channels, height, width)
168
- if len(obs_space.shape) == 3:
169
- cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
170
- obs_space.shape[0],
171
- activation,
172
- init_layers_orthogonal=cnn_layers_init_orthogonal,
173
- impala_channels=impala_channels,
174
- )
175
-
176
- def preprocess(obs: torch.Tensor) -> torch.Tensor:
177
- if len(obs.shape) == 3:
178
- obs = obs.unsqueeze(0)
179
- return obs.float() / 255.0
180
-
181
- with torch.no_grad():
182
- cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
183
- self.preprocess = preprocess
184
- self.feature_extractor = nn.Sequential(
185
- cnn,
186
- layer_init(
187
- nn.Linear(cnn_out.shape[1], cnn_feature_dim),
188
- init_layers_orthogonal,
189
- ),
190
- activation(),
191
- )
192
- self.out_dim = cnn_feature_dim
193
- elif len(obs_space.shape) == 1:
194
-
195
- def preprocess(obs: torch.Tensor) -> torch.Tensor:
196
- if len(obs.shape) == 1:
197
- obs = obs.unsqueeze(0)
198
- return obs.float()
199
-
200
- self.preprocess = preprocess
201
- self.feature_extractor = nn.Flatten()
202
- self.out_dim = get_flattened_obs_dim(obs_space)
203
- else:
204
- raise ValueError(f"Unsupported observation space: {obs_space}")
205
- elif isinstance(obs_space, Discrete):
206
- self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
207
- self.feature_extractor = nn.Flatten()
208
- self.out_dim = obs_space.n
209
- else:
210
- raise NotImplementedError
211
-
212
- def forward(self, obs: torch.Tensor) -> torch.Tensor:
213
- if self.preprocess:
214
- obs = self.preprocess(obs)
215
- return self.feature_extractor(obs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/module/module.py DELETED
@@ -1,40 +0,0 @@
1
- import numpy as np
2
- import torch.nn as nn
3
-
4
- from typing import Sequence, Type
5
-
6
-
7
- def mlp(
8
- layer_sizes: Sequence[int],
9
- activation: Type[nn.Module],
10
- output_activation: Type[nn.Module] = nn.Identity,
11
- init_layers_orthogonal: bool = False,
12
- final_layer_gain: float = np.sqrt(2),
13
- ) -> nn.Module:
14
- layers = []
15
- for i in range(len(layer_sizes) - 2):
16
- layers.append(
17
- layer_init(
18
- nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
19
- )
20
- )
21
- layers.append(activation())
22
- layers.append(
23
- layer_init(
24
- nn.Linear(layer_sizes[-2], layer_sizes[-1]),
25
- init_layers_orthogonal,
26
- std=final_layer_gain,
27
- )
28
- )
29
- layers.append(output_activation())
30
- return nn.Sequential(*layers)
31
-
32
-
33
- def layer_init(
34
- layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
35
- ) -> nn.Module:
36
- if not init_layers_orthogonal:
37
- return layer
38
- nn.init.orthogonal_(layer.weight, std) # type: ignore
39
- nn.init.constant_(layer.bias, 0.0) # type: ignore
40
- return layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/policy/actor.py DELETED
@@ -1,305 +0,0 @@
1
- import gym
2
- import torch
3
- import torch.nn as nn
4
-
5
- from abc import ABC, abstractmethod
6
- from gym.spaces import Box, Discrete
7
- from torch.distributions import Categorical, Distribution, Normal
8
- from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
9
-
10
- from shared.module.feature_extractor import FeatureExtractor
11
- from shared.module.module import mlp
12
-
13
-
14
- class PiForward(NamedTuple):
15
- pi: Distribution
16
- logp_a: Optional[torch.Tensor]
17
- entropy: Optional[torch.Tensor]
18
-
19
-
20
- class Actor(nn.Module, ABC):
21
- @abstractmethod
22
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
23
- ...
24
-
25
-
26
- class CategoricalActorHead(Actor):
27
- def __init__(
28
- self,
29
- act_dim: int,
30
- hidden_sizes: Sequence[int] = (32,),
31
- activation: Type[nn.Module] = nn.Tanh,
32
- init_layers_orthogonal: bool = True,
33
- ) -> None:
34
- super().__init__()
35
- layer_sizes = tuple(hidden_sizes) + (act_dim,)
36
- self._fc = mlp(
37
- layer_sizes,
38
- activation,
39
- init_layers_orthogonal=init_layers_orthogonal,
40
- final_layer_gain=0.01,
41
- )
42
-
43
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
44
- logits = self._fc(obs)
45
- pi = Categorical(logits=logits)
46
- logp_a = None
47
- entropy = None
48
- if a is not None:
49
- logp_a = pi.log_prob(a)
50
- entropy = pi.entropy()
51
- return PiForward(pi, logp_a, entropy)
52
-
53
-
54
- class GaussianDistribution(Normal):
55
- def log_prob(self, a: torch.Tensor) -> torch.Tensor:
56
- return super().log_prob(a).sum(axis=-1)
57
-
58
- def sample(self) -> torch.Tensor:
59
- return self.rsample()
60
-
61
-
62
- class GaussianActorHead(Actor):
63
- def __init__(
64
- self,
65
- act_dim: int,
66
- hidden_sizes: Sequence[int] = (32,),
67
- activation: Type[nn.Module] = nn.Tanh,
68
- init_layers_orthogonal: bool = True,
69
- log_std_init: float = -0.5,
70
- ) -> None:
71
- super().__init__()
72
- layer_sizes = tuple(hidden_sizes) + (act_dim,)
73
- self.mu_net = mlp(
74
- layer_sizes,
75
- activation,
76
- init_layers_orthogonal=init_layers_orthogonal,
77
- final_layer_gain=0.01,
78
- )
79
- self.log_std = nn.Parameter(
80
- torch.ones(act_dim, dtype=torch.float32) * log_std_init
81
- )
82
-
83
- def _distribution(self, obs: torch.Tensor) -> Distribution:
84
- mu = self.mu_net(obs)
85
- std = torch.exp(self.log_std)
86
- return GaussianDistribution(mu, std)
87
-
88
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
89
- pi = self._distribution(obs)
90
- logp_a = None
91
- entropy = None
92
- if a is not None:
93
- logp_a = pi.log_prob(a)
94
- entropy = pi.entropy()
95
- return PiForward(pi, logp_a, entropy)
96
-
97
-
98
- class TanhBijector:
99
- def __init__(self, epsilon: float = 1e-6) -> None:
100
- self.epsilon = epsilon
101
-
102
- @staticmethod
103
- def forward(x: torch.Tensor) -> torch.Tensor:
104
- return torch.tanh(x)
105
-
106
- @staticmethod
107
- def inverse(y: torch.Tensor) -> torch.Tensor:
108
- eps = torch.finfo(y.dtype).eps
109
- clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps)
110
- return torch.atanh(clamped_y)
111
-
112
- def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
113
- return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
114
-
115
-
116
- class StateDependentNoiseDistribution(Normal):
117
- def __init__(
118
- self,
119
- loc,
120
- scale,
121
- latent_sde: torch.Tensor,
122
- exploration_mat: torch.Tensor,
123
- exploration_matrices: torch.Tensor,
124
- bijector: Optional[TanhBijector] = None,
125
- validate_args=None,
126
- ):
127
- super().__init__(loc, scale, validate_args)
128
- self.latent_sde = latent_sde
129
- self.exploration_mat = exploration_mat
130
- self.exploration_matrices = exploration_matrices
131
- self.bijector = bijector
132
-
133
- def log_prob(self, a: torch.Tensor) -> torch.Tensor:
134
- gaussian_a = self.bijector.inverse(a) if self.bijector else a
135
- log_prob = super().log_prob(gaussian_a).sum(axis=-1)
136
- if self.bijector:
137
- log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1)
138
- return log_prob
139
-
140
- def sample(self) -> torch.Tensor:
141
- noise = self._get_noise()
142
- actions = self.mean + noise
143
- return self.bijector.forward(actions) if self.bijector else actions
144
-
145
- def _get_noise(self) -> torch.Tensor:
146
- if len(self.latent_sde) == 1 or len(self.latent_sde) != len(
147
- self.exploration_matrices
148
- ):
149
- return torch.mm(self.latent_sde, self.exploration_mat)
150
- # (batch_size, n_features) -> (batch_size, 1, n_features)
151
- latent_sde = self.latent_sde.unsqueeze(dim=1)
152
- # (batch_size, 1, n_actions)
153
- noise = torch.bmm(latent_sde, self.exploration_matrices)
154
- return noise.squeeze(dim=1)
155
-
156
- @property
157
- def mode(self) -> torch.Tensor:
158
- mean = super().mode
159
- return self.bijector.forward(mean) if self.bijector else mean
160
-
161
-
162
- StateDependentNoiseActorHeadSelf = TypeVar(
163
- "StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead"
164
- )
165
-
166
-
167
- class StateDependentNoiseActorHead(Actor):
168
- def __init__(
169
- self,
170
- act_dim: int,
171
- hidden_sizes: Sequence[int] = (32,),
172
- activation: Type[nn.Module] = nn.Tanh,
173
- init_layers_orthogonal: bool = True,
174
- log_std_init: float = -0.5,
175
- full_std: bool = True,
176
- squash_output: bool = False,
177
- learn_std: bool = False,
178
- ) -> None:
179
- super().__init__()
180
- self.act_dim = act_dim
181
- layer_sizes = tuple(hidden_sizes) + (self.act_dim,)
182
- if len(layer_sizes) == 2:
183
- self.latent_net = nn.Identity()
184
- elif len(layer_sizes) > 2:
185
- self.latent_net = mlp(
186
- layer_sizes[:-1],
187
- activation,
188
- output_activation=activation,
189
- init_layers_orthogonal=init_layers_orthogonal,
190
- )
191
- else:
192
- raise ValueError("hidden_sizes must be of at least length 1")
193
- self.mu_net = mlp(
194
- layer_sizes[-2:],
195
- activation,
196
- init_layers_orthogonal=init_layers_orthogonal,
197
- final_layer_gain=0.01,
198
- )
199
- self.full_std = full_std
200
- std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1)
201
- self.log_std = nn.Parameter(
202
- torch.ones(std_dim, dtype=torch.float32) * log_std_init
203
- )
204
- self.bijector = TanhBijector() if squash_output else None
205
- self.learn_std = learn_std
206
- self.device = None
207
-
208
- self.exploration_mat = None
209
- self.exploration_matrices = None
210
- self.sample_weights()
211
-
212
- def to(
213
- self: StateDependentNoiseActorHeadSelf,
214
- device: Optional[torch.device] = None,
215
- dtype: Optional[Union[torch.dtype, str]] = None,
216
- non_blocking: bool = False,
217
- ) -> StateDependentNoiseActorHeadSelf:
218
- super().to(device, dtype, non_blocking)
219
- self.device = device
220
- return self
221
-
222
- def _distribution(self, obs: torch.Tensor) -> Distribution:
223
- latent = self.latent_net(obs)
224
- mu = self.mu_net(latent)
225
- latent_sde = latent if self.learn_std else latent.detach()
226
- variance = torch.mm(latent_sde**2, self._get_std() ** 2)
227
- assert self.exploration_mat is not None
228
- assert self.exploration_matrices is not None
229
- return StateDependentNoiseDistribution(
230
- mu,
231
- torch.sqrt(variance + 1e-6),
232
- latent_sde,
233
- self.exploration_mat,
234
- self.exploration_matrices,
235
- self.bijector,
236
- )
237
-
238
- def _get_std(self) -> torch.Tensor:
239
- std = torch.exp(self.log_std)
240
- if self.full_std:
241
- return std
242
- ones = torch.ones(self.log_std.shape[0], self.act_dim)
243
- if self.device:
244
- ones = ones.to(self.device)
245
- return ones * std
246
-
247
- def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
248
- pi = self._distribution(obs)
249
- logp_a = None
250
- entropy = None
251
- if a is not None:
252
- logp_a = pi.log_prob(a)
253
- entropy = -logp_a
254
- return PiForward(pi, logp_a, entropy)
255
-
256
- def sample_weights(self, batch_size: int = 1) -> None:
257
- std = self._get_std()
258
- weights_dist = Normal(torch.zeros_like(std), std)
259
- # Reparametrization trick to pass gradients
260
- self.exploration_mat = weights_dist.rsample()
261
- self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
262
-
263
-
264
- def actor_head(
265
- action_space: gym.Space,
266
- hidden_sizes: Sequence[int],
267
- init_layers_orthogonal: bool,
268
- activation: Type[nn.Module],
269
- log_std_init: float = -0.5,
270
- use_sde: bool = False,
271
- full_std: bool = True,
272
- squash_output: bool = False,
273
- ) -> Actor:
274
- assert not use_sde or isinstance(
275
- action_space, Box
276
- ), "use_sde only valid if Box action_space"
277
- assert not squash_output or use_sde, "squash_output only valid if use_sde"
278
- if isinstance(action_space, Discrete):
279
- return CategoricalActorHead(
280
- action_space.n,
281
- hidden_sizes=hidden_sizes,
282
- activation=activation,
283
- init_layers_orthogonal=init_layers_orthogonal,
284
- )
285
- elif isinstance(action_space, Box):
286
- if use_sde:
287
- return StateDependentNoiseActorHead(
288
- action_space.shape[0],
289
- hidden_sizes=hidden_sizes,
290
- activation=activation,
291
- init_layers_orthogonal=init_layers_orthogonal,
292
- log_std_init=log_std_init,
293
- full_std=full_std,
294
- squash_output=squash_output,
295
- )
296
- else:
297
- return GaussianActorHead(
298
- action_space.shape[0],
299
- hidden_sizes=hidden_sizes,
300
- activation=activation,
301
- init_layers_orthogonal=init_layers_orthogonal,
302
- log_std_init=log_std_init,
303
- )
304
- else:
305
- raise ValueError(f"Unsupported action space: {action_space}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/policy/critic.py DELETED
@@ -1,28 +0,0 @@
1
- import gym
2
- import torch
3
- import torch.nn as nn
4
-
5
- from typing import Sequence, Type
6
- from shared.module.feature_extractor import FeatureExtractor
7
- from shared.module.module import mlp
8
-
9
-
10
- class CriticHead(nn.Module):
11
- def __init__(
12
- self,
13
- hidden_sizes: Sequence[int] = (32,),
14
- activation: Type[nn.Module] = nn.Tanh,
15
- init_layers_orthogonal: bool = True,
16
- ) -> None:
17
- super().__init__()
18
- layer_sizes = tuple(hidden_sizes) + (1,)
19
- self._fc = mlp(
20
- layer_sizes,
21
- activation,
22
- init_layers_orthogonal=init_layers_orthogonal,
23
- final_layer_gain=1.0,
24
- )
25
-
26
- def forward(self, obs: torch.Tensor) -> torch.Tensor:
27
- v = self._fc(obs)
28
- return v.squeeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/policy/on_policy.py DELETED
@@ -1,201 +0,0 @@
1
- import gym
2
- import numpy as np
3
- import torch
4
-
5
- from abc import abstractmethod
6
- from gym.spaces import Box, Discrete, Space
7
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
8
- from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
9
-
10
- from shared.module.feature_extractor import FeatureExtractor
11
- from shared.policy.actor import PiForward, StateDependentNoiseActorHead, actor_head
12
- from shared.policy.critic import CriticHead
13
- from shared.policy.policy import ACTIVATION, Policy
14
-
15
-
16
- class Step(NamedTuple):
17
- a: np.ndarray
18
- v: np.ndarray
19
- logp_a: np.ndarray
20
- clamped_a: np.ndarray
21
-
22
-
23
- class ACForward(NamedTuple):
24
- logp_a: torch.Tensor
25
- entropy: torch.Tensor
26
- v: torch.Tensor
27
-
28
-
29
- FEAT_EXT_FILE_NAME = "feat_ext.pt"
30
- V_FEAT_EXT_FILE_NAME = "v_feat_ext.pt"
31
- PI_FILE_NAME = "pi.pt"
32
- V_FILE_NAME = "v.pt"
33
- ActorCriticSelf = TypeVar("ActorCriticSelf", bound="ActorCritic")
34
-
35
-
36
- def clamp_actions(
37
- actions: np.ndarray, action_space: gym.Space, squash_output: bool
38
- ) -> np.ndarray:
39
- if isinstance(action_space, Box):
40
- low, high = action_space.low, action_space.high # type: ignore
41
- if squash_output:
42
- # Squashed output is already between -1 and 1. Rescale if the actual
43
- # output needs to something other than -1 and 1
44
- return low + 0.5 * (actions + 1) * (high - low)
45
- else:
46
- return np.clip(actions, low, high)
47
- return actions
48
-
49
-
50
- def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
51
- if isinstance(obs_space, Box):
52
- if len(obs_space.shape) == 3:
53
- # By default feature extractor to output has no hidden layers
54
- return []
55
- elif len(obs_space.shape) == 1:
56
- return [64, 64]
57
- else:
58
- raise ValueError(f"Unsupported observation space: {obs_space}")
59
- elif isinstance(obs_space, Discrete):
60
- return [64]
61
- else:
62
- raise ValueError(f"Unsupported observation space: {obs_space}")
63
-
64
-
65
- class OnPolicy(Policy):
66
- @abstractmethod
67
- def value(self, obs: VecEnvObs) -> np.ndarray:
68
- ...
69
-
70
- @abstractmethod
71
- def step(self, obs: VecEnvObs) -> Step:
72
- ...
73
-
74
-
75
- class ActorCritic(OnPolicy):
76
- def __init__(
77
- self,
78
- env: VecEnv,
79
- pi_hidden_sizes: Sequence[int],
80
- v_hidden_sizes: Sequence[int],
81
- init_layers_orthogonal: bool = True,
82
- activation_fn: str = "tanh",
83
- log_std_init: float = -0.5,
84
- use_sde: bool = False,
85
- full_std: bool = True,
86
- squash_output: bool = False,
87
- share_features_extractor: bool = True,
88
- cnn_feature_dim: int = 512,
89
- cnn_style: str = "nature",
90
- cnn_layers_init_orthogonal: Optional[bool] = None,
91
- **kwargs,
92
- ) -> None:
93
- super().__init__(env, **kwargs)
94
- activation = ACTIVATION[activation_fn]
95
- observation_space = env.observation_space
96
- self.action_space = env.action_space
97
- self.squash_output = squash_output
98
- self.share_features_extractor = share_features_extractor
99
- self._feature_extractor = FeatureExtractor(
100
- observation_space,
101
- activation,
102
- init_layers_orthogonal=init_layers_orthogonal,
103
- cnn_feature_dim=cnn_feature_dim,
104
- cnn_style=cnn_style,
105
- cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
106
- )
107
- self._pi = actor_head(
108
- self.action_space,
109
- (self._feature_extractor.out_dim,) + tuple(pi_hidden_sizes),
110
- init_layers_orthogonal,
111
- activation,
112
- log_std_init=log_std_init,
113
- use_sde=use_sde,
114
- full_std=full_std,
115
- squash_output=squash_output,
116
- )
117
-
118
- if not share_features_extractor:
119
- self._v_feature_extractor = FeatureExtractor(
120
- observation_space,
121
- activation,
122
- init_layers_orthogonal=init_layers_orthogonal,
123
- cnn_feature_dim=cnn_feature_dim,
124
- cnn_style=cnn_style,
125
- cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
126
- )
127
- v_hidden_sizes = (self._v_feature_extractor.out_dim,) + tuple(
128
- v_hidden_sizes
129
- )
130
- else:
131
- self._v_feature_extractor = None
132
- v_hidden_sizes = (self._feature_extractor.out_dim,) + tuple(v_hidden_sizes)
133
- self._v = CriticHead(
134
- hidden_sizes=v_hidden_sizes,
135
- activation=activation,
136
- init_layers_orthogonal=init_layers_orthogonal,
137
- )
138
-
139
- def _pi_forward(
140
- self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
141
- ) -> Tuple[PiForward, torch.Tensor]:
142
- p_fe = self._feature_extractor(obs)
143
- pi_forward = self._pi(p_fe, action)
144
-
145
- return pi_forward, p_fe
146
-
147
- def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
148
- v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
149
- return self._v(v_fe)
150
-
151
- def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
152
- (_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
153
- v = self._v_forward(obs, p_fc)
154
-
155
- assert logp_a is not None
156
- assert entropy is not None
157
- return ACForward(logp_a, entropy, v)
158
-
159
- def value(self, obs: VecEnvObs) -> np.ndarray:
160
- o = self._as_tensor(obs)
161
- with torch.no_grad():
162
- fe = (
163
- self._v_feature_extractor(o)
164
- if self._v_feature_extractor
165
- else self._feature_extractor(o)
166
- )
167
- v = self._v(fe)
168
- return v.cpu().numpy()
169
-
170
- def step(self, obs: VecEnvObs) -> Step:
171
- o = self._as_tensor(obs)
172
- with torch.no_grad():
173
- (pi, _, _), p_fc = self._pi_forward(o)
174
- a = pi.sample()
175
- logp_a = pi.log_prob(a)
176
-
177
- v = self._v_forward(o, p_fc)
178
-
179
- a_np = a.cpu().numpy()
180
- clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
181
- return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
182
-
183
- def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
184
- if not deterministic:
185
- return self.step(obs).clamped_a
186
- else:
187
- o = self._as_tensor(obs)
188
- with torch.no_grad():
189
- (pi, _, _), _ = self._pi_forward(o)
190
- a = pi.mode
191
- return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
192
-
193
- def load(self, path: str) -> None:
194
- super().load(path)
195
- self.reset_noise()
196
-
197
- def reset_noise(self, batch_size: Optional[int] = None) -> None:
198
- if isinstance(self._pi, StateDependentNoiseActorHead):
199
- self._pi.sample_weights(
200
- batch_size=batch_size if batch_size else self.env.num_envs
201
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/policy/policy.py DELETED
@@ -1,74 +0,0 @@
1
- import numpy as np
2
- import os
3
- import torch
4
- import torch.nn as nn
5
-
6
- from abc import ABC, abstractmethod
7
- from stable_baselines3.common.vec_env import unwrap_vec_normalize
8
- from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
9
- from typing import Dict, Optional, Type, TypeVar, Union
10
-
11
- ACTIVATION: Dict[str, Type[nn.Module]] = {
12
- "tanh": nn.Tanh,
13
- "relu": nn.ReLU,
14
- }
15
-
16
- VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
17
- MODEL_FILENAME = "model.pth"
18
-
19
- PolicySelf = TypeVar("PolicySelf", bound="Policy")
20
-
21
-
22
- class Policy(nn.Module, ABC):
23
- @abstractmethod
24
- def __init__(self, env: VecEnv, **kwargs) -> None:
25
- super().__init__()
26
- self.env = env
27
- self.vec_normalize = unwrap_vec_normalize(env)
28
- self.device = None
29
-
30
- def to(
31
- self: PolicySelf,
32
- device: Optional[torch.device] = None,
33
- dtype: Optional[Union[torch.dtype, str]] = None,
34
- non_blocking: bool = False,
35
- ) -> PolicySelf:
36
- super().to(device, dtype, non_blocking)
37
- self.device = device
38
- return self
39
-
40
- @abstractmethod
41
- def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
42
- ...
43
-
44
- def save(self, path: str) -> None:
45
- os.makedirs(path, exist_ok=True)
46
-
47
- if self.vec_normalize:
48
- self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
49
- torch.save(
50
- self.state_dict(),
51
- os.path.join(path, MODEL_FILENAME),
52
- )
53
-
54
- def load(self, path: str) -> None:
55
- # VecNormalize load occurs in env.py
56
- self.load_state_dict(
57
- torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
58
- )
59
-
60
- def reset_noise(self) -> None:
61
- pass
62
-
63
- def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
64
- assert isinstance(obs, np.ndarray)
65
- o = torch.as_tensor(obs)
66
- if self.device is not None:
67
- o = o.to(self.device)
68
- return o
69
-
70
- def num_trainable_parameters(self) -> int:
71
- return sum(p.numel() for p in self.parameters() if p.requires_grad)
72
-
73
- def num_parameters(self) -> int:
74
- return sum(p.numel() for p in self.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/schedule.py DELETED
@@ -1,19 +0,0 @@
1
- from typing import Callable
2
-
3
- Schedule = Callable[[float], float]
4
-
5
-
6
- def linear_schedule(
7
- start_val: float, end_val: float, end_fraction: float = 1.0
8
- ) -> Schedule:
9
- def func(progress_fraction: float) -> float:
10
- if progress_fraction >= end_fraction:
11
- return end_val
12
- else:
13
- return start_val + (end_val - start_val) * progress_fraction / end_fraction
14
-
15
- return func
16
-
17
-
18
- def constant_schedule(val: float) -> Schedule:
19
- return lambda f: val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/stats.py DELETED
@@ -1,153 +0,0 @@
1
- import numpy as np
2
-
3
- from dataclasses import dataclass
4
- from torch.utils.tensorboard.writer import SummaryWriter
5
- from typing import Dict, List, Optional, Sequence, TypeVar
6
-
7
-
8
- @dataclass
9
- class Episode:
10
- score: float = 0
11
- length: int = 0
12
-
13
-
14
- StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
15
-
16
-
17
- @dataclass
18
- class Statistic:
19
- values: np.ndarray
20
- round_digits: int = 2
21
-
22
- @property
23
- def mean(self) -> float:
24
- return np.mean(self.values).item()
25
-
26
- @property
27
- def std(self) -> float:
28
- return np.std(self.values).item()
29
-
30
- @property
31
- def min(self) -> float:
32
- return np.min(self.values).item()
33
-
34
- @property
35
- def max(self) -> float:
36
- return np.max(self.values).item()
37
-
38
- def sum(self) -> float:
39
- return np.sum(self.values).item()
40
-
41
- def __len__(self) -> int:
42
- return len(self.values)
43
-
44
- def _diff(self: StatisticSelf, o: StatisticSelf) -> float:
45
- return (self.mean - self.std) - (o.mean - o.std)
46
-
47
- def __gt__(self: StatisticSelf, o: StatisticSelf) -> bool:
48
- return self._diff(o) > 0
49
-
50
- def __ge__(self: StatisticSelf, o: StatisticSelf) -> bool:
51
- return self._diff(o) >= 0
52
-
53
- def __repr__(self) -> str:
54
- mean = round(self.mean, self.round_digits)
55
- std = round(self.std, self.round_digits)
56
- if self.round_digits == 0:
57
- mean = int(mean)
58
- std = int(std)
59
- return f"{mean} +/- {std}"
60
-
61
- def to_dict(self) -> Dict[str, float]:
62
- return {
63
- "mean": self.mean,
64
- "std": self.std,
65
- "min": self.min,
66
- "max": self.max,
67
- }
68
-
69
-
70
- EpisodesStatsSelf = TypeVar("EpisodesStatsSelf", bound="EpisodesStats")
71
-
72
-
73
- class EpisodesStats:
74
- episodes: Sequence[Episode]
75
- simple: bool
76
- score: Statistic
77
- length: Statistic
78
-
79
- def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
80
- self.episodes = episodes
81
- self.simple = simple
82
- self.score = Statistic(np.array([e.score for e in episodes]))
83
- self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
84
-
85
- def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
86
- return self.score > o.score
87
-
88
- def __ge__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
89
- return self.score >= o.score
90
-
91
- def __repr__(self) -> str:
92
- return (
93
- f"Score: {self.score} ({round(self.score.mean - self.score.std, 2)}) | "
94
- f"Length: {self.length}"
95
- )
96
-
97
- def __len__(self) -> int:
98
- return len(self.episodes)
99
-
100
- def _asdict(self) -> dict:
101
- return {
102
- "n_episodes": len(self.episodes),
103
- "score": self.score.to_dict(),
104
- "length": self.length.to_dict(),
105
- }
106
-
107
- def write_to_tensorboard(
108
- self, tb_writer: SummaryWriter, main_tag: str, global_step: Optional[int] = None
109
- ) -> None:
110
- stats = {"mean": self.score.mean}
111
- if not self.simple:
112
- stats.update(
113
- {
114
- "min": self.score.min,
115
- "max": self.score.max,
116
- "result": self.score.mean - self.score.std,
117
- "n_episodes": len(self.episodes),
118
- "length": self.length.mean,
119
- }
120
- )
121
- tb_writer.add_scalars(
122
- main_tag,
123
- stats,
124
- global_step=global_step,
125
- )
126
-
127
-
128
- class EpisodeAccumulator:
129
- def __init__(self, num_envs: int):
130
- self._episodes = []
131
- self.current_episodes = [Episode() for _ in range(num_envs)]
132
-
133
- @property
134
- def episodes(self) -> List[Episode]:
135
- return self._episodes
136
-
137
- def step(self, reward: np.ndarray, done: np.ndarray) -> None:
138
- for idx, current in enumerate(self.current_episodes):
139
- current.score += reward[idx]
140
- current.length += 1
141
- if done[idx]:
142
- self._episodes.append(current)
143
- self.current_episodes[idx] = Episode()
144
- self.on_done(idx, current)
145
-
146
- def __len__(self) -> int:
147
- return len(self.episodes)
148
-
149
- def on_done(self, ep_idx: int, episode: Episode) -> None:
150
- pass
151
-
152
- def stats(self) -> EpisodesStats:
153
- return EpisodesStats(self.episodes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/trajectory.py DELETED
@@ -1,80 +0,0 @@
1
- import numpy as np
2
-
3
- from dataclasses import dataclass, field
4
- from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
5
- from typing import Generic, List, Optional, Type, TypeVar
6
-
7
-
8
- @dataclass
9
- class Trajectory:
10
- obs: List[np.ndarray] = field(default_factory=list)
11
- act: List[np.ndarray] = field(default_factory=list)
12
- next_obs: Optional[np.ndarray] = None
13
- rew: List[float] = field(default_factory=list)
14
- terminated: bool = False
15
- v: List[float] = field(default_factory=list)
16
-
17
- def add(
18
- self,
19
- obs: np.ndarray,
20
- act: np.ndarray,
21
- next_obs: np.ndarray,
22
- rew: float,
23
- terminated: bool,
24
- v: float,
25
- ):
26
- self.obs.append(obs)
27
- self.act.append(act)
28
- self.next_obs = next_obs if not terminated else None
29
- self.rew.append(rew)
30
- self.terminated = terminated
31
- self.v.append(v)
32
-
33
- def __len__(self) -> int:
34
- return len(self.obs)
35
-
36
-
37
- T = TypeVar("T", bound=Trajectory)
38
-
39
-
40
- class TrajectoryAccumulator(Generic[T]):
41
- def __init__(self, num_envs: int, trajectory_class: Type[T] = Trajectory) -> None:
42
- self.num_envs = num_envs
43
- self.trajectory_class = trajectory_class
44
-
45
- self._trajectories = []
46
- self._current_trajectories = [trajectory_class() for _ in range(num_envs)]
47
-
48
- def step(
49
- self,
50
- obs: VecEnvObs,
51
- action: np.ndarray,
52
- next_obs: VecEnvObs,
53
- reward: np.ndarray,
54
- done: np.ndarray,
55
- val: np.ndarray,
56
- *args,
57
- ) -> None:
58
- assert isinstance(obs, np.ndarray)
59
- assert isinstance(next_obs, np.ndarray)
60
- for i, args in enumerate(zip(obs, action, next_obs, reward, done, val, *args)):
61
- trajectory = self._current_trajectories[i]
62
- # TODO: Eventually take advantage of terminated/truncated differentiation in
63
- # later versions of gym.
64
- trajectory.add(*args)
65
- if done[i]:
66
- self._trajectories.append(trajectory)
67
- self._current_trajectories[i] = self.trajectory_class()
68
- self.on_done(i, trajectory)
69
-
70
- @property
71
- def all_trajectories(self) -> List[T]:
72
- return self._trajectories + list(
73
- filter(lambda t: len(t), self._current_trajectories)
74
- )
75
-
76
- def n_timesteps(self) -> int:
77
- return sum(len(t) for t in self.all_trajectories)
78
-
79
- def on_done(self, env_idx: int, trajectory: T) -> None:
80
- pass