00BER commited on
Commit
e085e3b
1 Parent(s): cc9a97a

Upload 36 files

Browse files
Makefile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: create-atari-env
2
+ create-atari-env: ## Creates conda environment
3
+ conda env create -f environment.atari-yml --force
4
+
5
+ .PHONY: create-procgen-env
6
+ create-procgen-env: ## Creates conda environment
7
+ conda env create -f environment.procgen.yml --force
8
+
9
+ .PHONY: setup-env
10
+ setup-env: ## Sets up conda environment
11
+ conda install pytorch torchvision numpy -c pytorch -y
12
+ pip install gym-retro
13
+ pip install "gym[atari]==0.21.0"
14
+ pip install importlib-metadata==4.13.0
15
+
16
+ .PHONY: run-air-dqn
17
+ run-air-dqn: ## Runs
18
+ python ./src/airstriker-genesis/run-airstriker-dqn.py
19
+
20
+ .PHONY: run-air-ddqn
21
+ run-air-ddqn: ## Runs
22
+ python ./src/airstriker-genesis/run-airstriker-ddqn.py
23
+
24
+ .PHONY: run-starpilot-dqn
25
+ run-starpilot-dqn: ## Runs
26
+ python ./src/procgen/run-starpilot-dqn.py
27
+
28
+ .PHONY: run-starpilot-ddqn
29
+ run-starpilot-ddqn: ## Runs
30
+ python ./src/procgen/run-starpilot-ddqn.py
README.md CHANGED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Abstract**
2
+
3
+ On January 1, 2013, DeepMind published a paper called "Playing Atari
4
+ with Deep Reinforcement Learning" introducing their algorithm called
5
+ Deep Q-Network (DQN) which revolutionized the field of reinforcement
6
+ learning. For the first time they had brought together Deep Learning and
7
+ Q-learning and showed impressive results applying deep reinforcement
8
+ learning to Atari games with their agents performing at or over human
9
+ level expertise in almost all the games trained on.
10
+ A Deep Q-Network utilizes a deep neural network to estimate the q-values
11
+ for each action, allowing the policy to select the action with the
12
+ maximum q-values. This use of deep neural network to get q-values was
13
+ immensely superior to implementing q-table look-ups and widened the
14
+ applicability of q-learning to more complex reinforcement learning
15
+ environments.
16
+ While revolutionary, the original version of DQN had a few problems,
17
+ especially its slow/inefficient learning process. Over these past 9
18
+ years, a few improved versions of DQNs have become popular. This project
19
+ is an attempt to study the effectiveness of a few of these DQN flavors,
20
+ what problems they solve and compare their performance in the same
21
+ reinforcement learning environment.
22
+
23
+ # Deep Q-Networks and its flavors
24
+
25
+ - **Vanilla DQN**
26
+
27
+ The vanilla (original) DQN uses 2 neural networks: the **online**
28
+ network and the **target** network. The online network is the main
29
+ neural network that the agent uses to select the best action for a
30
+ given state. The target neural network is usually a copy of the
31
+ online network. It is used to get the "target" q-values for each
32
+ action for a particular state. i.e. During the learning phase, since
33
+ we don’t have actual ground truths for future q-values, these
34
+ q-values from the target network will be used as labels optimize the
35
+ network.
36
+
37
+ The target network calculates the target q-values by using the
38
+ following Bellman equation: \[\begin{aligned}
39
+ Q(s_t, a_t) =
40
+ r_{t+1} + \gamma \max _{a_{t+1} \in A} Q(s_{t+1}, a_{t+1})
41
+ \end{aligned}\] where,
42
+ \(Q(s_t, a_t)\) = The target q-value (ground truth) for a past
43
+ experience in the replay memory
44
+
45
+ \(r_{t+1}\)= The reward that was obtained for taking the chosen
46
+ action in that particular experience
47
+
48
+ \(\gamma\)= The discount factor for future rewards
49
+
50
+ \(Q(s_{t+1}, a_{t+1})\) = The q-value for best action (based on the
51
+ policy) for the next state for that particular experience
52
+
53
+ - **Double DQN**
54
+
55
+ One of the problems with vanilla DQN is the way it calculates its
56
+ target values (ground-truth). We can see from the bellman equation
57
+ above that the target network uses the **max** q-value directly in
58
+ the equation. This is found to almost always overestimate the
59
+ q-value because using the **max** function introduces the
60
+ maximization-bias to our estimates. Using max will give the largest
61
+ value even if that specific max value was an outlier, thus skewing
62
+ our estimates.
63
+ The Double DQN solves this problem by changing the original
64
+ algorithm to the following:
65
+
66
+ 1. Instead of using the **max** function, first use the online
67
+ network to estimate the best action for the next state
68
+
69
+ 2. Calculate target q-values for the next state for each possible
70
+ action using the target network
71
+
72
+ 3. From the q-values calculated by the target network, use the
73
+ q-value of the action chosen in step 1.
74
+
75
+ This can be represented by the following equation: \[\begin{aligned}
76
+ Q(s_t, a_t) =
77
+ r_{t+1} + \gamma Q_{target}(s_{t+1}, a'_{t+1})
78
+ \end{aligned}\] where, \[\begin{aligned}
79
+ a'_{t+1} = argmax({Q_{online}(s_{t+1})})
80
+ \end{aligned}\]
81
+
82
+ - **Dueling DQN**
83
+
84
+ The Dueling DQN algorithm was an attempt to improve upon the
85
+ original DQN algorithm by changing the architecture of the neural
86
+ network used in Deep Q-learning. The Duelling DQN algorithm splits
87
+ the last layer of the DQN into to parts, a **value stream** and an
88
+ **advantage stream**, the outputs of which are aggregated in an
89
+ aggregating layer that gives the final q-value. One of the main
90
+ problems with the original DQN algorithm was that the difference in
91
+ Q-values for the actions were often very close. Thus, selecting the
92
+ action with the max q-value might always not be the best action to
93
+ take. The Dueling DQN attempts to mitigate this by using advantage,
94
+ which is a measure of how better an action is compared to other
95
+ actions for a given state. The value stream, on the other hand,
96
+ learns how good/bad it is to be in a specific state. eg. Moving
97
+ straight towards an obstacle in a racing game, being in the path of
98
+ a projectile in Space Invaders, etc. Instead of learning to predict
99
+ a single q-value, by separating into value and advantage streams
100
+ helps the network generalize better.
101
+
102
+ ![image](./docs/dueling.png)
103
+ Fig: The Dueling DQN architecture (Image taken from the original
104
+ paper by Wang et al.)
105
+
106
+
107
+ The q-value in a Dueling DQN architecture is given by
108
+ \[\begin{aligned}
109
+ Q(s_t, a_t) = V(s_t) + A(a)
110
+ \end{aligned}\] where,
111
+ V(s\_t) = The value of the current state (how advantageous it is to
112
+ be in that state)
113
+
114
+ A(a) =The advantage of taking action an a at that state
115
+
116
+ # About the project
117
+
118
+ My original goal for the project was to train an agent using DQN to
119
+ play **Airstriker Genesis**, a space shooting game and evaluate the
120
+ same agent’s performance on another similar game called
121
+ **Starpilot**. Unfortunately, I was unable to train a decent enough
122
+ agent in the first game, which made it meaningless to evaluate it’s
123
+ performance on yet another game.
124
+
125
+ Because I still want to do the original project some time in the
126
+ future, to prepare myself for that I thought it would be better to
127
+ first learn in-depth about how Deep Q-Networks work, what their
128
+ shortcomings are and how they can be improved. This, and for
129
+ time-constraint reasons, I have changed my project for this class to
130
+ a comparison of various DQN versions.
131
+
132
+ # Dataset
133
+
134
+ I used the excellent [Gym](https://github.com/openai/gym) library to
135
+ run my environment. A total of 9 agents, 1 in Airstriker Genesis, 4
136
+ in Starpilot and 4 in Lunar Lander were trained.
137
+
138
+ | **Game** | **Observation Space** | **Action Space** |
139
+ | :----------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
140
+ | Airstriker Genesis | RGB values of each pixel of the game screen (255, 255, 3) | Discrete(12) representing each of the buttons on the old Atari controllers. But since only three of those buttons were used in the game  the action space was reduced to 3 during training. ( Left, Right, Fire ) |
141
+ | Starpilot | RGB values of each pixel of the game screen (64, 64, 3) | Discrete(15) representing each of the button combos ( Left, Right, Up, Down, Up + Right, Up + Left, Down + Right, Down + Left, W, A, S, D, Q, E, Do nothing ) |
142
+ | Lunar Lander | 8-dimensional vector: ( X-coordinate, Y-coordinate, Linear velocity in X, Linear Velocity in Y, Angle, Angular Velocity, Boolean (Leg 1 in contact with ground), Boolean (Leg 2 in contact with ground) ) | Discrete(4)( Do nothing, Fire left engine, Fire main engine, Fire right engine ) |
143
+
144
+
145
+ **Environment/Libraries**:
146
+ Miniconda, Python 3.9, Gym, Pyorch, Numpy, Tensorboard on my
147
+ personal Macbook Pro (M1)
148
+
149
+ # ML Methodology
150
+
151
+ Each agent was trained using DQN or one of its flavors. Each agent
152
+ for a particular game was trained with the same hyperparameters with
153
+ just the underlying algorithm different. The following metrics for
154
+ each agent were used for evaluation:
155
+
156
+ - **Epsilon value over each episode** Shows what the exploration
157
+ rate was at the end of each episode.
158
+
159
+ - **Average Q-value for the last 100 episodes** A measure of the
160
+ average q-value (for the action chosen) for the last 100
161
+ episodes.
162
+
163
+ - **Average length for the last 100 episodes** A measure of the
164
+ average number of steps taken in each episode
165
+
166
+ - **Average loss for the last 100 episodes** A measure of loss
167
+ during learning in the last 100 episodes (A Huber Loss was used)
168
+
169
+ - **Average reward for the last 100 episodes** A measure of the
170
+ average reward the agent accumulated over the last 100 episodes
171
+
172
+ ## Preprocessing
173
+
174
+ For the Airstriker and the Starpilot games:
175
+
176
+ 1. Changed each frame to grayscale
177
+ Since the color shouldn’t matter to the agent, I decided to
178
+ change the RGB image to grayscale
179
+
180
+ 2. Changed observation space shape from (height, width, channels)
181
+ to (channels, height, width) to make it compatible with
182
+ Pytorch
183
+ Apparently Pytorch uses a different format than the direct
184
+ output of the gym environment. For this reason, I had to reshape
185
+ each observation to match Pytorch’s scheme (this took me a very
186
+ long time to figure out, but had an "Aha\!" moment when I
187
+ remember you saying something similar in class).
188
+
189
+ 3. Framestacking
190
+ Instead of processing 1 frame at a time, process 4 frames at a
191
+ time. This is because just 1 frame is not enough information for
192
+ the agent to decide what action to take.
193
+
194
+ For Lunar Lander, since the reward changes are very drastic (sudden
195
+ +100, -100, +200) rewards, I experimented with Reward Clipping
196
+ (clipping the rewards to \[-1, 1\] range) but this didn’t seem to
197
+ make much difference in my agent’s performance.
198
+
199
+ # Results
200
+
201
+ - **Airstriker Genesis**
202
+ The loss went down until about 5200 episodes but after that it
203
+ stopped going down any further. Consequently the average reward the
204
+ agent accumulated over the last 100 episodes pretty much plateaued
205
+ after about 5000 episodes. On analysis, I noticed that my
206
+ exploration rate at the end of the 7000th episode was still about
207
+ 0.65, which means that the agent was taking random actions more than
208
+ half of the time. On hindsight, I feel like I should have trained
209
+ more, at least until the epsilon value (exploration rate) completely
210
+ decayed to 5%.
211
+ ![image](./docs/air1.png) ![image](./docs/air2.png) ![image](./docs/air3.png)
212
+
213
+
214
+ - **Starpilot**
215
+
216
+ I trained DQN, Double DQN, Dueling DQN and Dueling Double DQN
217
+ versions for this game to compare the different algorithms.
218
+ From the graph of mean q-values, we can tell that the Vanilla DQN
219
+ versions indeed give high q-values, and their Double-DQN couterparts
220
+ give lower values, which makes me think that my implementation of
221
+ the Double DQN algorithm was OK. I had expected the agent to
222
+ accumulate higher rewards starting much earlier for the Double and
223
+ Dueling versions, but since the average rewards was almost similar
224
+ for all the agents, I could not notice any stark differences between
225
+ the performance of each agent.
226
+
227
+ ![image](./docs/star1.png)
228
+
229
+ ![image](./docs/star2.png)
230
+
231
+ | | |
232
+ | :------------------ | :------------------ |
233
+ | ![image](./docs/star3.png) | ![image](./docs/star4.png) |
234
+
235
+
236
+ - **Lunar Lander**
237
+
238
+ Since I did gain much insight from the agent in the Starpilot game,
239
+ I thought I was not training long enough. So I tried training the
240
+ same agents on Lunar Lander, which is a comparatively simpler game
241
+ with a smaller observation space and one that a DQN algorithm should
242
+ be able converge pretty quickly to (based on comments by other
243
+ people in the RL community).
244
+ ![image](./docs/lunar1.png)
245
+
246
+ ![image](./docs/lunar2.png)
247
+
248
+ | | |
249
+ | :------------------- | :------------------- |
250
+ | ![image](./docs/lunar3.png) | ![image](./docs/lunar4.png) |
251
+
252
+
253
+
254
+ The results for this were interesting. Although I did not find any
255
+ vast difference between the different variations of the DQN
256
+ algorithm, I found that the performance of my agent suddenly got
257
+ worse at around 300 episodes. Upon researching on why this may have
258
+ happened, I learned that DQN agents suffer from **catastrophic
259
+ forgetting** i.e. after training extensively, the network suddenly
260
+ forgets what it has learned in the past and the starts performing
261
+ worse. Initially, I thought this might have been the case, but since
262
+ I haven’t trained long enough, and because all models started
263
+ performing worse at almost exactly the same episode number, I think
264
+ this might be a problem with my code or some hyperparameter that I
265
+ used.
266
+
267
+ Upon checking what the agent was doing in the actual game, I found
268
+ that it was playing it very safe and just constantly hovering in the
269
+ air, not attempting to land the spaceship (the goal of the agent is
270
+ to land within the yellow flags). I thought maybe penalizing the
271
+ rewards for taking too many steps in the episode would work, but
272
+ that didn’t help either.
273
+
274
+ ![image](./docs/check.png)
275
+
276
+ # Problems Faced
277
+
278
+
279
+ Here are a few of the problems that I faced while training my agents:
280
+
281
+ - Understanding the various hyperparameters in the algorithm. DQN uses
282
+ a lot of moving parts and thus, tuning each parameter was a
283
+ difficult task. There were about 8 different hyperparameters (some
284
+ correlated) that impacted the agent’s training performance. I
285
+ struggled with understanding how each parameter impacted the agent
286
+ and also with figuring out how to find optimal values for those. I
287
+ ended up tuning them by trial and error.
288
+
289
+ - I got stuck for a long time figuring out why my convolutional layer
290
+ was not working. I didn’t realize that Pytorch has the channels in
291
+ the first dimension, and because of that, I was passing huge numbers
292
+ like 255 (the height of the image) into the input dimension for a
293
+ Conv2D layer.
294
+
295
+ - I struggled with knowing how long is long enough to realize that a
296
+ model is not working. I trained a model on Airstriker Genesis for 14
297
+ hours just to realize later that I had set a parameter incorrectly
298
+ and had to retrain all over again.
299
+
300
+ # What Next?
301
+
302
+ Although I didn’t get a final working agent for any of the games I
303
+ tried, I feel like I have learned a lot about reinforcement learning,
304
+ especially about Deep Q-learning. I plan to improve upon this further,
305
+ and hopefully get an agent to go far into at least one of the games.
306
+ Next time, I will start with first debugging my current code and see if
307
+ I have any implementation mistakes. Then I will train them a lot longer
308
+ than I did this time and see if it works. While learning about the
309
+ different flavors of DQN, I also learned a little about NoisyNet DQN,
310
+ Rainbow-DQN and Prioritized Experience Replay. I couln’t implement these
311
+ for this project, but I would like to try them out some time soon.
312
+
313
+ # Lessons Learned
314
+
315
+ - Reinforcement learning is a very challenging problem. It takes a
316
+ substantially large amount of time to train, it is hard to debug and
317
+ it is very difficult to tune its hyperparameters just right. It is a
318
+ lot different from supervised learning in that there are no actual
319
+ labels and thus, this makes optimization very difficult.
320
+
321
+ - I tried training an agent on the Atari Airstriker Genesis and the
322
+ procgen Starpilot game using just the CPU, but this took a very long
323
+ time. This is understandable because the inputs are images and using
324
+ a GPU would have been obviously better. Next time, I will definitely
325
+ try using a GPU to make training faster.
326
+
327
+ - Upon being faced with the problem of my agent not learning, I went
328
+ into research mode and got to learn a lot about DQN and its improved
329
+ versions. I am not a master of the algorithms yet (I have yet to get
330
+ an agent to perform well in the game), but I feel like I understand
331
+ how each version works.
332
+
333
+ - Rather than just following someone’s tutorial, also reading the
334
+ actual papers for that particular algorithm helped me understand the
335
+ algorithm better and code it.
336
+
337
+ - Doing this project reinforced into me that I love the concept of
338
+ reinforcement learning. It has made me even more interested into
339
+ exploring the field further and learn more.
340
+
341
+ # References / Resources
342
+
343
+ - [Reinforcement Learning (DQN) Tutorial, Adam
344
+ Paszke](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
345
+
346
+ - [Train a mario-playing RL agent, Yuansong Feng, Suraj Subramanian,
347
+ Howard Wang, Steven
348
+ Guo](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
349
+
350
+ - [About Double DQN, Dueling
351
+ DQN](https://horomary.hatenablog.com/entry/2021/02/06/013412)
352
+
353
+ - [Dueling Network Architecture for Deep Reinforcement Learning (Wang
354
+ et al., 2015))](https://arxiv.org/abs/1511.06581)
355
+
356
+
357
+ *(Final source code for the project can be found*
358
+ [*here*](https://github.com/00ber/ml-reinforcement-learning)*)*.
environment.atari.yml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mlrl
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - absl-py=1.3.0=py37hecd8cb5_0
7
+ - aiohttp=3.8.3=py37h6c40b1e_0
8
+ - aiosignal=1.2.0=pyhd3eb1b0_0
9
+ - appnope=0.1.2=py37hecd8cb5_1001
10
+ - async-timeout=4.0.2=py37hecd8cb5_0
11
+ - asynctest=0.13.0=py_0
12
+ - attrs=22.1.0=py37hecd8cb5_0
13
+ - backcall=0.2.0=pyhd3eb1b0_0
14
+ - blas=1.0=mkl
15
+ - blinker=1.4=py37hecd8cb5_0
16
+ - brotli=1.0.9=hca72f7f_7
17
+ - brotli-bin=1.0.9=hca72f7f_7
18
+ - brotlipy=0.7.0=py37h9ed2024_1003
19
+ - bzip2=1.0.8=h1de35cc_0
20
+ - c-ares=1.18.1=hca72f7f_0
21
+ - ca-certificates=2022.10.11=hecd8cb5_0
22
+ - cachetools=4.2.2=pyhd3eb1b0_0
23
+ - cairo=1.14.12=hc4e6be7_4
24
+ - certifi=2022.9.24=py37hecd8cb5_0
25
+ - cffi=1.15.0=py37hca72f7f_0
26
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
27
+ - click=8.0.4=py37hecd8cb5_0
28
+ - cryptography=38.0.1=py37hf6deb26_0
29
+ - cycler=0.11.0=pyhd3eb1b0_0
30
+ - dataclasses=0.8=pyh6d0b6a4_7
31
+ - decorator=5.1.1=pyhd3eb1b0_0
32
+ - expat=2.4.9=he9d5cce_0
33
+ - ffmpeg=4.0=h01ea3c9_0
34
+ - flit-core=3.6.0=pyhd3eb1b0_0
35
+ - fontconfig=2.14.1=hedf32ac_1
36
+ - fonttools=4.25.0=pyhd3eb1b0_0
37
+ - freetype=2.12.1=hd8bbffd_0
38
+ - frozenlist=1.3.3=py37h6c40b1e_0
39
+ - gettext=0.21.0=h7535e17_0
40
+ - giflib=5.2.1=haf1e3a3_0
41
+ - glib=2.63.1=hd977a24_0
42
+ - google-auth=2.6.0=pyhd3eb1b0_0
43
+ - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
44
+ - graphite2=1.3.14=he9d5cce_1
45
+ - grpcio=1.42.0=py37ha29bfda_0
46
+ - harfbuzz=1.8.8=hb8d4a28_0
47
+ - hdf5=1.10.2=hfa1e0ec_1
48
+ - icu=58.2=h0a44026_3
49
+ - idna=3.4=py37hecd8cb5_0
50
+ - intel-openmp=2021.4.0=hecd8cb5_3538
51
+ - ipython=7.31.1=py37hecd8cb5_1
52
+ - jasper=2.0.14=h0129ec2_2
53
+ - jedi=0.18.1=py37hecd8cb5_1
54
+ - jpeg=9e=hca72f7f_0
55
+ - kiwisolver=1.4.2=py37he9d5cce_0
56
+ - lcms2=2.12=hf1fd2bf_0
57
+ - lerc=3.0=he9d5cce_0
58
+ - libbrotlicommon=1.0.9=hca72f7f_7
59
+ - libbrotlidec=1.0.9=hca72f7f_7
60
+ - libbrotlienc=1.0.9=hca72f7f_7
61
+ - libcxx=14.0.6=h9765a3e_0
62
+ - libdeflate=1.8=h9ed2024_5
63
+ - libedit=3.1.20221030=h6c40b1e_0
64
+ - libffi=3.2.1=h0a44026_1007
65
+ - libgfortran=3.0.1=h93005f0_2
66
+ - libiconv=1.16=hca72f7f_2
67
+ - libopencv=3.4.2=h7c891bd_1
68
+ - libopus=1.3.1=h1de35cc_0
69
+ - libpng=1.6.37=ha441bb4_0
70
+ - libprotobuf=3.20.1=h8346a28_0
71
+ - libtiff=4.4.0=h2cd0358_2
72
+ - libvpx=1.7.0=h378b8a2_0
73
+ - libwebp=1.2.4=h56c3ce4_0
74
+ - libwebp-base=1.2.4=hca72f7f_0
75
+ - libxml2=2.9.14=hbf8cd5e_0
76
+ - llvm-openmp=14.0.6=h0dcd299_0
77
+ - lz4-c=1.9.4=hcec6c5f_0
78
+ - markdown=3.3.4=py37hecd8cb5_0
79
+ - matplotlib=3.1.2=py37h9aa3819_0
80
+ - matplotlib-inline=0.1.6=py37hecd8cb5_0
81
+ - mkl=2021.4.0=hecd8cb5_637
82
+ - mkl-service=2.4.0=py37h9ed2024_0
83
+ - mkl_fft=1.3.1=py37h4ab4a9b_0
84
+ - mkl_random=1.2.2=py37hb2f4e1b_0
85
+ - multidict=6.0.2=py37hca72f7f_0
86
+ - munkres=1.1.4=py_0
87
+ - ncurses=6.3=hca72f7f_3
88
+ - numpy=1.21.5=py37h2e5f0a9_3
89
+ - numpy-base=1.21.5=py37h3b1a694_3
90
+ - oauthlib=3.2.1=py37hecd8cb5_0
91
+ - olefile=0.46=py37_0
92
+ - opencv=3.4.2=py37h6fd60c2_1
93
+ - openssl=1.1.1s=hca72f7f_0
94
+ - packaging=21.3=pyhd3eb1b0_0
95
+ - parso=0.8.3=pyhd3eb1b0_0
96
+ - pcre=8.45=h23ab428_0
97
+ - pexpect=4.8.0=pyhd3eb1b0_3
98
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
99
+ - pillow=6.1.0=py37hb68e598_0
100
+ - pip=22.3.1=py37hecd8cb5_0
101
+ - pixman=0.40.0=h9ed2024_1
102
+ - prompt-toolkit=3.0.20=pyhd3eb1b0_0
103
+ - protobuf=3.20.1=py37he9d5cce_0
104
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
105
+ - py-opencv=3.4.2=py37h7c891bd_1
106
+ - pyasn1=0.4.8=pyhd3eb1b0_0
107
+ - pyasn1-modules=0.2.8=py_0
108
+ - pycparser=2.21=pyhd3eb1b0_0
109
+ - pygments=2.11.2=pyhd3eb1b0_0
110
+ - pyjwt=2.4.0=py37hecd8cb5_0
111
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
112
+ - pyparsing=3.0.9=py37hecd8cb5_0
113
+ - pysocks=1.7.1=py37hecd8cb5_0
114
+ - python=3.7.3=h359304d_0
115
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
116
+ - pytorch=1.13.1=py3.7_0
117
+ - readline=7.0=h1de35cc_5
118
+ - requests=2.28.1=py37hecd8cb5_0
119
+ - requests-oauthlib=1.3.0=py_0
120
+ - rsa=4.7.2=pyhd3eb1b0_1
121
+ - setuptools=65.5.0=py37hecd8cb5_0
122
+ - six=1.16.0=pyhd3eb1b0_1
123
+ - sqlite=3.33.0=hffcf06c_0
124
+ - tensorboard=2.9.0=py37hecd8cb5_0
125
+ - tensorboard-data-server=0.6.1=py37h7242b5c_0
126
+ - tensorboard-plugin-wit=1.6.0=py_0
127
+ - tk=8.6.12=h5d9f67b_0
128
+ - torchvision=0.2.2=py_3
129
+ - tornado=6.2=py37hca72f7f_0
130
+ - tqdm=4.64.1=py37hecd8cb5_0
131
+ - traitlets=5.7.1=py37hecd8cb5_0
132
+ - typing-extensions=4.4.0=py37hecd8cb5_0
133
+ - typing_extensions=4.4.0=py37hecd8cb5_0
134
+ - urllib3=1.26.13=py37hecd8cb5_0
135
+ - wcwidth=0.2.5=pyhd3eb1b0_0
136
+ - werkzeug=2.0.3=pyhd3eb1b0_0
137
+ - wheel=0.37.1=pyhd3eb1b0_0
138
+ - xz=5.2.8=h6c40b1e_0
139
+ - yarl=1.8.1=py37hca72f7f_0
140
+ - zlib=1.2.13=h4dc903c_0
141
+ - zstd=1.5.2=hcb37349_0
142
+ - pip:
143
+ - ale-py==0.7.5
144
+ - cloudpickle==2.2.0
145
+ - gym==0.21.0
146
+ - gym-notices==0.0.8
147
+ - gym-retro==0.8.0
148
+ - importlib-metadata==4.13.0
149
+ - importlib-resources==5.10.1
150
+ - pygame==2.1.0
151
+ - pyglet==1.5.27
152
+ - zipp==3.11.0
153
+ prefix: /Users/karkisushant/miniconda3/envs/mlrl
environment.procgen-v2.yml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: procgen
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - absl-py=1.3.0=py39hecd8cb5_0
7
+ - aiohttp=3.8.3=py39h6c40b1e_0
8
+ - aiosignal=1.2.0=pyhd3eb1b0_0
9
+ - async-timeout=4.0.2=py39hecd8cb5_0
10
+ - attrs=22.1.0=py39hecd8cb5_0
11
+ - blas=1.0=mkl
12
+ - blinker=1.4=py39hecd8cb5_0
13
+ - brotli=1.0.9=hca72f7f_7
14
+ - brotli-bin=1.0.9=hca72f7f_7
15
+ - brotlipy=0.7.0=py39h9ed2024_1003
16
+ - bzip2=1.0.8=h1de35cc_0
17
+ - c-ares=1.18.1=hca72f7f_0
18
+ - ca-certificates=2022.10.11=hecd8cb5_0
19
+ - cachetools=4.2.2=pyhd3eb1b0_0
20
+ - certifi=2022.9.24=py39hecd8cb5_0
21
+ - cffi=1.15.1=py39h6c40b1e_3
22
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
23
+ - click=8.0.4=py39hecd8cb5_0
24
+ - contourpy=1.0.5=py39haf03e11_0
25
+ - cryptography=38.0.1=py39hf6deb26_0
26
+ - cycler=0.11.0=pyhd3eb1b0_0
27
+ - ffmpeg=4.3=h0a44026_0
28
+ - flit-core=3.6.0=pyhd3eb1b0_0
29
+ - fonttools=4.25.0=pyhd3eb1b0_0
30
+ - freetype=2.12.1=hd8bbffd_0
31
+ - frozenlist=1.3.3=py39h6c40b1e_0
32
+ - gettext=0.21.0=h7535e17_0
33
+ - giflib=5.2.1=haf1e3a3_0
34
+ - gmp=6.2.1=he9d5cce_3
35
+ - gnutls=3.6.15=hed9c0bf_0
36
+ - google-auth=2.6.0=pyhd3eb1b0_0
37
+ - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
38
+ - grpcio=1.42.0=py39ha29bfda_0
39
+ - icu=58.2=h0a44026_3
40
+ - idna=3.4=py39hecd8cb5_0
41
+ - importlib-metadata=4.11.3=py39hecd8cb5_0
42
+ - intel-openmp=2021.4.0=hecd8cb5_3538
43
+ - jpeg=9e=hca72f7f_0
44
+ - kiwisolver=1.4.2=py39he9d5cce_0
45
+ - lame=3.100=h1de35cc_0
46
+ - lcms2=2.12=hf1fd2bf_0
47
+ - lerc=3.0=he9d5cce_0
48
+ - libbrotlicommon=1.0.9=hca72f7f_7
49
+ - libbrotlidec=1.0.9=hca72f7f_7
50
+ - libbrotlienc=1.0.9=hca72f7f_7
51
+ - libcxx=14.0.6=h9765a3e_0
52
+ - libdeflate=1.8=h9ed2024_5
53
+ - libffi=3.4.2=hecd8cb5_6
54
+ - libiconv=1.16=hca72f7f_2
55
+ - libidn2=2.3.2=h9ed2024_0
56
+ - libpng=1.6.37=ha441bb4_0
57
+ - libprotobuf=3.20.1=h8346a28_0
58
+ - libtasn1=4.16.0=h9ed2024_0
59
+ - libtiff=4.4.0=h2cd0358_2
60
+ - libunistring=0.9.10=h9ed2024_0
61
+ - libwebp=1.2.4=h56c3ce4_0
62
+ - libwebp-base=1.2.4=hca72f7f_0
63
+ - libxml2=2.9.14=hbf8cd5e_0
64
+ - llvm-openmp=14.0.6=h0dcd299_0
65
+ - lz4-c=1.9.4=hcec6c5f_0
66
+ - markdown=3.3.4=py39hecd8cb5_0
67
+ - markupsafe=2.1.1=py39hca72f7f_0
68
+ - matplotlib=3.6.2=py39hecd8cb5_0
69
+ - matplotlib-base=3.6.2=py39h220de94_0
70
+ - mkl=2021.4.0=hecd8cb5_637
71
+ - mkl-service=2.4.0=py39h9ed2024_0
72
+ - mkl_fft=1.3.1=py39h4ab4a9b_0
73
+ - mkl_random=1.2.2=py39hb2f4e1b_0
74
+ - multidict=6.0.2=py39hca72f7f_0
75
+ - munkres=1.1.4=py_0
76
+ - ncurses=6.3=hca72f7f_3
77
+ - nettle=3.7.3=h230ac6f_1
78
+ - numpy=1.23.4=py39he696674_0
79
+ - numpy-base=1.23.4=py39h9cd3388_0
80
+ - oauthlib=3.2.1=py39hecd8cb5_0
81
+ - openh264=2.1.1=h8346a28_0
82
+ - openssl=1.1.1s=hca72f7f_0
83
+ - packaging=21.3=pyhd3eb1b0_0
84
+ - pillow=9.2.0=py39hde71d04_1
85
+ - pip=22.3.1=py39hecd8cb5_0
86
+ - protobuf=3.20.1=py39he9d5cce_0
87
+ - pyasn1=0.4.8=pyhd3eb1b0_0
88
+ - pyasn1-modules=0.2.8=py_0
89
+ - pycparser=2.21=pyhd3eb1b0_0
90
+ - pyjwt=2.4.0=py39hecd8cb5_0
91
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
92
+ - pyparsing=3.0.9=py39hecd8cb5_0
93
+ - pysocks=1.7.1=py39hecd8cb5_0
94
+ - python=3.9.15=h218abb5_2
95
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
96
+ - pytorch=1.13.1=py3.9_0
97
+ - readline=8.2=hca72f7f_0
98
+ - requests=2.28.1=py39hecd8cb5_0
99
+ - requests-oauthlib=1.3.0=py_0
100
+ - rsa=4.7.2=pyhd3eb1b0_1
101
+ - setuptools=65.5.0=py39hecd8cb5_0
102
+ - six=1.16.0=pyhd3eb1b0_1
103
+ - sqlite=3.40.0=h880c91c_0
104
+ - tensorboard=2.9.0=py39hecd8cb5_0
105
+ - tensorboard-data-server=0.6.1=py39h7242b5c_0
106
+ - tensorboard-plugin-wit=1.6.0=py_0
107
+ - tk=8.6.12=h5d9f67b_0
108
+ - torchvision=0.14.1=py39_cpu
109
+ - tornado=6.2=py39hca72f7f_0
110
+ - tqdm=4.64.1=py39hecd8cb5_0
111
+ - typing_extensions=4.4.0=py39hecd8cb5_0
112
+ - tzdata=2022g=h04d1e81_0
113
+ - urllib3=1.26.13=py39hecd8cb5_0
114
+ - werkzeug=2.2.2=py39hecd8cb5_0
115
+ - wheel=0.37.1=pyhd3eb1b0_0
116
+ - xz=5.2.8=h6c40b1e_0
117
+ - yarl=1.8.1=py39hca72f7f_0
118
+ - zipp=3.8.0=py39hecd8cb5_0
119
+ - zlib=1.2.13=h4dc903c_0
120
+ - zstd=1.5.2=hcb37349_0
121
+ - pip:
122
+ - cloudpickle==2.2.0
123
+ - filelock==3.8.2
124
+ - glcontext==2.3.7
125
+ - glfw==1.12.0
126
+ - gym==0.21.0
127
+ - gym-notices==0.0.8
128
+ - gym3==0.3.3
129
+ - imageio==2.22.4
130
+ - imageio-ffmpeg==0.3.0
131
+ - moderngl==5.7.4
132
+ - opencv-python==4.6.0.66
133
+ - procgen==0.10.7
134
+ - pyglet==1.5.27
135
+ prefix: /Users/karkisushant/miniconda3/envs/v2
environment.procgen.yml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: procgen
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - absl-py=1.3.0=py39hecd8cb5_0
7
+ - aiohttp=3.8.3=py39h6c40b1e_0
8
+ - aiosignal=1.2.0=pyhd3eb1b0_0
9
+ - async-timeout=4.0.2=py39hecd8cb5_0
10
+ - attrs=22.1.0=py39hecd8cb5_0
11
+ - blas=1.0=mkl
12
+ - blinker=1.4=py39hecd8cb5_0
13
+ - brotli=1.0.9=hca72f7f_7
14
+ - brotli-bin=1.0.9=hca72f7f_7
15
+ - brotlipy=0.7.0=py39h9ed2024_1003
16
+ - bzip2=1.0.8=h1de35cc_0
17
+ - c-ares=1.18.1=hca72f7f_0
18
+ - ca-certificates=2022.10.11=hecd8cb5_0
19
+ - cachetools=4.2.2=pyhd3eb1b0_0
20
+ - certifi=2022.9.24=py39hecd8cb5_0
21
+ - cffi=1.15.1=py39h6c40b1e_3
22
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
23
+ - click=8.0.4=py39hecd8cb5_0
24
+ - contourpy=1.0.5=py39haf03e11_0
25
+ - cryptography=38.0.1=py39hf6deb26_0
26
+ - cycler=0.11.0=pyhd3eb1b0_0
27
+ - ffmpeg=4.3=h0a44026_0
28
+ - flit-core=3.6.0=pyhd3eb1b0_0
29
+ - fonttools=4.25.0=pyhd3eb1b0_0
30
+ - freetype=2.12.1=hd8bbffd_0
31
+ - frozenlist=1.3.3=py39h6c40b1e_0
32
+ - gettext=0.21.0=h7535e17_0
33
+ - giflib=5.2.1=haf1e3a3_0
34
+ - gmp=6.2.1=he9d5cce_3
35
+ - gnutls=3.6.15=hed9c0bf_0
36
+ - google-auth=2.6.0=pyhd3eb1b0_0
37
+ - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
38
+ - grpcio=1.42.0=py39ha29bfda_0
39
+ - icu=58.2=h0a44026_3
40
+ - idna=3.4=py39hecd8cb5_0
41
+ - importlib-metadata=4.11.3=py39hecd8cb5_0
42
+ - intel-openmp=2021.4.0=hecd8cb5_3538
43
+ - jpeg=9e=hca72f7f_0
44
+ - kiwisolver=1.4.2=py39he9d5cce_0
45
+ - lame=3.100=h1de35cc_0
46
+ - lcms2=2.12=hf1fd2bf_0
47
+ - lerc=3.0=he9d5cce_0
48
+ - libbrotlicommon=1.0.9=hca72f7f_7
49
+ - libbrotlidec=1.0.9=hca72f7f_7
50
+ - libbrotlienc=1.0.9=hca72f7f_7
51
+ - libcxx=14.0.6=h9765a3e_0
52
+ - libdeflate=1.8=h9ed2024_5
53
+ - libffi=3.4.2=hecd8cb5_6
54
+ - libiconv=1.16=hca72f7f_2
55
+ - libidn2=2.3.2=h9ed2024_0
56
+ - libpng=1.6.37=ha441bb4_0
57
+ - libprotobuf=3.20.1=h8346a28_0
58
+ - libtasn1=4.16.0=h9ed2024_0
59
+ - libtiff=4.4.0=h2cd0358_2
60
+ - libunistring=0.9.10=h9ed2024_0
61
+ - libwebp=1.2.4=h56c3ce4_0
62
+ - libwebp-base=1.2.4=hca72f7f_0
63
+ - libxml2=2.9.14=hbf8cd5e_0
64
+ - llvm-openmp=14.0.6=h0dcd299_0
65
+ - lz4-c=1.9.4=hcec6c5f_0
66
+ - markdown=3.3.4=py39hecd8cb5_0
67
+ - markupsafe=2.1.1=py39hca72f7f_0
68
+ - matplotlib=3.6.2=py39hecd8cb5_0
69
+ - matplotlib-base=3.6.2=py39h220de94_0
70
+ - mkl=2021.4.0=hecd8cb5_637
71
+ - mkl-service=2.4.0=py39h9ed2024_0
72
+ - mkl_fft=1.3.1=py39h4ab4a9b_0
73
+ - mkl_random=1.2.2=py39hb2f4e1b_0
74
+ - multidict=6.0.2=py39hca72f7f_0
75
+ - munkres=1.1.4=py_0
76
+ - ncurses=6.3=hca72f7f_3
77
+ - nettle=3.7.3=h230ac6f_1
78
+ - numpy=1.23.4=py39he696674_0
79
+ - numpy-base=1.23.4=py39h9cd3388_0
80
+ - oauthlib=3.2.1=py39hecd8cb5_0
81
+ - openh264=2.1.1=h8346a28_0
82
+ - openssl=1.1.1s=hca72f7f_0
83
+ - packaging=21.3=pyhd3eb1b0_0
84
+ - pillow=9.2.0=py39hde71d04_1
85
+ - pip=22.3.1=py39hecd8cb5_0
86
+ - protobuf=3.20.1=py39he9d5cce_0
87
+ - pyasn1=0.4.8=pyhd3eb1b0_0
88
+ - pyasn1-modules=0.2.8=py_0
89
+ - pycparser=2.21=pyhd3eb1b0_0
90
+ - pyjwt=2.4.0=py39hecd8cb5_0
91
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
92
+ - pyparsing=3.0.9=py39hecd8cb5_0
93
+ - pysocks=1.7.1=py39hecd8cb5_0
94
+ - python=3.9.15=h218abb5_2
95
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
96
+ - pytorch=1.13.1=py3.9_0
97
+ - readline=8.2=hca72f7f_0
98
+ - requests=2.28.1=py39hecd8cb5_0
99
+ - requests-oauthlib=1.3.0=py_0
100
+ - rsa=4.7.2=pyhd3eb1b0_1
101
+ - setuptools=65.5.0=py39hecd8cb5_0
102
+ - six=1.16.0=pyhd3eb1b0_1
103
+ - sqlite=3.40.0=h880c91c_0
104
+ - tensorboard=2.9.0=py39hecd8cb5_0
105
+ - tensorboard-data-server=0.6.1=py39h7242b5c_0
106
+ - tensorboard-plugin-wit=1.6.0=py_0
107
+ - tk=8.6.12=h5d9f67b_0
108
+ - torchvision=0.14.1=py39_cpu
109
+ - tornado=6.2=py39hca72f7f_0
110
+ - tqdm=4.64.1=py39hecd8cb5_0
111
+ - typing_extensions=4.4.0=py39hecd8cb5_0
112
+ - tzdata=2022g=h04d1e81_0
113
+ - urllib3=1.26.13=py39hecd8cb5_0
114
+ - werkzeug=2.2.2=py39hecd8cb5_0
115
+ - wheel=0.37.1=pyhd3eb1b0_0
116
+ - xz=5.2.8=h6c40b1e_0
117
+ - yarl=1.8.1=py39hca72f7f_0
118
+ - zipp=3.8.0=py39hecd8cb5_0
119
+ - zlib=1.2.13=h4dc903c_0
120
+ - zstd=1.5.2=hcb37349_0
121
+ - pip:
122
+ - cloudpickle==2.2.0
123
+ - filelock==3.8.2
124
+ - glcontext==2.3.7
125
+ - glfw==1.12.0
126
+ - gym==0.21.0
127
+ - gym-notices==0.0.8
128
+ - gym3==0.3.3
129
+ - imageio==2.22.4
130
+ - imageio-ffmpeg==0.3.0
131
+ - moderngl==5.7.4
132
+ - opencv-python==4.6.0.66
133
+ - procgen==0.10.7
134
+ - pyglet==1.5.27
135
+ prefix: /Users/karkisushant/miniconda3/envs/procgen
requirements-v1.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.3.0
2
+ ale-py==0.7.5
3
+ astunparse==1.6.3
4
+ attrs==22.1.0
5
+ box2d-py==2.3.5
6
+ cachetools==5.2.0
7
+ certifi==2022.12.7
8
+ cffi==1.15.1
9
+ charset-normalizer==2.1.1
10
+ cloudpickle==2.2.0
11
+ cycler==0.11.0
12
+ Cython==0.29.32
13
+ fasteners==0.18
14
+ flatbuffers==22.12.6
15
+ fonttools==4.38.0
16
+ future==0.18.2
17
+ gast==0.4.0
18
+ glfw==2.5.5
19
+ google-auth==2.15.0
20
+ google-auth-oauthlib==0.4.6
21
+ google-pasta==0.2.0
22
+ grpcio==1.51.1
23
+ gym==0.21.0
24
+ gym-notices==0.0.8
25
+ gym-retro==0.8.0
26
+ h5py==3.7.0
27
+ idna==3.4
28
+ imageio==2.22.4
29
+ importlib-metadata==4.13.0
30
+ importlib-resources==5.10.1
31
+ iniconfig==1.1.1
32
+ keras==2.11.0
33
+ kiwisolver==1.4.4
34
+ libclang==14.0.6
35
+ lz4==4.0.2
36
+ Markdown==3.4.1
37
+ MarkupSafe==2.1.1
38
+ matplotlib==3.5.3
39
+ mujoco==2.2.0
40
+ mujoco-py==2.1.2.14
41
+ numpy==1.21.6
42
+ oauthlib==3.2.2
43
+ opencv-python==4.6.0.66
44
+ opt-einsum==3.3.0
45
+ packaging==22.0
46
+ Pillow==9.3.0
47
+ pluggy==1.0.0
48
+ protobuf==3.19.6
49
+ py==1.11.0
50
+ pyasn1==0.4.8
51
+ pyasn1-modules==0.2.8
52
+ pycparser==2.21
53
+ pygame==2.1.0
54
+ pyglet==1.5.11
55
+ PyOpenGL==3.1.6
56
+ pyparsing==3.0.9
57
+ pytest==7.0.1
58
+ python-dateutil==2.8.2
59
+ requests==2.28.1
60
+ requests-oauthlib==1.3.1
61
+ rsa==4.9
62
+ six==1.16.0
63
+ swig==4.1.1
64
+ tensorboard==2.11.0
65
+ tensorboard-data-server==0.6.1
66
+ tensorboard-plugin-wit==1.8.1
67
+ tensorflow==2.11.0
68
+ tensorflow-estimator==2.11.0
69
+ tensorflow-io-gcs-filesystem==0.28.0
70
+ termcolor==2.1.1
71
+ tomli==2.0.1
72
+ typing_extensions==4.4.0
73
+ urllib3==1.26.13
74
+ Werkzeug==2.2.2
75
+ wrapt==1.14.1
76
+ zipp==3.11.0
requirements.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.3.0
2
+ ale-py==0.7.5
3
+ attrs==22.1.0
4
+ box2d-py==2.3.5
5
+ cffi==1.15.1
6
+ cloudpickle==2.2.0
7
+ cycler==0.11.0
8
+ Cython==0.29.32
9
+ fasteners==0.18
10
+ fonttools==4.38.0
11
+ future==0.18.2
12
+ glfw==2.5.5
13
+ gym==0.21.0
14
+ gym-notices==0.0.8
15
+ gym-retro==0.8.0
16
+ imageio==2.22.4
17
+ importlib-metadata==4.13.0
18
+ importlib-resources==5.10.1
19
+ iniconfig==1.1.1
20
+ kiwisolver==1.4.4
21
+ lz4==4.0.2
22
+ matplotlib==3.5.3
23
+ mujoco==2.2.0
24
+ mujoco-py==2.1.2.14
25
+ numpy==1.18.0
26
+ opencv-python==4.6.0.66
27
+ packaging==22.0
28
+ Pillow==9.3.0
29
+ pluggy==1.0.0
30
+ py==1.11.0
31
+ pycparser==2.21
32
+ pygame==2.1.0
33
+ pyglet==1.5.11
34
+ PyOpenGL==3.1.6
35
+ pyparsing==3.0.9
36
+ pytest==7.0.1
37
+ python-dateutil==2.8.2
38
+ six==1.16.0
39
+ swig==4.1.1
40
+ tomli==2.0.1
41
+ typing_extensions==4.4.0
42
+ zipp==3.11.0
src/airstriker-genesis/__init__.py ADDED
File without changes
src/airstriker-genesis/agent.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import torch.nn as nn
5
+ import copy
6
+ import time, datetime
7
+ import matplotlib.pyplot as plt
8
+ from collections import deque
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import pickle
11
+
12
+
13
+ class DQNet(nn.Module):
14
+ """mini cnn structure
15
+ input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
16
+ """
17
+
18
+ def __init__(self, input_dim, output_dim):
19
+ super().__init__()
20
+ print("#################################")
21
+ print("#################################")
22
+ print(input_dim)
23
+ print(output_dim)
24
+ print("#################################")
25
+ print("#################################")
26
+ c, h, w = input_dim
27
+
28
+ # if h != 84:
29
+ # raise ValueError(f"Expecting input height: 84, got: {h}")
30
+ # if w != 84:
31
+ # raise ValueError(f"Expecting input width: 84, got: {w}")
32
+
33
+ self.online = nn.Sequential(
34
+ nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
35
+ nn.ReLU(),
36
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
37
+ nn.ReLU(),
38
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
39
+ nn.ReLU(),
40
+ nn.Flatten(),
41
+ nn.Linear(17024, 512),
42
+ nn.ReLU(),
43
+ nn.Linear(512, output_dim),
44
+ )
45
+
46
+
47
+ self.target = copy.deepcopy(self.online)
48
+
49
+ # Q_target parameters are frozen.
50
+ for p in self.target.parameters():
51
+ p.requires_grad = False
52
+
53
+ def forward(self, input, model):
54
+ if model == "online":
55
+ return self.online(input)
56
+ elif model == "target":
57
+ return self.target(input)
58
+
59
+
60
+
61
+ class MetricLogger:
62
+ def __init__(self, save_dir):
63
+ self.writer = SummaryWriter(log_dir=save_dir)
64
+ self.save_log = save_dir / "log"
65
+ with open(self.save_log, "w") as f:
66
+ f.write(
67
+ f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
68
+ f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
69
+ f"{'TimeDelta':>15}{'Time':>20}\n"
70
+ )
71
+ self.ep_rewards_plot = save_dir / "reward_plot.jpg"
72
+ self.ep_lengths_plot = save_dir / "length_plot.jpg"
73
+ self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
74
+ self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
75
+
76
+ # History metrics
77
+ self.ep_rewards = []
78
+ self.ep_lengths = []
79
+ self.ep_avg_losses = []
80
+ self.ep_avg_qs = []
81
+
82
+ # Moving averages, added for every call to record()
83
+ self.moving_avg_ep_rewards = []
84
+ self.moving_avg_ep_lengths = []
85
+ self.moving_avg_ep_avg_losses = []
86
+ self.moving_avg_ep_avg_qs = []
87
+
88
+ # Current episode metric
89
+ self.init_episode()
90
+
91
+ # Timing
92
+ self.record_time = time.time()
93
+
94
+ def log_step(self, reward, loss, q):
95
+ self.curr_ep_reward += reward
96
+ self.curr_ep_length += 1
97
+ if loss:
98
+ self.curr_ep_loss += loss
99
+ self.curr_ep_q += q
100
+ self.curr_ep_loss_length += 1
101
+
102
+ def log_episode(self, episode_number):
103
+ "Mark end of episode"
104
+ self.ep_rewards.append(self.curr_ep_reward)
105
+ self.ep_lengths.append(self.curr_ep_length)
106
+ if self.curr_ep_loss_length == 0:
107
+ ep_avg_loss = 0
108
+ ep_avg_q = 0
109
+ else:
110
+ ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
111
+ ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
112
+ self.ep_avg_losses.append(ep_avg_loss)
113
+ self.ep_avg_qs.append(ep_avg_q)
114
+ self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
115
+ self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
116
+ self.writer.flush()
117
+ self.init_episode()
118
+
119
+ def init_episode(self):
120
+ self.curr_ep_reward = 0.0
121
+ self.curr_ep_length = 0
122
+ self.curr_ep_loss = 0.0
123
+ self.curr_ep_q = 0.0
124
+ self.curr_ep_loss_length = 0
125
+
126
+ def record(self, episode, epsilon, step):
127
+ mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
128
+ mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
129
+ mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
130
+ mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
131
+ self.moving_avg_ep_rewards.append(mean_ep_reward)
132
+ self.moving_avg_ep_lengths.append(mean_ep_length)
133
+ self.moving_avg_ep_avg_losses.append(mean_ep_loss)
134
+ self.moving_avg_ep_avg_qs.append(mean_ep_q)
135
+
136
+ last_record_time = self.record_time
137
+ self.record_time = time.time()
138
+ time_since_last_record = np.round(self.record_time - last_record_time, 3)
139
+
140
+ print(
141
+ f"Episode {episode} - "
142
+ f"Step {step} - "
143
+ f"Epsilon {epsilon} - "
144
+ f"Mean Reward {mean_ep_reward} - "
145
+ f"Mean Length {mean_ep_length} - "
146
+ f"Mean Loss {mean_ep_loss} - "
147
+ f"Mean Q Value {mean_ep_q} - "
148
+ f"Time Delta {time_since_last_record} - "
149
+ f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
150
+ )
151
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
152
+ self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
153
+ self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
154
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
155
+ self.writer.add_scalar("Epsilon value", epsilon, episode)
156
+ self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
157
+ self.writer.flush()
158
+ with open(self.save_log, "a") as f:
159
+ f.write(
160
+ f"{episode:8d}{step:8d}{epsilon:10.3f}"
161
+ f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
162
+ f"{time_since_last_record:15.3f}"
163
+ f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
164
+ )
165
+
166
+ for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
167
+ plt.plot(getattr(self, f"moving_avg_{metric}"))
168
+ plt.savefig(getattr(self, f"{metric}_plot"))
169
+ plt.clf()
170
+
171
+
172
+ class DQNAgent:
173
+ def __init__(self,
174
+ state_dim,
175
+ action_dim,
176
+ save_dir,
177
+ checkpoint=None,
178
+ learning_rate=0.00025,
179
+ max_memory_size=100000,
180
+ batch_size=32,
181
+ exploration_rate=1,
182
+ exploration_rate_decay=0.9999999,
183
+ exploration_rate_min=0.1,
184
+ training_frequency=1,
185
+ learning_starts=1000,
186
+ target_network_sync_frequency=500,
187
+ reset_exploration_rate=False,
188
+ save_frequency=100000,
189
+ gamma=0.9,
190
+ load_replay_buffer=True):
191
+ self.state_dim = state_dim
192
+ self.action_dim = action_dim
193
+ self.max_memory_size = max_memory_size
194
+ self.memory = deque(maxlen=max_memory_size)
195
+ self.batch_size = batch_size
196
+
197
+ self.exploration_rate = exploration_rate
198
+ self.exploration_rate_decay = exploration_rate_decay
199
+ self.exploration_rate_min = exploration_rate_min
200
+ self.gamma = gamma
201
+
202
+ self.curr_step = 0
203
+ self.learning_starts = learning_starts # min. experiences before training
204
+
205
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
206
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
207
+
208
+ self.save_every = save_frequency # no. of experiences between saving Mario Net
209
+ self.save_dir = save_dir
210
+
211
+ self.use_cuda = torch.cuda.is_available()
212
+
213
+ # Mario's DNN to predict the most optimal action - we implement this in the Learn section
214
+ self.net = DQNet(self.state_dim, self.action_dim).float()
215
+ if self.use_cuda:
216
+ self.net = self.net.to(device='cuda')
217
+ if checkpoint:
218
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
219
+
220
+ self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
221
+ self.loss_fn = torch.nn.SmoothL1Loss()
222
+
223
+
224
+ def act(self, state):
225
+ """
226
+ Given a state, choose an epsilon-greedy action and update value of step.
227
+
228
+ Inputs:
229
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
230
+ Outputs:
231
+ action_idx (int): An integer representing which action Mario will perform
232
+ """
233
+ # EXPLORE
234
+ if np.random.rand() < self.exploration_rate:
235
+ action_idx = np.random.randint(self.action_dim)
236
+
237
+ # EXPLOIT
238
+ else:
239
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
240
+ state = state.unsqueeze(0)
241
+ action_values = self.net(state, model='online')
242
+ action_idx = torch.argmax(action_values, axis=1).item()
243
+
244
+ # decrease exploration_rate
245
+ self.exploration_rate *= self.exploration_rate_decay
246
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
247
+
248
+ # increment step
249
+ self.curr_step += 1
250
+ return action_idx
251
+
252
+ def cache(self, state, next_state, action, reward, done):
253
+ """
254
+ Store the experience to self.memory (replay buffer)
255
+
256
+ Inputs:
257
+ state (LazyFrame),
258
+ next_state (LazyFrame),
259
+ action (int),
260
+ reward (float),
261
+ done(bool))
262
+ """
263
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
264
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
265
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
266
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
267
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
268
+
269
+ self.memory.append( (state, next_state, action, reward, done,) )
270
+
271
+
272
+ def recall(self):
273
+ """
274
+ Retrieve a batch of experiences from memory
275
+ """
276
+ batch = random.sample(self.memory, self.batch_size)
277
+ state, next_state, action, reward, done = map(torch.stack, zip(*batch))
278
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
279
+
280
+
281
+ # def td_estimate(self, state, action):
282
+ # current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
283
+ # return current_Q
284
+
285
+
286
+ # @torch.no_grad()
287
+ # def td_target(self, reward, next_state, done):
288
+ # next_state_Q = self.net(next_state, model='online')
289
+ # best_action = torch.argmax(next_state_Q, axis=1)
290
+ # next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
291
+ # return (reward + (1 - done.float()) * self.gamma * next_Q).float()
292
+
293
+ def td_estimate(self, states, actions):
294
+ actions = actions.reshape(-1, 1)
295
+ predicted_qs = self.net(states, model='online')# Q_online(s,a)
296
+ predicted_qs = predicted_qs.gather(1, actions)
297
+ return predicted_qs
298
+
299
+
300
+ @torch.no_grad()
301
+ def td_target(self, rewards, next_states, dones):
302
+ rewards = rewards.reshape(-1, 1)
303
+ dones = dones.reshape(-1, 1)
304
+ target_qs = self.net(next_states, model='target')
305
+ target_qs = torch.max(target_qs, dim=1).values
306
+ target_qs = target_qs.reshape(-1, 1)
307
+ target_qs[dones] = 0.0
308
+ return (rewards + (self.gamma * target_qs))
309
+
310
+ def update_Q_online(self, td_estimate, td_target) :
311
+ loss = self.loss_fn(td_estimate, td_target)
312
+ self.optimizer.zero_grad()
313
+ loss.backward()
314
+ self.optimizer.step()
315
+ return loss.item()
316
+
317
+
318
+ def sync_Q_target(self):
319
+ self.net.target.load_state_dict(self.net.online.state_dict())
320
+
321
+
322
+ def learn(self):
323
+ if self.curr_step % self.target_network_sync_frequency == 0:
324
+ self.sync_Q_target()
325
+
326
+ if self.curr_step % self.save_every == 0:
327
+ self.save()
328
+
329
+ if self.curr_step < self.learning_starts:
330
+ return None, None
331
+
332
+ if self.curr_step % self.training_frequency != 0:
333
+ return None, None
334
+
335
+ # Sample from memory
336
+ state, next_state, action, reward, done = self.recall()
337
+
338
+ # Get TD Estimate
339
+ td_est = self.td_estimate(state, action)
340
+
341
+ # Get TD Target
342
+ td_tgt = self.td_target(reward, next_state, done)
343
+
344
+ # Backpropagate loss through Q_online
345
+ loss = self.update_Q_online(td_est, td_tgt)
346
+
347
+ return (td_est.mean().item(), loss)
348
+
349
+
350
+ def save(self):
351
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
352
+ torch.save(
353
+ dict(
354
+ model=self.net.state_dict(),
355
+ exploration_rate=self.exploration_rate,
356
+ replay_memory=self.memory
357
+ ),
358
+ save_path
359
+ )
360
+
361
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
362
+
363
+
364
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
365
+ if not load_path.exists():
366
+ raise ValueError(f"{load_path} does not exist")
367
+
368
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
369
+ exploration_rate = ckp.get('exploration_rate')
370
+ state_dict = ckp.get('model')
371
+
372
+
373
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
374
+ self.net.load_state_dict(state_dict)
375
+
376
+ if load_replay_buffer:
377
+ replay_memory = ckp.get('replay_memory')
378
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
379
+ self.memory = replay_memory if replay_memory else self.memory
380
+
381
+ if reset_exploration_rate:
382
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
383
+ else:
384
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
385
+ self.exploration_rate = exploration_rate
386
+
387
+
388
+ class DDQNAgent(DQNAgent):
389
+ @torch.no_grad()
390
+ def td_target(self, rewards, next_states, dones):
391
+ print("Double dqn -----------------------")
392
+ rewards = rewards.reshape(-1, 1)
393
+ dones = dones.reshape(-1, 1)
394
+ q_vals = self.net(next_states, model='online')
395
+ target_actions = torch.argmax(q_vals, axis=1)
396
+ target_actions = target_actions.reshape(-1, 1)
397
+ target_qs = self.net(next_states, model='target').gather(target_actions, 1)
398
+ target_qs = target_qs.reshape(-1, 1)
399
+ target_qs[dones] = 0.0
400
+ return (rewards + (self.gamma * target_qs))
src/airstriker-genesis/cartpole.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+ import random
5
+ import torch.nn as nn
6
+ import copy
7
+ import time, datetime
8
+ import matplotlib.pyplot as plt
9
+ from collections import deque
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ import pickle
12
+
13
+
14
+ class MyDQN(nn.Module):
15
+ """mini cnn structure
16
+ input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
17
+ """
18
+
19
+ def __init__(self, input_dim, output_dim):
20
+ super().__init__()
21
+
22
+ self.online = nn.Sequential(
23
+ nn.Linear(input_dim, 128),
24
+ nn.ReLU(),
25
+ nn.Linear(128, 128),
26
+ nn.ReLU(),
27
+ nn.Linear(128, output_dim)
28
+ )
29
+
30
+
31
+ self.target = copy.deepcopy(self.online)
32
+
33
+ # Q_target parameters are frozen.
34
+ for p in self.target.parameters():
35
+ p.requires_grad = False
36
+
37
+ def forward(self, input, model):
38
+ if model == "online":
39
+ return self.online(input)
40
+ elif model == "target":
41
+ return self.target(input)
42
+
43
+
44
+
45
+ class MetricLogger:
46
+ def __init__(self, save_dir):
47
+ self.writer = SummaryWriter(log_dir=save_dir)
48
+ self.save_log = save_dir / "log"
49
+ with open(self.save_log, "w") as f:
50
+ f.write(
51
+ f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
52
+ f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
53
+ f"{'TimeDelta':>15}{'Time':>20}\n"
54
+ )
55
+ self.ep_rewards_plot = save_dir / "reward_plot.jpg"
56
+ self.ep_lengths_plot = save_dir / "length_plot.jpg"
57
+ self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
58
+ self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
59
+
60
+ # History metrics
61
+ self.ep_rewards = []
62
+ self.ep_lengths = []
63
+ self.ep_avg_losses = []
64
+ self.ep_avg_qs = []
65
+
66
+ # Moving averages, added for every call to record()
67
+ self.moving_avg_ep_rewards = []
68
+ self.moving_avg_ep_lengths = []
69
+ self.moving_avg_ep_avg_losses = []
70
+ self.moving_avg_ep_avg_qs = []
71
+
72
+ # Current episode metric
73
+ self.init_episode()
74
+
75
+ # Timing
76
+ self.record_time = time.time()
77
+
78
+ def log_step(self, reward, loss, q):
79
+ self.curr_ep_reward += reward
80
+ self.curr_ep_length += 1
81
+ if loss:
82
+ self.curr_ep_loss += loss
83
+ self.curr_ep_q += q
84
+ self.curr_ep_loss_length += 1
85
+
86
+ def log_episode(self, episode_number):
87
+ "Mark end of episode"
88
+ self.ep_rewards.append(self.curr_ep_reward)
89
+ self.ep_lengths.append(self.curr_ep_length)
90
+ if self.curr_ep_loss_length == 0:
91
+ ep_avg_loss = 0
92
+ ep_avg_q = 0
93
+ else:
94
+ ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
95
+ ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
96
+ self.ep_avg_losses.append(ep_avg_loss)
97
+ self.ep_avg_qs.append(ep_avg_q)
98
+ self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
99
+ self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
100
+ self.writer.flush()
101
+ self.init_episode()
102
+
103
+ def init_episode(self):
104
+ self.curr_ep_reward = 0.0
105
+ self.curr_ep_length = 0
106
+ self.curr_ep_loss = 0.0
107
+ self.curr_ep_q = 0.0
108
+ self.curr_ep_loss_length = 0
109
+
110
+ def record(self, episode, epsilon, step):
111
+ mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
112
+ mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
113
+ mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
114
+ mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
115
+ self.moving_avg_ep_rewards.append(mean_ep_reward)
116
+ self.moving_avg_ep_lengths.append(mean_ep_length)
117
+ self.moving_avg_ep_avg_losses.append(mean_ep_loss)
118
+ self.moving_avg_ep_avg_qs.append(mean_ep_q)
119
+
120
+ last_record_time = self.record_time
121
+ self.record_time = time.time()
122
+ time_since_last_record = np.round(self.record_time - last_record_time, 3)
123
+
124
+ print(
125
+ f"Episode {episode} - "
126
+ f"Step {step} - "
127
+ f"Epsilon {epsilon} - "
128
+ f"Mean Reward {mean_ep_reward} - "
129
+ f"Mean Length {mean_ep_length} - "
130
+ f"Mean Loss {mean_ep_loss} - "
131
+ f"Mean Q Value {mean_ep_q} - "
132
+ f"Time Delta {time_since_last_record} - "
133
+ f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
134
+ )
135
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
136
+ self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
137
+ self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
138
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
139
+ self.writer.add_scalar("Epsilon value", epsilon, episode)
140
+ self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
141
+ self.writer.flush()
142
+ with open(self.save_log, "a") as f:
143
+ f.write(
144
+ f"{episode:8d}{step:8d}{epsilon:10.3f}"
145
+ f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
146
+ f"{time_since_last_record:15.3f}"
147
+ f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
148
+ )
149
+
150
+ for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
151
+ plt.plot(getattr(self, f"moving_avg_{metric}"))
152
+ plt.savefig(getattr(self, f"{metric}_plot"))
153
+ plt.clf()
154
+
155
+
156
+ class MyAgent:
157
+ def __init__(self, state_dim, action_dim, save_dir, checkpoint=None, reset_exploration_rate=False, max_memory_size=100000):
158
+ self.state_dim = state_dim
159
+ self.action_dim = action_dim
160
+ self.max_memory_size = max_memory_size
161
+ self.memory = deque(maxlen=max_memory_size)
162
+ # self.batch_size = 32
163
+ self.batch_size = 512
164
+
165
+ self.exploration_rate = 1
166
+ # self.exploration_rate_decay = 0.99999975
167
+ self.exploration_rate_decay = 0.9999999
168
+ self.exploration_rate_min = 0.1
169
+ self.gamma = 0.9
170
+
171
+ self.curr_step = 0
172
+ self.learning_start_threshold = 10000 # min. experiences before training
173
+
174
+ self.learn_every = 5 # no. of experiences between updates to Q_online
175
+ self.sync_every = 200 # no. of experiences between Q_target & Q_online sync
176
+
177
+ self.save_every = 200000 # no. of experiences between saving Mario Net
178
+ self.save_dir = save_dir
179
+
180
+ self.use_cuda = torch.cuda.is_available()
181
+
182
+ # Mario's DNN to predict the most optimal action - we implement this in the Learn section
183
+ self.net = MyDQN(self.state_dim, self.action_dim).float()
184
+ if self.use_cuda:
185
+ self.net = self.net.to(device='cuda')
186
+ if checkpoint:
187
+ self.load(checkpoint, reset_exploration_rate)
188
+
189
+ # self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
190
+ self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=0.00025, amsgrad=True)
191
+ self.loss_fn = torch.nn.SmoothL1Loss()
192
+
193
+
194
+ def act(self, state):
195
+ """
196
+ Given a state, choose an epsilon-greedy action and update value of step.
197
+
198
+ Inputs:
199
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
200
+ Outputs:
201
+ action_idx (int): An integer representing which action Mario will perform
202
+ """
203
+ # EXPLORE
204
+ if np.random.rand() < self.exploration_rate:
205
+ action_idx = np.random.randint(self.action_dim)
206
+
207
+ # EXPLOIT
208
+ else:
209
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
210
+ state = state.unsqueeze(0)
211
+ action_values = self.net(state, model='online')
212
+ action_idx = torch.argmax(action_values, axis=1).item()
213
+
214
+ # decrease exploration_rate
215
+ self.exploration_rate *= self.exploration_rate_decay
216
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
217
+
218
+ # increment step
219
+ self.curr_step += 1
220
+ return action_idx
221
+
222
+ def cache(self, state, next_state, action, reward, done):
223
+ """
224
+ Store the experience to self.memory (replay buffer)
225
+
226
+ Inputs:
227
+ state (LazyFrame),
228
+ next_state (LazyFrame),
229
+ action (int),
230
+ reward (float),
231
+ done(bool))
232
+ """
233
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
234
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
235
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
236
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
237
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
238
+
239
+ self.memory.append( (state, next_state, action, reward, done,) )
240
+
241
+
242
+ def recall(self):
243
+ """
244
+ Retrieve a batch of experiences from memory
245
+ """
246
+ batch = random.sample(self.memory, self.batch_size)
247
+ state, next_state, action, reward, done = map(torch.stack, zip(*batch))
248
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
249
+
250
+
251
+ # def td_estimate(self, state, action):
252
+ # current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
253
+ # return current_Q
254
+
255
+
256
+ # @torch.no_grad()
257
+ # def td_target(self, reward, next_state, done):
258
+ # next_state_Q = self.net(next_state, model='online')
259
+ # best_action = torch.argmax(next_state_Q, axis=1)
260
+ # next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
261
+ # return (reward + (1 - done.float()) * self.gamma * next_Q).float()
262
+
263
+ def td_estimate(self, states, actions):
264
+ actions = actions.reshape(-1, 1)
265
+ predicted_qs = self.net(states, model='online')# Q_online(s,a)
266
+ predicted_qs = predicted_qs.gather(1, actions)
267
+ return predicted_qs
268
+
269
+
270
+ @torch.no_grad()
271
+ def td_target(self, rewards, next_states, dones):
272
+ rewards = rewards.reshape(-1, 1)
273
+ dones = dones.reshape(-1, 1)
274
+ target_qs = self.net(next_states, model='target')
275
+ target_qs = torch.max(target_qs, dim=1).values
276
+ target_qs = target_qs.reshape(-1, 1)
277
+ target_qs[dones] = 0.0
278
+ return (rewards + (self.gamma * target_qs))
279
+
280
+ def update_Q_online(self, td_estimate, td_target) :
281
+ loss = self.loss_fn(td_estimate, td_target)
282
+ self.optimizer.zero_grad()
283
+ loss.backward()
284
+ self.optimizer.step()
285
+ return loss.item()
286
+
287
+
288
+ def sync_Q_target(self):
289
+ self.net.target.load_state_dict(self.net.online.state_dict())
290
+
291
+
292
+ def learn(self):
293
+ if self.curr_step % self.sync_every == 0:
294
+ self.sync_Q_target()
295
+
296
+ if self.curr_step % self.save_every == 0:
297
+ self.save()
298
+
299
+ if self.curr_step < self.learning_start_threshold:
300
+ return None, None
301
+
302
+ if self.curr_step % self.learn_every != 0:
303
+ return None, None
304
+
305
+ # Sample from memory
306
+ state, next_state, action, reward, done = self.recall()
307
+
308
+ # Get TD Estimate
309
+ td_est = self.td_estimate(state, action)
310
+
311
+ # Get TD Target
312
+ td_tgt = self.td_target(reward, next_state, done)
313
+
314
+ # Backpropagate loss through Q_online
315
+ loss = self.update_Q_online(td_est, td_tgt)
316
+
317
+ return (td_est.mean().item(), loss)
318
+
319
+
320
+ def save(self):
321
+ save_path = self.save_dir / f"cartpole_net_{int(self.curr_step // self.save_every)}.chkpt"
322
+ torch.save(
323
+ dict(
324
+ model=self.net.state_dict(),
325
+ exploration_rate=self.exploration_rate,
326
+ replay_memory=self.memory
327
+ ),
328
+ save_path
329
+ )
330
+
331
+ print(f"Cartpole Net saved to {save_path} at step {self.curr_step}")
332
+
333
+
334
+ def load(self, load_path, reset_exploration_rate=False):
335
+ if not load_path.exists():
336
+ raise ValueError(f"{load_path} does not exist")
337
+
338
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
339
+ exploration_rate = ckp.get('exploration_rate')
340
+ state_dict = ckp.get('model')
341
+ replay_memory = ckp.get('replay_memory')
342
+
343
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
344
+ self.net.load_state_dict(state_dict)
345
+
346
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
347
+ self.memory = replay_memory if replay_memory else self.memory
348
+
349
+ if reset_exploration_rate:
350
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
351
+ else:
352
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
353
+ self.exploration_rate = exploration_rate
src/airstriker-genesis/procgen_agent.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import torch.nn as nn
5
+ import copy
6
+ import time, datetime
7
+ import matplotlib.pyplot as plt
8
+ from collections import deque
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import pickle
11
+
12
+
13
+ class DQNet(nn.Module):
14
+ """mini cnn structure
15
+ input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
16
+ """
17
+
18
+ def __init__(self, input_dim, output_dim):
19
+ super().__init__()
20
+ print("#################################")
21
+ print("#################################")
22
+ print(input_dim)
23
+ print(output_dim)
24
+ print("#################################")
25
+ print("#################################")
26
+ c, h, w = input_dim
27
+
28
+ # if h != 84:
29
+ # raise ValueError(f"Expecting input height: 84, got: {h}")
30
+ # if w != 84:
31
+ # raise ValueError(f"Expecting input width: 84, got: {w}")
32
+
33
+ self.online = nn.Sequential(
34
+ nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
35
+ nn.ReLU(),
36
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
37
+ nn.ReLU(),
38
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
39
+ nn.ReLU(),
40
+ nn.Flatten(),
41
+ nn.Linear(7168, 512),
42
+ nn.ReLU(),
43
+ nn.Linear(512, output_dim),
44
+ )
45
+
46
+
47
+ self.target = copy.deepcopy(self.online)
48
+
49
+ # Q_target parameters are frozen.
50
+ for p in self.target.parameters():
51
+ p.requires_grad = False
52
+
53
+ def forward(self, input, model):
54
+ if model == "online":
55
+ return self.online(input)
56
+ elif model == "target":
57
+ return self.target(input)
58
+
59
+
60
+
61
+ class MetricLogger:
62
+ def __init__(self, save_dir):
63
+ self.writer = SummaryWriter(log_dir=save_dir)
64
+ self.save_log = save_dir / "log"
65
+ with open(self.save_log, "w") as f:
66
+ f.write(
67
+ f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
68
+ f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
69
+ f"{'TimeDelta':>15}{'Time':>20}\n"
70
+ )
71
+ self.ep_rewards_plot = save_dir / "reward_plot.jpg"
72
+ self.ep_lengths_plot = save_dir / "length_plot.jpg"
73
+ self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
74
+ self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
75
+
76
+ # History metrics
77
+ self.ep_rewards = []
78
+ self.ep_lengths = []
79
+ self.ep_avg_losses = []
80
+ self.ep_avg_qs = []
81
+
82
+ # Moving averages, added for every call to record()
83
+ self.moving_avg_ep_rewards = []
84
+ self.moving_avg_ep_lengths = []
85
+ self.moving_avg_ep_avg_losses = []
86
+ self.moving_avg_ep_avg_qs = []
87
+
88
+ # Current episode metric
89
+ self.init_episode()
90
+
91
+ # Timing
92
+ self.record_time = time.time()
93
+
94
+ def log_step(self, reward, loss, q):
95
+ self.curr_ep_reward += reward
96
+ self.curr_ep_length += 1
97
+ if loss:
98
+ self.curr_ep_loss += loss
99
+ self.curr_ep_q += q
100
+ self.curr_ep_loss_length += 1
101
+
102
+ def log_episode(self, episode_number):
103
+ "Mark end of episode"
104
+ self.ep_rewards.append(self.curr_ep_reward)
105
+ self.ep_lengths.append(self.curr_ep_length)
106
+ if self.curr_ep_loss_length == 0:
107
+ ep_avg_loss = 0
108
+ ep_avg_q = 0
109
+ else:
110
+ ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
111
+ ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
112
+ self.ep_avg_losses.append(ep_avg_loss)
113
+ self.ep_avg_qs.append(ep_avg_q)
114
+ self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
115
+ self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
116
+ self.writer.flush()
117
+ self.init_episode()
118
+
119
+ def init_episode(self):
120
+ self.curr_ep_reward = 0.0
121
+ self.curr_ep_length = 0
122
+ self.curr_ep_loss = 0.0
123
+ self.curr_ep_q = 0.0
124
+ self.curr_ep_loss_length = 0
125
+
126
+ def record(self, episode, epsilon, step):
127
+ mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
128
+ mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
129
+ mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
130
+ mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
131
+ self.moving_avg_ep_rewards.append(mean_ep_reward)
132
+ self.moving_avg_ep_lengths.append(mean_ep_length)
133
+ self.moving_avg_ep_avg_losses.append(mean_ep_loss)
134
+ self.moving_avg_ep_avg_qs.append(mean_ep_q)
135
+
136
+ last_record_time = self.record_time
137
+ self.record_time = time.time()
138
+ time_since_last_record = np.round(self.record_time - last_record_time, 3)
139
+
140
+ print(
141
+ f"Episode {episode} - "
142
+ f"Step {step} - "
143
+ f"Epsilon {epsilon} - "
144
+ f"Mean Reward {mean_ep_reward} - "
145
+ f"Mean Length {mean_ep_length} - "
146
+ f"Mean Loss {mean_ep_loss} - "
147
+ f"Mean Q Value {mean_ep_q} - "
148
+ f"Time Delta {time_since_last_record} - "
149
+ f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
150
+ )
151
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
152
+ self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
153
+ self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
154
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
155
+ self.writer.add_scalar("Epsilon value", epsilon, episode)
156
+ self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
157
+ self.writer.flush()
158
+ with open(self.save_log, "a") as f:
159
+ f.write(
160
+ f"{episode:8d}{step:8d}{epsilon:10.3f}"
161
+ f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
162
+ f"{time_since_last_record:15.3f}"
163
+ f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
164
+ )
165
+
166
+ for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
167
+ plt.plot(getattr(self, f"moving_avg_{metric}"))
168
+ plt.savefig(getattr(self, f"{metric}_plot"))
169
+ plt.clf()
170
+
171
+
172
+ class DQNAgent:
173
+ def __init__(self,
174
+ state_dim,
175
+ action_dim,
176
+ save_dir,
177
+ checkpoint=None,
178
+ learning_rate=0.00025,
179
+ max_memory_size=100000,
180
+ batch_size=32,
181
+ exploration_rate=1,
182
+ exploration_rate_decay=0.9999999,
183
+ exploration_rate_min=0.1,
184
+ training_frequency=1,
185
+ learning_starts=1000,
186
+ target_network_sync_frequency=500,
187
+ reset_exploration_rate=False,
188
+ save_frequency=100000,
189
+ gamma=0.9,
190
+ load_replay_buffer=True):
191
+ self.state_dim = state_dim
192
+ self.action_dim = action_dim
193
+ self.max_memory_size = max_memory_size
194
+ self.memory = deque(maxlen=max_memory_size)
195
+ self.batch_size = batch_size
196
+
197
+ self.exploration_rate = exploration_rate
198
+ self.exploration_rate_decay = exploration_rate_decay
199
+ self.exploration_rate_min = exploration_rate_min
200
+ self.gamma = gamma
201
+
202
+ self.curr_step = 0
203
+ self.learning_starts = learning_starts # min. experiences before training
204
+
205
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
206
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
207
+
208
+ self.save_every = save_frequency # no. of experiences between saving Mario Net
209
+ self.save_dir = save_dir
210
+
211
+ self.use_cuda = torch.cuda.is_available()
212
+
213
+ # Mario's DNN to predict the most optimal action - we implement this in the Learn section
214
+ self.net = DQNet(self.state_dim, self.action_dim).float()
215
+ if self.use_cuda:
216
+ self.net = self.net.to(device='cuda')
217
+ if checkpoint:
218
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
219
+
220
+ self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
221
+ self.loss_fn = torch.nn.SmoothL1Loss()
222
+
223
+
224
+ def act(self, state):
225
+ """
226
+ Given a state, choose an epsilon-greedy action and update value of step.
227
+
228
+ Inputs:
229
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
230
+ Outputs:
231
+ action_idx (int): An integer representing which action Mario will perform
232
+ """
233
+ # EXPLORE
234
+ if np.random.rand() < self.exploration_rate:
235
+ action_idx = np.random.randint(self.action_dim)
236
+
237
+ # EXPLOIT
238
+ else:
239
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
240
+ state = state.unsqueeze(0)
241
+ action_values = self.net(state, model='online')
242
+ action_idx = torch.argmax(action_values, axis=1).item()
243
+
244
+ # decrease exploration_rate
245
+ self.exploration_rate *= self.exploration_rate_decay
246
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
247
+
248
+ # increment step
249
+ self.curr_step += 1
250
+ return action_idx
251
+
252
+ def cache(self, state, next_state, action, reward, done):
253
+ """
254
+ Store the experience to self.memory (replay buffer)
255
+
256
+ Inputs:
257
+ state (LazyFrame),
258
+ next_state (LazyFrame),
259
+ action (int),
260
+ reward (float),
261
+ done(bool))
262
+ """
263
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
264
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
265
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
266
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
267
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
268
+
269
+ self.memory.append( (state, next_state, action, reward, done,) )
270
+
271
+
272
+ def recall(self):
273
+ """
274
+ Retrieve a batch of experiences from memory
275
+ """
276
+ batch = random.sample(self.memory, self.batch_size)
277
+ state, next_state, action, reward, done = map(torch.stack, zip(*batch))
278
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
279
+
280
+
281
+ # def td_estimate(self, state, action):
282
+ # current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
283
+ # return current_Q
284
+
285
+
286
+ # @torch.no_grad()
287
+ # def td_target(self, reward, next_state, done):
288
+ # next_state_Q = self.net(next_state, model='online')
289
+ # best_action = torch.argmax(next_state_Q, axis=1)
290
+ # next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
291
+ # return (reward + (1 - done.float()) * self.gamma * next_Q).float()
292
+
293
+ def td_estimate(self, states, actions):
294
+ actions = actions.reshape(-1, 1)
295
+ predicted_qs = self.net(states, model='online')# Q_online(s,a)
296
+ predicted_qs = predicted_qs.gather(1, actions)
297
+ return predicted_qs
298
+
299
+
300
+ @torch.no_grad()
301
+ def td_target(self, rewards, next_states, dones):
302
+ rewards = rewards.reshape(-1, 1)
303
+ dones = dones.reshape(-1, 1)
304
+ target_qs = self.net(next_states, model='target')
305
+ target_qs = torch.max(target_qs, dim=1).values
306
+ target_qs = target_qs.reshape(-1, 1)
307
+ target_qs[dones] = 0.0
308
+ return (rewards + (self.gamma * target_qs))
309
+
310
+ def update_Q_online(self, td_estimate, td_target) :
311
+ loss = self.loss_fn(td_estimate, td_target)
312
+ self.optimizer.zero_grad()
313
+ loss.backward()
314
+ self.optimizer.step()
315
+ return loss.item()
316
+
317
+
318
+ def sync_Q_target(self):
319
+ self.net.target.load_state_dict(self.net.online.state_dict())
320
+
321
+
322
+ def learn(self):
323
+ if self.curr_step % self.target_network_sync_frequency == 0:
324
+ self.sync_Q_target()
325
+
326
+ if self.curr_step % self.save_every == 0:
327
+ self.save()
328
+
329
+ if self.curr_step < self.learning_starts:
330
+ return None, None
331
+
332
+ if self.curr_step % self.training_frequency != 0:
333
+ return None, None
334
+
335
+ # Sample from memory
336
+ state, next_state, action, reward, done = self.recall()
337
+
338
+ # Get TD Estimate
339
+ td_est = self.td_estimate(state, action)
340
+
341
+ # Get TD Target
342
+ td_tgt = self.td_target(reward, next_state, done)
343
+
344
+ # Backpropagate loss through Q_online
345
+ loss = self.update_Q_online(td_est, td_tgt)
346
+
347
+ return (td_est.mean().item(), loss)
348
+
349
+
350
+ def save(self):
351
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
352
+ torch.save(
353
+ dict(
354
+ model=self.net.state_dict(),
355
+ exploration_rate=self.exploration_rate,
356
+ replay_memory=self.memory
357
+ ),
358
+ save_path
359
+ )
360
+
361
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
362
+
363
+
364
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
365
+ if not load_path.exists():
366
+ raise ValueError(f"{load_path} does not exist")
367
+
368
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
369
+ exploration_rate = ckp.get('exploration_rate')
370
+ state_dict = ckp.get('model')
371
+
372
+
373
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
374
+ self.net.load_state_dict(state_dict)
375
+
376
+ if load_replay_buffer:
377
+ replay_memory = ckp.get('replay_memory')
378
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
379
+ self.memory = replay_memory if replay_memory else self.memory
380
+
381
+ if reset_exploration_rate:
382
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
383
+ else:
384
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
385
+ self.exploration_rate = exploration_rate
386
+
387
+
388
+ class DDQNAgent(DQNAgent):
389
+ @torch.no_grad()
390
+ def td_target(self, rewards, next_states, dones):
391
+ print("Double dqn -----------------------")
392
+ rewards = rewards.reshape(-1, 1)
393
+ dones = dones.reshape(-1, 1)
394
+ q_vals = self.net(next_states, model='online')
395
+ target_actions = torch.argmax(q_vals, axis=1)
396
+ target_actions = target_actions.reshape(-1, 1)
397
+ target_qs = self.net(next_states, model='target').gather(target_actions, 1)
398
+ target_qs = target_qs.reshape(-1, 1)
399
+ target_qs[dones] = 0.0
400
+ return (rewards + (self.gamma * target_qs))
src/airstriker-genesis/replay.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pathlib import Path
3
+ from itertools import count
4
+ from agent import DQNAgent, MetricLogger
5
+ from wrappers import make_env, make_starpilot
6
+
7
+
8
+ env = make_starpilot()
9
+
10
+ env.reset()
11
+
12
+ save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
13
+ save_dir.mkdir(parents=True)
14
+
15
+ checkpoint = Path('checkpoints/procgen-starpilot-dqn/airstriker_net_3.chkpt')
16
+
17
+ agent = DQNAgent(
18
+ state_dim=(1, 64, 64),
19
+ action_dim=env.action_space.n,
20
+ save_dir=save_dir,
21
+ batch_size=256,
22
+ checkpoint=checkpoint,
23
+ reset_exploration_rate=True,
24
+ exploration_rate_decay=0.999999,
25
+ training_frequency=10,
26
+ target_network_sync_frequency=200,
27
+ max_memory_size=3000,
28
+ learning_rate=0.001,
29
+ save_frequency=2000
30
+
31
+ )
32
+ agent.exploration_rate = agent.exploration_rate_min
33
+
34
+ # logger = MetricLogger(save_dir)
35
+
36
+ episodes = 100
37
+
38
+ for e in range(episodes):
39
+
40
+ state = env.reset()
41
+
42
+ while True:
43
+
44
+ env.render()
45
+
46
+ action = agent.act(state)
47
+
48
+ next_state, reward, done, info = env.step(action)
49
+
50
+ agent.cache(state, next_state, action, reward, done)
51
+
52
+ # logger.log_step(reward, None, None)
53
+
54
+ state = next_state
55
+
56
+ if done:
57
+ break
58
+
59
+ # logger.log_episode()
60
+
61
+ # if e % 20 == 0:
62
+ # logger.record(
63
+ # episode=e,
64
+ # epsilon=agent.exploration_rate,
65
+ # step=agent.curr_step
66
+ # )
src/airstriker-genesis/run-airstriker-ddqn.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+
6
+ from pathlib import Path
7
+ from tqdm import trange
8
+ from agent import DQNAgent, DDQNAgent, MetricLogger
9
+ from wrappers import make_env
10
+
11
+
12
+ # set up matplotlib
13
+ is_ipython = 'inline' in matplotlib.get_backend()
14
+ if is_ipython:
15
+ from IPython import display
16
+
17
+ plt.ion()
18
+
19
+
20
+ env = make_env()
21
+
22
+ use_cuda = torch.cuda.is_available()
23
+ print(f"Using CUDA: {use_cuda}\n")
24
+
25
+
26
+ checkpoint = None
27
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
28
+
29
+ path = "checkpoints/airstriker-ddqn"
30
+ save_dir = Path(path)
31
+
32
+ isExist = os.path.exists(path)
33
+ if not isExist:
34
+ os.makedirs(path)
35
+
36
+ # Vanilla DQN
37
+ print("Training Vanilla DQN Agent!")
38
+ # agent = DQNAgent(
39
+ # state_dim=(1, 84, 84),
40
+ # action_dim=env.action_space.n,
41
+ # save_dir=save_dir,
42
+ # batch_size=128,
43
+ # checkpoint=checkpoint,
44
+ # exploration_rate_decay=0.995,
45
+ # exploration_rate_min=0.05,
46
+ # training_frequency=1,
47
+ # target_network_sync_frequency=500,
48
+ # max_memory_size=50000,
49
+ # learning_rate=0.0005,
50
+
51
+ # )
52
+
53
+ # Double DQN
54
+ print("Training DDQN Agent!")
55
+ agent = DDQNAgent(
56
+ state_dim=(1, 84, 84),
57
+ action_dim=env.action_space.n,
58
+ save_dir=save_dir,
59
+ batch_size=128,
60
+ checkpoint=checkpoint,
61
+ exploration_rate_decay=0.995,
62
+ exploration_rate_min=0.05,
63
+ training_frequency=1,
64
+ target_network_sync_frequency=500,
65
+ max_memory_size=50000,
66
+ learning_rate=0.0005,
67
+ )
68
+
69
+ logger = MetricLogger(save_dir)
70
+
71
+ def fill_memory(agent: DQNAgent, num_episodes=1000):
72
+ print("Filling up memory....")
73
+ for _ in trange(num_episodes):
74
+ state = env.reset()
75
+ done = False
76
+ while not done:
77
+ action = agent.act(state)
78
+ next_state, reward, done, _ = env.step(action)
79
+ agent.cache(state, next_state, action, reward, done)
80
+ state = next_state
81
+
82
+
83
+ def train(agent: DQNAgent):
84
+ episodes = 10000000
85
+ for e in range(episodes):
86
+
87
+ state = env.reset()
88
+ # Play the game!
89
+ while True:
90
+
91
+ # print(state.shape)
92
+ # Run agent on the state
93
+ action = agent.act(state)
94
+
95
+ # Agent performs action
96
+ next_state, reward, done, info = env.step(action)
97
+
98
+ # Remember
99
+ agent.cache(state, next_state, action, reward, done)
100
+
101
+ # Learn
102
+ q, loss = agent.learn()
103
+
104
+ # Logging
105
+ logger.log_step(reward, loss, q)
106
+
107
+ # Update state
108
+ state = next_state
109
+
110
+ # Check if end of game
111
+ if done or info["gameover"] == 1:
112
+ break
113
+
114
+ logger.log_episode(e)
115
+
116
+ if e % 20 == 0:
117
+ logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
118
+
119
+ fill_memory(agent)
120
+ train(agent)
src/airstriker-genesis/run-airstriker-dqn.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+
6
+ from pathlib import Path
7
+ from tqdm import trange
8
+ from agent import DQNAgent, DDQNAgent, MetricLogger
9
+ from wrappers import make_env
10
+
11
+
12
+ # set up matplotlib
13
+ is_ipython = 'inline' in matplotlib.get_backend()
14
+ if is_ipython:
15
+ from IPython import display
16
+
17
+ plt.ion()
18
+
19
+
20
+ env = make_env()
21
+
22
+ use_cuda = torch.cuda.is_available()
23
+ print(f"Using CUDA: {use_cuda}\n")
24
+
25
+
26
+ checkpoint = None
27
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
28
+
29
+ path = "checkpoints/airstriker-dqn-new"
30
+ save_dir = Path(path)
31
+
32
+ isExist = os.path.exists(path)
33
+ if not isExist:
34
+ os.makedirs(path)
35
+
36
+ # Vanilla DQN
37
+ print("Training Vanilla DQN Agent!")
38
+ agent = DQNAgent(
39
+ state_dim=(1, 84, 84),
40
+ action_dim=env.action_space.n,
41
+ save_dir=save_dir,
42
+ batch_size=128,
43
+ checkpoint=checkpoint,
44
+ exploration_rate_decay=0.995,
45
+ exploration_rate_min=0.05,
46
+ training_frequency=1,
47
+ target_network_sync_frequency=500,
48
+ max_memory_size=50000,
49
+ learning_rate=0.0005,
50
+
51
+ )
52
+
53
+ # Double DQN
54
+ # print("Training DDQN Agent!")
55
+ # agent = DDQNAgent(
56
+ # state_dim=(1, 84, 84),
57
+ # action_dim=env.action_space.n,
58
+ # save_dir=save_dir,
59
+ # checkpoint=checkpoint,
60
+ # reset_exploration_rate=True,
61
+ # max_memory_size=max_memory_size
62
+ # )
63
+
64
+ logger = MetricLogger(save_dir)
65
+
66
+ def fill_memory(agent: DQNAgent, num_episodes=1000):
67
+ print("Filling up memory....")
68
+ for _ in trange(num_episodes):
69
+ state = env.reset()
70
+ done = False
71
+ while not done:
72
+ action = agent.act(state)
73
+ next_state, reward, done, _ = env.step(action)
74
+ agent.cache(state, next_state, action, reward, done)
75
+ state = next_state
76
+
77
+
78
+ def train(agent: DQNAgent):
79
+ episodes = 10000000
80
+ for e in range(episodes):
81
+
82
+ state = env.reset()
83
+ # Play the game!
84
+ while True:
85
+
86
+ # print(state.shape)
87
+ # Run agent on the state
88
+ action = agent.act(state)
89
+
90
+ # Agent performs action
91
+ next_state, reward, done, info = env.step(action)
92
+
93
+ # Remember
94
+ agent.cache(state, next_state, action, reward, done)
95
+
96
+ # Learn
97
+ q, loss = agent.learn()
98
+
99
+ # Logging
100
+ logger.log_step(reward, loss, q)
101
+
102
+ # Update state
103
+ state = next_state
104
+
105
+ # Check if end of game
106
+ if done or info["gameover"] == 1:
107
+ break
108
+
109
+ logger.log_episode(e)
110
+
111
+ if e % 20 == 0:
112
+ logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
113
+
114
+ fill_memory(agent)
115
+ train(agent)
src/airstriker-genesis/run-cartpole.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random, datetime
3
+ from pathlib import Path
4
+ import retro as gym
5
+ from collections import namedtuple, deque
6
+ from itertools import count
7
+
8
+ import torch
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ # from agent import MyAgent, MyDQN, MetricLogger
12
+ from cartpole import MyAgent, MetricLogger
13
+ from wrappers import make_env
14
+ import pickle
15
+ import gym
16
+ from tqdm import trange
17
+
18
+ # set up matplotlib
19
+ is_ipython = 'inline' in matplotlib.get_backend()
20
+ if is_ipython:
21
+ from IPython import display
22
+
23
+ plt.ion()
24
+
25
+
26
+ # env = make_env()
27
+ env = gym.make('CartPole-v1')
28
+
29
+ use_cuda = torch.cuda.is_available()
30
+ print(f"Using CUDA: {use_cuda}")
31
+ print()
32
+
33
+ path = "checkpoints/cartpole/latest"
34
+ save_dir = Path(path)
35
+
36
+ isExist = os.path.exists(path)
37
+ if not isExist:
38
+ os.makedirs(path)
39
+
40
+ # save_dir.mkdir(parents=True)
41
+
42
+
43
+ checkpoint = None
44
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
45
+
46
+ # For cartpole
47
+ n_actions = env.action_space.n
48
+ state = env.reset()
49
+ n_observations = len(state)
50
+ max_memory_size=100000
51
+ agent = MyAgent(
52
+ state_dim=n_observations,
53
+ action_dim=n_actions,
54
+ save_dir=save_dir,
55
+ checkpoint=checkpoint,
56
+ reset_exploration_rate=True,
57
+ max_memory_size=max_memory_size
58
+ )
59
+
60
+ # For airstriker
61
+ # agent = MyAgent(state_dim=(1, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, checkpoint=checkpoint, reset_exploration_rate=True)
62
+
63
+
64
+ logger = MetricLogger(save_dir)
65
+
66
+
67
+
68
+ def fill_memory(agent: MyAgent):
69
+ print("Filling up memory....")
70
+ for _ in trange(max_memory_size):
71
+ state = env.reset()
72
+ done = False
73
+ while not done:
74
+ action = agent.act(state)
75
+ next_state, reward, done, info = env.step(action)
76
+ agent.cache(state, next_state, action, reward, done)
77
+ state = next_state
78
+
79
+ def train(agent: MyAgent):
80
+ episodes = 10000000
81
+ for e in range(episodes):
82
+
83
+ state = env.reset()
84
+ # Play the game!
85
+ while True:
86
+
87
+ # print(state.shape)
88
+ # Run agent on the state
89
+ action = agent.act(state)
90
+
91
+ # Agent performs action
92
+ next_state, reward, done, info = env.step(action)
93
+
94
+ # Remember
95
+ agent.cache(state, next_state, action, reward, done)
96
+
97
+ # Learn
98
+ q, loss = agent.learn()
99
+
100
+ # Logging
101
+ logger.log_step(reward, loss, q)
102
+
103
+ # Update state
104
+ state = next_state
105
+
106
+ # # Check if end of game (for airstriker)
107
+ # if done or info["gameover"] == 1:
108
+ # break
109
+ # Check if end of game (for cartpole)
110
+ if done:
111
+ break
112
+
113
+ logger.log_episode(e)
114
+
115
+ if e % 20 == 0:
116
+ logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
117
+
118
+
119
+ fill_memory(agent)
120
+ train(agent)
src/airstriker-genesis/test.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import retro
2
+ import gym
3
+ import math
4
+ import random
5
+ import numpy as np
6
+ import matplotlib
7
+ import matplotlib.pyplot as plt
8
+ from collections import namedtuple, deque
9
+ from itertools import count
10
+ from gym import spaces
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import torch.nn.functional as F
16
+ import cv2
17
+ import torch
18
+ from torch.utils.tensorboard import SummaryWriter
19
+
20
+
21
+ class MaxAndSkipEnv(gym.Wrapper):
22
+ def __init__(self, env, skip=4):
23
+ """Return only every `skip`-th frame"""
24
+ gym.Wrapper.__init__(self, env)
25
+ # most recent raw observations (for max pooling across time steps)
26
+ self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
27
+ self._skip = skip
28
+
29
+ def step(self, action):
30
+ """Repeat action, sum reward, and max over last observations."""
31
+ total_reward = 0.0
32
+ done = None
33
+ for i in range(self._skip):
34
+ obs, reward, done, info = self.env.step(action)
35
+ if i == self._skip - 2: self._obs_buffer[0] = obs
36
+ if i == self._skip - 1: self._obs_buffer[1] = obs
37
+ total_reward += reward
38
+ if done:
39
+ break
40
+ # Note that the observation on the done=True frame
41
+ # doesn't matter
42
+ max_frame = self._obs_buffer.max(axis=0)
43
+
44
+ return max_frame, total_reward, done, info
45
+
46
+ def reset(self, **kwargs):
47
+ return self.env.reset(**kwargs)
48
+
49
+
50
+ class LazyFrames(object):
51
+ def __init__(self, frames):
52
+ """This object ensures that common frames between the observations are only stored once.
53
+ It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
54
+ buffers.
55
+ This object should only be converted to numpy array before being passed to the model.
56
+ You'd not believe how complex the previous solution was."""
57
+ self._frames = frames
58
+ self._out = None
59
+
60
+ def _force(self):
61
+ if self._out is None:
62
+ self._out = np.concatenate(self._frames, axis=2)
63
+ self._frames = None
64
+ return self._out
65
+
66
+ def __array__(self, dtype=None):
67
+ out = self._force()
68
+ if dtype is not None:
69
+ out = out.astype(dtype)
70
+ return out
71
+
72
+ def __len__(self):
73
+ return len(self._force())
74
+
75
+ def __getitem__(self, i):
76
+ return self._force()[i]
77
+
78
+
79
+ class FrameStack(gym.Wrapper):
80
+ def __init__(self, env, k):
81
+ """Stack k last frames.
82
+ Returns lazy array, which is much more memory efficient.
83
+ See Also
84
+ --------
85
+ baselines.common.atari_wrappers.LazyFrames
86
+ """
87
+ gym.Wrapper.__init__(self, env)
88
+ self.k = k
89
+ self.frames = deque([], maxlen=k)
90
+ shp = env.observation_space.shape
91
+ self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
92
+
93
+ def reset(self):
94
+ ob = self.env.reset()
95
+ for _ in range(self.k):
96
+ self.frames.append(ob)
97
+ return self._get_ob()
98
+
99
+ def step(self, action):
100
+ ob, reward, done, info = self.env.step(action)
101
+ self.frames.append(ob)
102
+ return self._get_ob(), reward, done, info
103
+
104
+ def _get_ob(self):
105
+ assert len(self.frames) == self.k
106
+ return LazyFrames(list(self.frames))
107
+
108
+ class ClipRewardEnv(gym.RewardWrapper):
109
+ def __init__(self, env):
110
+ gym.RewardWrapper.__init__(self, env)
111
+
112
+ def reward(self, reward):
113
+ """Bin reward to {+1, 0, -1} by its sign."""
114
+ return np.sign(reward)
115
+
116
+
117
+ class ImageToPyTorch(gym.ObservationWrapper):
118
+ def __init__(self, env):
119
+ super(ImageToPyTorch, self).__init__(env)
120
+ old_shape = self.observation_space.shape
121
+ self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
122
+
123
+ def observation(self, observation):
124
+ return np.moveaxis(observation, 2, 0)
125
+
126
+
127
+ class WarpFrame(gym.ObservationWrapper):
128
+ def __init__(self, env):
129
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
130
+ gym.ObservationWrapper.__init__(self, env)
131
+ self.width = 84
132
+ self.height = 84
133
+ self.observation_space = spaces.Box(low=0, high=255,
134
+ shape=(self.height, self.width, 1), dtype=np.uint8)
135
+
136
+ def observation(self, frame):
137
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
138
+ frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
139
+ return frame[:, :, None]
140
+
141
+ class AirstrikerDiscretizer(gym.ActionWrapper):
142
+ # 初期化
143
+ def __init__(self, env):
144
+ super(AirstrikerDiscretizer, self).__init__(env)
145
+ buttons = ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
146
+ actions = [['LEFT'], ['RIGHT'], ['B']]
147
+ self._actions = []
148
+ for action in actions:
149
+ arr = np.array([False] * 12)
150
+ for button in action:
151
+ arr[buttons.index(button)] = True
152
+ self._actions.append(arr)
153
+ self.action_space = gym.spaces.Discrete(len(self._actions))
154
+
155
+ # 行動の取得
156
+ def action(self, a):
157
+ return self._actions[a].copy()
158
+
159
+
160
+ env = retro.make(game='Airstriker-Genesis')
161
+ env = MaxAndSkipEnv(env) ## Return only every `skip`-th frame
162
+ env = WarpFrame(env) ## Reshape image
163
+ env = ImageToPyTorch(env) ## Invert shape
164
+ env = FrameStack(env, 4) ## Stack last 4 frames
165
+ # env = ScaledFloatFrame(env) ## Scale frames
166
+ env = AirstrikerDiscretizer(env)
167
+ env = ClipRewardEnv(env)
168
+
169
+ # set up matplotlib
170
+ is_ipython = 'inline' in matplotlib.get_backend()
171
+ if is_ipython:
172
+ from IPython import display
173
+
174
+ plt.ion()
175
+
176
+ # if gpu is to be used
177
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
178
+
179
+ Transition = namedtuple('Transition',
180
+ ('state', 'action', 'next_state', 'reward'))
181
+
182
+
183
+ class ReplayMemory(object):
184
+
185
+ def __init__(self, capacity):
186
+ self.memory = deque([],maxlen=capacity)
187
+
188
+ def push(self, *args):
189
+ """Save a transition"""
190
+ self.memory.append(Transition(*args))
191
+
192
+ def sample(self, batch_size):
193
+ return random.sample(self.memory, batch_size)
194
+
195
+ def __len__(self):
196
+ return len(self.memory)
197
+
198
+
199
+ class DQN(nn.Module):
200
+
201
+ def __init__(self, n_observations, n_actions):
202
+ super(DQN, self).__init__()
203
+ # self.layer1 = nn.Linear(n_observations, 128)
204
+ # self.layer2 = nn.Linear(128, 128)
205
+ # self.layer3 = nn.Linear(128, n_actions)
206
+
207
+ self.layer1 = nn.Conv2d(in_channels=n_observations, out_channels=32, kernel_size=8, stride=4)
208
+ self.layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
209
+ self.layer3 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(), nn.Flatten())
210
+ self.layer4 = nn.Linear(17024, 512)
211
+ self.layer5 = nn.Linear(512, n_actions)
212
+
213
+ # Called with either one element to determine next action, or a batch
214
+ # during optimization. Returns tensor([[left0exp,right0exp]...]).
215
+ def forward(self, x):
216
+ x = F.relu(self.layer1(x))
217
+ x = F.relu(self.layer2(x))
218
+ x = F.relu(self.layer3(x))
219
+ x = F.relu(self.layer4(x))
220
+ return self.layer5(x)
221
+
222
+
223
+ # BATCH_SIZE is the number of transitions sampled from the replay buffer
224
+ # GAMMA is the discount factor as mentioned in the previous section
225
+ # EPS_START is the starting value of epsilon
226
+ # EPS_END is the final value of epsilon
227
+ # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
228
+ # TAU is the update rate of the target network
229
+ # LR is the learning rate of the AdamW optimizer
230
+ BATCH_SIZE = 512
231
+ GAMMA = 0.99
232
+ EPS_START = 1
233
+ EPS_END = 0.01
234
+ EPS_DECAY = 10000
235
+ TAU = 0.005
236
+ # LR = 1e-4
237
+ LR = 0.00025
238
+
239
+ # Get number of actions from gym action space
240
+ n_actions = env.action_space.n
241
+ state = env.reset()
242
+ n_observations = len(state)
243
+
244
+ policy_net = DQN(n_observations, n_actions).to(device)
245
+ target_net = DQN(n_observations, n_actions).to(device)
246
+ target_net.load_state_dict(policy_net.state_dict())
247
+
248
+ optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
249
+ memory = ReplayMemory(10000)
250
+
251
+
252
+ steps_done = 0
253
+
254
+
255
+ def select_action(state):
256
+ global steps_done
257
+ sample = random.random()
258
+ eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
259
+ steps_done += 1
260
+ if sample > eps_threshold:
261
+ with torch.no_grad():
262
+ # t.max(1) will return largest column value of each row.
263
+ # second column on max result is index of where max element was
264
+ # found, so we pick action with the larger expected reward.
265
+ return policy_net(state).max(1)[1].view(1, 1), eps_threshold
266
+ else:
267
+ return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long), eps_threshold
268
+
269
+
270
+ episode_durations = []
271
+
272
+
273
+ def plot_durations(show_result=False):
274
+ plt.figure(1)
275
+ durations_t = torch.tensor(episode_durations, dtype=torch.float)
276
+ if show_result:
277
+ plt.title('Result')
278
+ else:
279
+ plt.clf()
280
+ plt.title('Training...')
281
+ plt.xlabel('Episode')
282
+ plt.ylabel('Duration')
283
+ plt.plot(durations_t.numpy())
284
+ # Take 100 episode averages and plot them too
285
+ if len(durations_t) >= 100:
286
+ means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
287
+ means = torch.cat((torch.zeros(99), means))
288
+ plt.plot(means.numpy())
289
+
290
+ plt.pause(0.001) # pause a bit so that plots are updated
291
+ if is_ipython:
292
+ if not show_result:
293
+ display.display(plt.gcf())
294
+ display.clear_output(wait=True)
295
+ else:
296
+ display.display(plt.gcf())
297
+
298
+
299
+
300
+ def optimize_model():
301
+ if len(memory) < BATCH_SIZE:
302
+ return
303
+ transitions = memory.sample(BATCH_SIZE)
304
+ # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
305
+ # detailed explanation). This converts batch-array of Transitions
306
+ # to Transition of batch-arrays.
307
+ batch = Transition(*zip(*transitions))
308
+
309
+ # Compute a mask of non-final states and concatenate the batch elements
310
+ # (a final state would've been the one after which simulation ended)
311
+ non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
312
+ batch.next_state)), device=device, dtype=torch.bool)
313
+ non_final_next_states = torch.cat([s for s in batch.next_state
314
+ if s is not None])
315
+ state_batch = torch.cat(batch.state)
316
+ action_batch = torch.cat(batch.action)
317
+ reward_batch = torch.cat(batch.reward)
318
+
319
+ # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
320
+ # columns of actions taken. These are the actions which would've been taken
321
+ # for each batch state according to policy_net
322
+ state_action_values = policy_net(state_batch).gather(1, action_batch)
323
+
324
+ # Compute V(s_{t+1}) for all next states.
325
+ # Expected values of actions for non_final_next_states are computed based
326
+ # on the "older" target_net; selecting their best reward with max(1)[0].
327
+ # This is merged based on the mask, such that we'll have either the expected
328
+ # state value or 0 in case the state was final.
329
+ next_state_values = torch.zeros(BATCH_SIZE, device=device)
330
+ with torch.no_grad():
331
+ next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
332
+ # Compute the expected Q values
333
+ expected_state_action_values = (next_state_values * GAMMA) + reward_batch
334
+
335
+ # Compute Huber loss
336
+ criterion = nn.SmoothL1Loss()
337
+ loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
338
+
339
+ # Optimize the model
340
+ optimizer.zero_grad()
341
+ loss.backward()
342
+ # In-place gradient clipping
343
+ torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
344
+ optimizer.step()
345
+
346
+
347
+ with SummaryWriter() as writer:
348
+ if torch.cuda.is_available():
349
+ num_episodes = 600
350
+ else:
351
+ num_episodes = 50
352
+ epsilon = 1
353
+ episode_rewards = []
354
+ for i_episode in range(num_episodes):
355
+
356
+ # Initialize the environment and get it's state
357
+ state = env.reset()
358
+ state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
359
+ episode_reward = 0
360
+ for t in count():
361
+ action, epsilon = select_action(state)
362
+ observation, reward, done, info = env.step(action.item())
363
+ reward = torch.tensor([reward], device=device)
364
+
365
+ done = done or info["gameover"] == 1
366
+ if done:
367
+ episode_durations.append(t + 1)
368
+ print(f"Episode {i_episode} done")
369
+ # plot_durations()
370
+ break
371
+ # if done:
372
+ # next_state = None
373
+ # else:
374
+ # next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
375
+
376
+ next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
377
+
378
+ # Store the transition in memory
379
+ memory.push(state, action, next_state, reward)
380
+ episode_reward += reward
381
+ # Move to the next state
382
+ state = next_state
383
+
384
+ # Perform one step of the optimization (on the policy network)
385
+ optimize_model()
386
+
387
+ # Soft update of the target network's weights
388
+ # θ′ ← τ θ + (1 −τ )θ′
389
+ target_net_state_dict = target_net.state_dict()
390
+ policy_net_state_dict = policy_net.state_dict()
391
+ for key in policy_net_state_dict:
392
+ target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
393
+ target_net.load_state_dict(target_net_state_dict)
394
+ # if done:
395
+ # episode_durations.append(t + 1)
396
+ # # plot_durations()
397
+ # break
398
+ # episode_rewards.append(episode_reward)
399
+ writer.add_scalar("Rewards/Episode", episode_reward, i_episode)
400
+ writer.add_scalar("Epsilon", epsilon, i_episode)
401
+ writer.flush()
402
+ print('Complete')
403
+ plot_durations(show_result=True)
404
+ plt.ioff()
405
+ plt.show()
src/airstriker-genesis/utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import numpy as np
3
+
4
+
5
+ # Airstrikerラッパー
6
+ class AirstrikerDiscretizer(gym.ActionWrapper):
7
+ # 初期化
8
+ def __init__(self, env):
9
+ super(AirstrikerDiscretizer, self).__init__(env)
10
+ buttons = ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
11
+ actions = [['LEFT'], ['RIGHT'], ['B']]
12
+ self._actions = []
13
+ for action in actions:
14
+ arr = np.array([False] * 12)
15
+ for button in action:
16
+ arr[buttons.index(button)] = True
17
+ self._actions.append(arr)
18
+ self.action_space = gym.spaces.Discrete(len(self._actions))
19
+
20
+ # 行動の取得
21
+ def action(self, a):
22
+ return self._actions[a].copy()
src/airstriker-genesis/wrappers.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from collections import deque
4
+ import gym
5
+ from gym import spaces
6
+ import cv2
7
+ import retro
8
+ from utils import AirstrikerDiscretizer
9
+
10
+
11
+ '''
12
+ Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
13
+ '''
14
+
15
+
16
+ class LazyFrames(object):
17
+ def __init__(self, frames):
18
+ """This object ensures that common frames between the observations are only stored once.
19
+ It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
20
+ buffers.
21
+ This object should only be converted to numpy array before being passed to the model.
22
+ You'd not believe how complex the previous solution was."""
23
+ self._frames = frames
24
+ self._out = None
25
+
26
+ def _force(self):
27
+ if self._out is None:
28
+ self._out = np.concatenate(self._frames, axis=2)
29
+ self._frames = None
30
+ return self._out
31
+
32
+ def __array__(self, dtype=None):
33
+ out = self._force()
34
+ if dtype is not None:
35
+ out = out.astype(dtype)
36
+ return out
37
+
38
+ def __len__(self):
39
+ return len(self._force())
40
+
41
+ def __getitem__(self, i):
42
+ return self._force()[i]
43
+
44
+ class FireResetEnv(gym.Wrapper):
45
+ def __init__(self, env):
46
+ """Take action on reset for environments that are fixed until firing."""
47
+ gym.Wrapper.__init__(self, env)
48
+ assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
49
+ assert len(env.unwrapped.get_action_meanings()) >= 3
50
+
51
+ def reset(self, **kwargs):
52
+ self.env.reset(**kwargs)
53
+ obs, _, done, _ = self.env.step(1)
54
+ if done:
55
+ self.env.reset(**kwargs)
56
+ obs, _, done, _ = self.env.step(2)
57
+ if done:
58
+ self.env.reset(**kwargs)
59
+ return obs
60
+
61
+ def step(self, ac):
62
+ return self.env.step(ac)
63
+
64
+
65
+ class MaxAndSkipEnv(gym.Wrapper):
66
+ def __init__(self, env, skip=4):
67
+ """Return only every `skip`-th frame"""
68
+ gym.Wrapper.__init__(self, env)
69
+ # most recent raw observations (for max pooling across time steps)
70
+ self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
71
+ self._skip = skip
72
+
73
+ def step(self, action):
74
+ """Repeat action, sum reward, and max over last observations."""
75
+ total_reward = 0.0
76
+ done = None
77
+ for i in range(self._skip):
78
+ obs, reward, done, info = self.env.step(action)
79
+ if i == self._skip - 2: self._obs_buffer[0] = obs
80
+ if i == self._skip - 1: self._obs_buffer[1] = obs
81
+ total_reward += reward
82
+ if done:
83
+ break
84
+ # Note that the observation on the done=True frame
85
+ # doesn't matter
86
+ max_frame = self._obs_buffer.max(axis=0)
87
+
88
+ return max_frame, total_reward, done, info
89
+
90
+ def reset(self, **kwargs):
91
+ return self.env.reset(**kwargs)
92
+
93
+
94
+
95
+ class WarpFrame(gym.ObservationWrapper):
96
+ def __init__(self, env):
97
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
98
+ gym.ObservationWrapper.__init__(self, env)
99
+ self.width = 84
100
+ self.height = 84
101
+ self.observation_space = spaces.Box(low=0, high=255,
102
+ shape=(self.height, self.width, 1), dtype=np.uint8)
103
+
104
+ def observation(self, frame):
105
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
106
+ frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
107
+ return frame[:, :, None]
108
+
109
+ class WarpFrameNoResize(gym.ObservationWrapper):
110
+ def __init__(self, env):
111
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
112
+ gym.ObservationWrapper.__init__(self, env)
113
+
114
+ def observation(self, frame):
115
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
116
+ # frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
117
+ return frame[:, :, None]
118
+
119
+
120
+
121
+ class FrameStack(gym.Wrapper):
122
+ def __init__(self, env, k):
123
+ """Stack k last frames.
124
+ Returns lazy array, which is much more memory efficient.
125
+ See Also
126
+ --------
127
+ baselines.common.atari_wrappers.LazyFrames
128
+ """
129
+ gym.Wrapper.__init__(self, env)
130
+ self.k = k
131
+ self.frames = deque([], maxlen=k)
132
+ shp = env.observation_space.shape
133
+ self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
134
+
135
+ def reset(self):
136
+ ob = self.env.reset()
137
+ for _ in range(self.k):
138
+ self.frames.append(ob)
139
+ return self._get_ob()
140
+
141
+ def step(self, action):
142
+ ob, reward, done, info = self.env.step(action)
143
+ self.frames.append(ob)
144
+ return self._get_ob(), reward, done, info
145
+
146
+ def _get_ob(self):
147
+ assert len(self.frames) == self.k
148
+ return LazyFrames(list(self.frames))
149
+
150
+
151
+ class ImageToPyTorch(gym.ObservationWrapper):
152
+ def __init__(self, env):
153
+ super(ImageToPyTorch, self).__init__(env)
154
+ old_shape = self.observation_space.shape
155
+ self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
156
+
157
+ def observation(self, observation):
158
+ return np.moveaxis(observation, 2, 0)
159
+
160
+
161
+ # class ImageToPyTorch(gym.ObservationWrapper):
162
+ # def __init__(self, env):
163
+ # super(ImageToPyTorch, self).__init__(env)
164
+ # old_shape = self.observation_space.shape
165
+ # new_shape = (old_shape[-1], old_shape[0], old_shape[1])
166
+ # print("Old: ", old_shape)
167
+ # print("New: ", new_shape)
168
+ # self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=new_shape, dtype=np.float32)
169
+
170
+ # def observation(self, observation):
171
+ # return np.moveaxis(observation, 2, 0)
172
+
173
+
174
+ class ScaledFloatFrame(gym.ObservationWrapper):
175
+ def __init__(self, env):
176
+ gym.ObservationWrapper.__init__(self, env)
177
+ self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
178
+
179
+ def observation(self, observation):
180
+ # careful! This undoes the memory optimization, use
181
+ # with smaller replay buffers only.
182
+ return np.array(observation).astype(np.float32) / 255.0
183
+
184
+ class ClipRewardEnv(gym.RewardWrapper):
185
+ def __init__(self, env):
186
+ gym.RewardWrapper.__init__(self, env)
187
+
188
+ def reward(self, reward):
189
+ """Bin reward to {+1, 0, -1} by its sign."""
190
+ return np.sign(reward)
191
+
192
+
193
+ def make_env():
194
+
195
+ env = retro.make(game='Airstriker-Genesis')
196
+ env = MaxAndSkipEnv(env) ## Return only every `skip`-th frame
197
+ env = WarpFrame(env) ## Reshape image
198
+ env = ImageToPyTorch(env) ## Invert shape
199
+ env = FrameStack(env, 4) ## Stack last 4 frames
200
+ env = ScaledFloatFrame(env) ## Scale frames
201
+ env = AirstrikerDiscretizer(env)
202
+ env = ClipRewardEnv(env)
203
+ return env
204
+
205
+ def make_starpilot(render=False):
206
+ if render:
207
+ env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy", render_mode="human")
208
+ else:
209
+ env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy")
210
+ env = WarpFrameNoResize(env) ## Reshape image
211
+ env = ImageToPyTorch(env) ## Invert shape
212
+ env = FrameStack(env, 4) ## Stack last 4 frames
213
+ return env
src/lunar-lander/agent.py ADDED
@@ -0,0 +1,1104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import torch.nn as nn
5
+ import copy
6
+ import time, datetime
7
+ import matplotlib.pyplot as plt
8
+ from collections import deque
9
+ from torch.utils.tensorboard import SummaryWriter
10
+
11
+
12
+ class DQNet(nn.Module):
13
+ """mini cnn structure"""
14
+
15
+ def __init__(self, input_dim, output_dim):
16
+ super().__init__()
17
+
18
+ self.online = nn.Sequential(
19
+ nn.Linear(input_dim, 150),
20
+ nn.ReLU(),
21
+ nn.Linear(150, 120),
22
+ nn.ReLU(),
23
+ nn.Linear(120, output_dim),
24
+ )
25
+
26
+
27
+ self.target = copy.deepcopy(self.online)
28
+
29
+ # Q_target parameters are frozen.
30
+ for p in self.target.parameters():
31
+ p.requires_grad = False
32
+
33
+ def forward(self, input, model):
34
+ if model == "online":
35
+ return self.online(input)
36
+ elif model == "target":
37
+ return self.target(input)
38
+
39
+
40
+
41
+ class MetricLogger:
42
+ def __init__(self, save_dir):
43
+ self.writer = SummaryWriter(log_dir=save_dir)
44
+ self.save_log = save_dir / "log"
45
+ with open(self.save_log, "w") as f:
46
+ f.write(
47
+ f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
48
+ f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
49
+ f"{'TimeDelta':>15}{'Time':>20}\n"
50
+ )
51
+ self.ep_rewards_plot = save_dir / "reward_plot.jpg"
52
+ self.ep_lengths_plot = save_dir / "length_plot.jpg"
53
+ self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
54
+ self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
55
+
56
+ # History metrics
57
+ self.ep_rewards = []
58
+ self.ep_lengths = []
59
+ self.ep_avg_losses = []
60
+ self.ep_avg_qs = []
61
+
62
+ # Moving averages, added for every call to record()
63
+ self.moving_avg_ep_rewards = []
64
+ self.moving_avg_ep_lengths = []
65
+ self.moving_avg_ep_avg_losses = []
66
+ self.moving_avg_ep_avg_qs = []
67
+
68
+ # Current episode metric
69
+ self.init_episode()
70
+
71
+ # Timing
72
+ self.record_time = time.time()
73
+
74
+ def log_step(self, reward, loss, q):
75
+ self.curr_ep_reward += reward
76
+ self.curr_ep_length += 1
77
+ if loss:
78
+ self.curr_ep_loss += loss
79
+ self.curr_ep_q += q
80
+ self.curr_ep_loss_length += 1
81
+
82
+ def log_episode(self, episode_number):
83
+ "Mark end of episode"
84
+ self.ep_rewards.append(self.curr_ep_reward)
85
+ self.ep_lengths.append(self.curr_ep_length)
86
+ if self.curr_ep_loss_length == 0:
87
+ ep_avg_loss = 0
88
+ ep_avg_q = 0
89
+ else:
90
+ ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
91
+ ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
92
+ self.ep_avg_losses.append(ep_avg_loss)
93
+ self.ep_avg_qs.append(ep_avg_q)
94
+ self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
95
+ self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
96
+ self.writer.flush()
97
+ self.init_episode()
98
+
99
+ def init_episode(self):
100
+ self.curr_ep_reward = 0.0
101
+ self.curr_ep_length = 0
102
+ self.curr_ep_loss = 0.0
103
+ self.curr_ep_q = 0.0
104
+ self.curr_ep_loss_length = 0
105
+
106
+ def record(self, episode, epsilon, step):
107
+ mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
108
+ mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
109
+ mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
110
+ mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
111
+ self.moving_avg_ep_rewards.append(mean_ep_reward)
112
+ self.moving_avg_ep_lengths.append(mean_ep_length)
113
+ self.moving_avg_ep_avg_losses.append(mean_ep_loss)
114
+ self.moving_avg_ep_avg_qs.append(mean_ep_q)
115
+
116
+ last_record_time = self.record_time
117
+ self.record_time = time.time()
118
+ time_since_last_record = np.round(self.record_time - last_record_time, 3)
119
+
120
+ print(
121
+ f"Episode {episode} - "
122
+ f"Step {step} - "
123
+ f"Epsilon {epsilon} - "
124
+ f"Mean Reward {mean_ep_reward} - "
125
+ f"Mean Length {mean_ep_length} - "
126
+ f"Mean Loss {mean_ep_loss} - "
127
+ f"Mean Q Value {mean_ep_q} - "
128
+ f"Time Delta {time_since_last_record} - "
129
+ f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
130
+ )
131
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
132
+ self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
133
+ self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
134
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
135
+ self.writer.add_scalar("Epsilon value", epsilon, episode)
136
+ self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
137
+ self.writer.flush()
138
+ with open(self.save_log, "a") as f:
139
+ f.write(
140
+ f"{episode:8d}{step:8d}{epsilon:10.3f}"
141
+ f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
142
+ f"{time_since_last_record:15.3f}"
143
+ f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
144
+ )
145
+
146
+ for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
147
+ plt.plot(getattr(self, f"moving_avg_{metric}"))
148
+ plt.savefig(getattr(self, f"{metric}_plot"))
149
+ plt.clf()
150
+
151
+
152
+ class DQNAgent:
153
+ def __init__(self,
154
+ state_dim,
155
+ action_dim,
156
+ save_dir,
157
+ checkpoint=None,
158
+ learning_rate=0.00025,
159
+ max_memory_size=100000,
160
+ batch_size=32,
161
+ exploration_rate=1,
162
+ exploration_rate_decay=0.9999999,
163
+ exploration_rate_min=0.1,
164
+ training_frequency=1,
165
+ learning_starts=1000,
166
+ target_network_sync_frequency=500,
167
+ reset_exploration_rate=False,
168
+ save_frequency=100000,
169
+ gamma=0.9,
170
+ load_replay_buffer=True):
171
+ self.state_dim = state_dim
172
+ self.action_dim = action_dim
173
+ self.max_memory_size = max_memory_size
174
+ self.memory = deque(maxlen=max_memory_size)
175
+ self.batch_size = batch_size
176
+
177
+ self.exploration_rate = exploration_rate
178
+ self.exploration_rate_decay = exploration_rate_decay
179
+ self.exploration_rate_min = exploration_rate_min
180
+ self.gamma = gamma
181
+
182
+ self.curr_step = 0
183
+ self.learning_starts = learning_starts # min. experiences before training
184
+
185
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
186
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
187
+
188
+ self.save_every = save_frequency # no. of experiences between saving the network
189
+ self.save_dir = save_dir
190
+
191
+ self.use_cuda = torch.cuda.is_available()
192
+
193
+ self.net = DQNet(self.state_dim, self.action_dim).float()
194
+ if self.use_cuda:
195
+ self.net = self.net.to(device='cuda')
196
+ if checkpoint:
197
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
198
+
199
+ self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
200
+ self.loss_fn = torch.nn.SmoothL1Loss()
201
+ # self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
202
+ # self.loss_fn = torch.nn.MSELoss()
203
+
204
+
205
+ def act(self, state):
206
+ """
207
+ Given a state, choose an epsilon-greedy action and update value of step.
208
+
209
+ Inputs:
210
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
211
+ Outputs:
212
+ action_idx (int): An integer representing which action the agent will perform
213
+ """
214
+ # EXPLORE
215
+ if np.random.rand() < self.exploration_rate:
216
+ action_idx = np.random.randint(self.action_dim)
217
+
218
+ # EXPLOIT
219
+ else:
220
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
221
+ state = state.unsqueeze(0)
222
+ action_values = self.net(state, model='online')
223
+ action_idx = torch.argmax(action_values, axis=1).item()
224
+
225
+ # decrease exploration_rate
226
+
227
+ self.exploration_rate *= self.exploration_rate_decay
228
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
229
+
230
+ # increment step
231
+ self.curr_step += 1
232
+ return action_idx
233
+
234
+ def cache(self, state, next_state, action, reward, done):
235
+ """
236
+ Store the experience to self.memory (replay buffer)
237
+
238
+ Inputs:
239
+ state (LazyFrame),
240
+ next_state (LazyFrame),
241
+ action (int),
242
+ reward (float),
243
+ done(bool))
244
+ """
245
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
246
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
247
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
248
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
249
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
250
+
251
+ self.memory.append( (state, next_state, action, reward, done,) )
252
+
253
+
254
+ def recall(self):
255
+ """
256
+ Retrieve a batch of experiences from memory
257
+ """
258
+ batch = random.sample(self.memory, self.batch_size)
259
+ state, next_state, action, reward, done = map(torch.stack, zip(*batch))
260
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
261
+
262
+
263
+ def td_estimate(self, states, actions):
264
+ actions = actions.reshape(-1, 1)
265
+ predicted_qs = self.net(states, model='online')# Q_online(s,a)
266
+ predicted_qs = predicted_qs.gather(1, actions)
267
+ return predicted_qs
268
+
269
+
270
+ @torch.no_grad()
271
+ def td_target(self, rewards, next_states, dones):
272
+ rewards = rewards.reshape(-1, 1)
273
+ dones = dones.reshape(-1, 1)
274
+ target_qs = self.net(next_states, model='target')
275
+ target_qs = torch.max(target_qs, dim=1).values
276
+ target_qs = target_qs.reshape(-1, 1)
277
+ target_qs[dones] = 0.0
278
+ return (rewards + (self.gamma * target_qs))
279
+
280
+ def update_Q_online(self, td_estimate, td_target) :
281
+ loss = self.loss_fn(td_estimate.float(), td_target.float())
282
+ self.optimizer.zero_grad()
283
+ loss.backward()
284
+ self.optimizer.step()
285
+ return loss.item()
286
+
287
+
288
+ def sync_Q_target(self):
289
+ self.net.target.load_state_dict(self.net.online.state_dict())
290
+
291
+
292
+ def learn(self):
293
+ if self.curr_step % self.target_network_sync_frequency == 0:
294
+ self.sync_Q_target()
295
+
296
+ if self.curr_step % self.save_every == 0:
297
+ self.save()
298
+
299
+ if self.curr_step < self.learning_starts:
300
+ return None, None
301
+
302
+ if self.curr_step % self.training_frequency != 0:
303
+ return None, None
304
+
305
+ # Sample from memory
306
+ state, next_state, action, reward, done = self.recall()
307
+
308
+ # Get TD Estimate
309
+ td_est = self.td_estimate(state, action)
310
+
311
+ # Get TD Target
312
+ td_tgt = self.td_target(reward, next_state, done)
313
+
314
+ # Backpropagate loss through Q_online
315
+
316
+ loss = self.update_Q_online(td_est, td_tgt)
317
+
318
+ return (td_est.mean().item(), loss)
319
+
320
+
321
+ def save(self):
322
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
323
+ torch.save(
324
+ dict(
325
+ model=self.net.state_dict(),
326
+ exploration_rate=self.exploration_rate,
327
+ replay_memory=self.memory
328
+ ),
329
+ save_path
330
+ )
331
+
332
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
333
+
334
+
335
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
336
+ if not load_path.exists():
337
+ raise ValueError(f"{load_path} does not exist")
338
+
339
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
340
+ exploration_rate = ckp.get('exploration_rate')
341
+ state_dict = ckp.get('model')
342
+
343
+
344
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
345
+ self.net.load_state_dict(state_dict)
346
+
347
+ if load_replay_buffer:
348
+ replay_memory = ckp.get('replay_memory')
349
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
350
+ self.memory = replay_memory if replay_memory else self.memory
351
+
352
+ if reset_exploration_rate:
353
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
354
+ else:
355
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
356
+ self.exploration_rate = exploration_rate
357
+
358
+
359
+ class DDQNAgent(DQNAgent):
360
+ @torch.no_grad()
361
+ def td_target(self, rewards, next_states, dones):
362
+ rewards = rewards.reshape(-1, 1)
363
+ dones = dones.reshape(-1, 1)
364
+ q_vals = self.net(next_states, model='online')
365
+ target_actions = torch.argmax(q_vals, axis=1)
366
+ target_actions = target_actions.reshape(-1, 1)
367
+
368
+ target_qs = self.net(next_states, model='target')
369
+ target_qs = target_qs.gather(1, target_actions)
370
+ target_qs = target_qs.reshape(-1, 1)
371
+ target_qs[dones] = 0.0
372
+ return (rewards + (self.gamma * target_qs))
373
+
374
+
375
+ class DuelingDQNet(nn.Module):
376
+ def __init__(self, input_dim, output_dim):
377
+ super().__init__()
378
+ self.feature_layer = nn.Sequential(
379
+ nn.Linear(input_dim, 150),
380
+ nn.ReLU(),
381
+ nn.Linear(150, 120),
382
+ nn.ReLU()
383
+ )
384
+
385
+ self.value_layer = nn.Sequential(
386
+ nn.Linear(120, 120),
387
+ nn.ReLU(),
388
+ nn.Linear(120, 1)
389
+ )
390
+
391
+ self.advantage_layer = nn.Sequential(
392
+ nn.Linear(120, 120),
393
+ nn.ReLU(),
394
+ nn.Linear(120, output_dim)
395
+ )
396
+
397
+ def forward(self, state):
398
+ feature_output = self.feature_layer(state)
399
+ # feature_output = feature_output.view(feature_output.size(0), -1)
400
+ value = self.value_layer(feature_output)
401
+ advantage = self.advantage_layer(feature_output)
402
+ q_value = value + (advantage - advantage.mean())
403
+
404
+ return q_value
405
+
406
+
407
+ class DuelingDQNAgent:
408
+ def __init__(self,
409
+ state_dim,
410
+ action_dim,
411
+ save_dir,
412
+ checkpoint=None,
413
+ learning_rate=0.00025,
414
+ max_memory_size=100000,
415
+ batch_size=32,
416
+ exploration_rate=1,
417
+ exploration_rate_decay=0.9999999,
418
+ exploration_rate_min=0.1,
419
+ training_frequency=1,
420
+ learning_starts=1000,
421
+ target_network_sync_frequency=500,
422
+ reset_exploration_rate=False,
423
+ save_frequency=100000,
424
+ gamma=0.9,
425
+ load_replay_buffer=True):
426
+ self.state_dim = state_dim
427
+ self.action_dim = action_dim
428
+ self.max_memory_size = max_memory_size
429
+ self.memory = deque(maxlen=max_memory_size)
430
+ self.batch_size = batch_size
431
+
432
+ self.exploration_rate = exploration_rate
433
+ self.exploration_rate_decay = exploration_rate_decay
434
+ self.exploration_rate_min = exploration_rate_min
435
+ self.gamma = gamma
436
+
437
+ self.curr_step = 0
438
+ self.learning_starts = learning_starts # min. experiences before training
439
+
440
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
441
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
442
+
443
+ self.save_every = save_frequency # no. of experiences between saving the network
444
+ self.save_dir = save_dir
445
+
446
+ self.use_cuda = torch.cuda.is_available()
447
+
448
+
449
+ self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
450
+ self.target_net = copy.deepcopy(self.online_net)
451
+ # Q_target parameters are frozen.
452
+ for p in self.target_net.parameters():
453
+ p.requires_grad = False
454
+
455
+ if self.use_cuda:
456
+ self.online_net = self.online_net(device='cuda')
457
+ self.target_net = self.target_net(device='cuda')
458
+ if checkpoint:
459
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
460
+
461
+ self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
462
+ self.loss_fn = torch.nn.SmoothL1Loss()
463
+ # self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
464
+ # self.loss_fn = torch.nn.MSELoss()
465
+
466
+
467
+ def act(self, state):
468
+ """
469
+ Given a state, choose an epsilon-greedy action and update value of step.
470
+
471
+ Inputs:
472
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
473
+ Outputs:
474
+ action_idx (int): An integer representing which action the agent will perform
475
+ """
476
+ # EXPLORE
477
+ if np.random.rand() < self.exploration_rate:
478
+ action_idx = np.random.randint(self.action_dim)
479
+
480
+ # EXPLOIT
481
+ else:
482
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
483
+ state = state.unsqueeze(0)
484
+ action_values = self.online_net(state)
485
+ action_idx = torch.argmax(action_values, axis=1).item()
486
+
487
+ # decrease exploration_rate
488
+ self.exploration_rate *= self.exploration_rate_decay
489
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
490
+
491
+ # increment step
492
+ self.curr_step += 1
493
+ return action_idx
494
+
495
+ def cache(self, state, next_state, action, reward, done):
496
+ """
497
+ Store the experience to self.memory (replay buffer)
498
+
499
+ Inputs:
500
+ state (LazyFrame),
501
+ next_state (LazyFrame),
502
+ action (int),
503
+ reward (float),
504
+ done(bool))
505
+ """
506
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
507
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
508
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
509
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
510
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
511
+
512
+ self.memory.append( (state, next_state, action, reward, done,) )
513
+
514
+
515
+ def recall(self):
516
+ """
517
+ Retrieve a batch of experiences from memory
518
+ """
519
+ batch = random.sample(self.memory, self.batch_size)
520
+ state, next_state, action, reward, done = map(torch.stack, zip(*batch))
521
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
522
+
523
+
524
+ def td_estimate(self, states, actions):
525
+ actions = actions.reshape(-1, 1)
526
+ predicted_qs = self.online_net(states)# Q_online(s,a)
527
+ predicted_qs = predicted_qs.gather(1, actions)
528
+ return predicted_qs
529
+
530
+
531
+ @torch.no_grad()
532
+ def td_target(self, rewards, next_states, dones):
533
+ rewards = rewards.reshape(-1, 1)
534
+ dones = dones.reshape(-1, 1)
535
+ target_qs = self.target_net.forward(next_states)
536
+ target_qs = torch.max(target_qs, dim=1).values
537
+ target_qs = target_qs.reshape(-1, 1)
538
+ target_qs[dones] = 0.0
539
+ return (rewards + (self.gamma * target_qs))
540
+
541
+ def update_Q_online(self, td_estimate, td_target) :
542
+ loss = self.loss_fn(td_estimate.float(), td_target.float())
543
+ self.optimizer.zero_grad()
544
+ loss.backward()
545
+ self.optimizer.step()
546
+ return loss.item()
547
+
548
+
549
+ def sync_Q_target(self):
550
+ self.target_net.load_state_dict(self.online_net.state_dict())
551
+
552
+
553
+ def learn(self):
554
+ if self.curr_step % self.target_network_sync_frequency == 0:
555
+ self.sync_Q_target()
556
+
557
+ if self.curr_step % self.save_every == 0:
558
+ self.save()
559
+
560
+ if self.curr_step < self.learning_starts:
561
+ return None, None
562
+
563
+ if self.curr_step % self.training_frequency != 0:
564
+ return None, None
565
+
566
+ # Sample from memory
567
+ state, next_state, action, reward, done = self.recall()
568
+
569
+ # Get TD Estimate
570
+ td_est = self.td_estimate(state, action)
571
+
572
+ # Get TD Target
573
+ td_tgt = self.td_target(reward, next_state, done)
574
+
575
+ # Backpropagate loss through Q_online
576
+ loss = self.update_Q_online(td_est, td_tgt)
577
+
578
+ return (td_est.mean().item(), loss)
579
+
580
+
581
+ def save(self):
582
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
583
+ torch.save(
584
+ dict(
585
+ model=self.online_net.state_dict(),
586
+ exploration_rate=self.exploration_rate,
587
+ replay_memory=self.memory
588
+ ),
589
+ save_path
590
+ )
591
+
592
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
593
+
594
+
595
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
596
+ if not load_path.exists():
597
+ raise ValueError(f"{load_path} does not exist")
598
+
599
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
600
+ exploration_rate = ckp.get('exploration_rate')
601
+ state_dict = ckp.get('model')
602
+
603
+
604
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
605
+ self.online_net.load_state_dict(state_dict)
606
+ self.target_net = copy.deepcopy(self.online_net)
607
+ self.sync_Q_target()
608
+
609
+ if load_replay_buffer:
610
+ replay_memory = ckp.get('replay_memory')
611
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
612
+ self.memory = replay_memory if replay_memory else self.memory
613
+
614
+ if reset_exploration_rate:
615
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
616
+ else:
617
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
618
+ self.exploration_rate = exploration_rate
619
+
620
+
621
+
622
+
623
+ class DuelingDDQNAgent(DuelingDQNAgent):
624
+ @torch.no_grad()
625
+ def td_target(self, rewards, next_states, dones):
626
+ rewards = rewards.reshape(-1, 1)
627
+ dones = dones.reshape(-1, 1)
628
+ q_vals = self.online_net.forward(next_states)
629
+ target_actions = torch.argmax(q_vals, axis=1)
630
+ target_actions = target_actions.reshape(-1, 1)
631
+
632
+ target_qs = self.target_net.forward(next_states)
633
+ target_qs = target_qs.gather(1, target_actions)
634
+ target_qs = target_qs.reshape(-1, 1)
635
+ target_qs[dones] = 0.0
636
+ return (rewards + (self.gamma * target_qs))
637
+
638
+
639
+
640
+ class DQNAgentWithStepDecay:
641
+ def __init__(self,
642
+ state_dim,
643
+ action_dim,
644
+ save_dir,
645
+ checkpoint=None,
646
+ learning_rate=0.00025,
647
+ max_memory_size=100000,
648
+ batch_size=32,
649
+ exploration_rate=1,
650
+ exploration_rate_decay=0.9999999,
651
+ exploration_rate_min=0.1,
652
+ training_frequency=1,
653
+ learning_starts=1000,
654
+ target_network_sync_frequency=500,
655
+ reset_exploration_rate=False,
656
+ save_frequency=100000,
657
+ gamma=0.9,
658
+ load_replay_buffer=True):
659
+ self.state_dim = state_dim
660
+ self.action_dim = action_dim
661
+ self.max_memory_size = max_memory_size
662
+ self.memory = deque(maxlen=max_memory_size)
663
+ self.batch_size = batch_size
664
+
665
+ self.exploration_rate = exploration_rate
666
+ self.exploration_rate_decay = exploration_rate_decay
667
+ self.exploration_rate_min = exploration_rate_min
668
+ self.gamma = gamma
669
+
670
+ self.curr_step = 0
671
+ self.learning_starts = learning_starts # min. experiences before training
672
+
673
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
674
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
675
+
676
+ self.save_every = save_frequency # no. of experiences between saving the network
677
+ self.save_dir = save_dir
678
+
679
+ self.use_cuda = torch.cuda.is_available()
680
+
681
+ self.net = DQNet(self.state_dim, self.action_dim).float()
682
+ if self.use_cuda:
683
+ self.net = self.net.to(device='cuda')
684
+ if checkpoint:
685
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
686
+
687
+ self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
688
+ self.loss_fn = torch.nn.SmoothL1Loss()
689
+ # self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
690
+ # self.loss_fn = torch.nn.MSELoss()
691
+
692
+
693
+ def act(self, state):
694
+ """
695
+ Given a state, choose an epsilon-greedy action and update value of step.
696
+
697
+ Inputs:
698
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
699
+ Outputs:
700
+ action_idx (int): An integer representing which action the agent will perform
701
+ """
702
+ # EXPLORE
703
+ if np.random.rand() < self.exploration_rate:
704
+ action_idx = np.random.randint(self.action_dim)
705
+
706
+ # EXPLOIT
707
+ else:
708
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
709
+ state = state.unsqueeze(0)
710
+ action_values = self.net(state, model='online')
711
+ action_idx = torch.argmax(action_values, axis=1).item()
712
+
713
+ # decrease exploration_rate
714
+
715
+ self.exploration_rate *= self.exploration_rate_decay
716
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
717
+
718
+ # increment step
719
+ self.curr_step += 1
720
+ return action_idx
721
+
722
+ def cache(self, state, next_state, action, reward, done, stepnumber):
723
+ """
724
+ Store the experience to self.memory (replay buffer)
725
+
726
+ Inputs:
727
+ state (LazyFrame),
728
+ next_state (LazyFrame),
729
+ action (int),
730
+ reward (float),
731
+ done(bool))
732
+ """
733
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
734
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
735
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
736
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
737
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
738
+ stepnumber = torch.LongTensor([stepnumber]).cuda() if self.use_cuda else torch.LongTensor([stepnumber])
739
+
740
+ self.memory.append( (state, next_state, action, reward, done, stepnumber) )
741
+
742
+
743
+ def recall(self):
744
+ """
745
+ Retrieve a batch of experiences from memory
746
+ """
747
+ batch = random.sample(self.memory, self.batch_size)
748
+ state, next_state, action, reward, done, stepnumber = map(torch.stack, zip(*batch))
749
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze(), stepnumber.squeeze()
750
+
751
+
752
+ def td_estimate(self, states, actions):
753
+ actions = actions.reshape(-1, 1)
754
+ predicted_qs = self.net(states, model='online')# Q_online(s,a)
755
+ predicted_qs = predicted_qs.gather(1, actions)
756
+ return predicted_qs
757
+
758
+
759
+ @torch.no_grad()
760
+ def td_target(self, rewards, next_states, dones, stepnumbers):
761
+ rewards = rewards.reshape(-1, 1)
762
+ dones = dones.reshape(-1, 1)
763
+ stepnumbers = stepnumbers.reshape(-1, 1)
764
+ target_qs = self.net(next_states, model='target')
765
+ target_qs = torch.max(target_qs, dim=1).values
766
+ target_qs = target_qs.reshape(-1, 1)
767
+ target_qs[dones] = 0.0
768
+ discount = ((200 - stepnumbers)/200)
769
+ val = np.minimum(discount, self.gamma * target_qs)
770
+ return (rewards + val)
771
+
772
+ def update_Q_online(self, td_estimate, td_target) :
773
+ loss = self.loss_fn(td_estimate.float(), td_target.float())
774
+ self.optimizer.zero_grad()
775
+ loss.backward()
776
+ self.optimizer.step()
777
+ return loss.item()
778
+
779
+
780
+ def sync_Q_target(self):
781
+ self.net.target.load_state_dict(self.net.online.state_dict())
782
+
783
+
784
+ def learn(self):
785
+ if self.curr_step % self.target_network_sync_frequency == 0:
786
+ self.sync_Q_target()
787
+
788
+ if self.curr_step % self.save_every == 0:
789
+ self.save()
790
+
791
+ if self.curr_step < self.learning_starts:
792
+ return None, None
793
+
794
+ if self.curr_step % self.training_frequency != 0:
795
+ return None, None
796
+
797
+ # Sample from memory
798
+ state, next_state, action, reward, done, stepnumber = self.recall()
799
+
800
+ # Get TD Estimate
801
+ td_est = self.td_estimate(state, action)
802
+
803
+ # Get TD Target
804
+ td_tgt = self.td_target(reward, next_state, done, stepnumber)
805
+
806
+ # Backpropagate loss through Q_online
807
+
808
+ loss = self.update_Q_online(td_est, td_tgt)
809
+
810
+ return (td_est.mean().item(), loss)
811
+
812
+
813
+ def save(self):
814
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
815
+ torch.save(
816
+ dict(
817
+ model=self.net.state_dict(),
818
+ exploration_rate=self.exploration_rate,
819
+ replay_memory=self.memory
820
+ ),
821
+ save_path
822
+ )
823
+
824
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
825
+
826
+
827
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
828
+ if not load_path.exists():
829
+ raise ValueError(f"{load_path} does not exist")
830
+
831
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
832
+ exploration_rate = ckp.get('exploration_rate')
833
+ state_dict = ckp.get('model')
834
+
835
+
836
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
837
+ self.net.load_state_dict(state_dict)
838
+
839
+ if load_replay_buffer:
840
+ replay_memory = ckp.get('replay_memory')
841
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
842
+ self.memory = replay_memory if replay_memory else self.memory
843
+
844
+ if reset_exploration_rate:
845
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
846
+ else:
847
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
848
+ self.exploration_rate = exploration_rate
849
+
850
+
851
+ class DDQNAgentWithStepDecay(DQNAgentWithStepDecay):
852
+ @torch.no_grad()
853
+ def td_target(self, rewards, next_states, dones, stepnumbers):
854
+ rewards = rewards.reshape(-1, 1)
855
+ dones = dones.reshape(-1, 1)
856
+ stepnumbers = stepnumbers.reshape(-1, 1)
857
+ q_vals = self.net(next_states, model='online')
858
+ target_actions = torch.argmax(q_vals, axis=1)
859
+ target_actions = target_actions.reshape(-1, 1)
860
+
861
+ target_qs = self.net(next_states, model='target')
862
+ target_qs = target_qs.gather(1, target_actions)
863
+ target_qs = target_qs.reshape(-1, 1)
864
+ target_qs[dones] = 0.0
865
+ discount = ((200 - stepnumbers)/200)
866
+ val = np.minimum(discount, self.gamma * target_qs)
867
+ return (rewards + val)
868
+
869
+
870
+ class DuelingDQNAgentWithStepDecay:
871
+ def __init__(self,
872
+ state_dim,
873
+ action_dim,
874
+ save_dir,
875
+ checkpoint=None,
876
+ learning_rate=0.00025,
877
+ max_memory_size=100000,
878
+ batch_size=32,
879
+ exploration_rate=1,
880
+ exploration_rate_decay=0.9999999,
881
+ exploration_rate_min=0.1,
882
+ training_frequency=1,
883
+ learning_starts=1000,
884
+ target_network_sync_frequency=500,
885
+ reset_exploration_rate=False,
886
+ save_frequency=100000,
887
+ gamma=0.9,
888
+ load_replay_buffer=True):
889
+ self.state_dim = state_dim
890
+ self.action_dim = action_dim
891
+ self.max_memory_size = max_memory_size
892
+ self.memory = deque(maxlen=max_memory_size)
893
+ self.batch_size = batch_size
894
+
895
+ self.exploration_rate = exploration_rate
896
+ self.exploration_rate_decay = exploration_rate_decay
897
+ self.exploration_rate_min = exploration_rate_min
898
+ self.gamma = gamma
899
+
900
+ self.curr_step = 0
901
+ self.learning_starts = learning_starts # min. experiences before training
902
+
903
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
904
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
905
+
906
+ self.save_every = save_frequency # no. of experiences between saving the network
907
+ self.save_dir = save_dir
908
+
909
+ self.use_cuda = torch.cuda.is_available()
910
+
911
+
912
+ self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
913
+ self.target_net = copy.deepcopy(self.online_net)
914
+ # Q_target parameters are frozen.
915
+ for p in self.target_net.parameters():
916
+ p.requires_grad = False
917
+
918
+ if self.use_cuda:
919
+ self.online_net = self.online_net(device='cuda')
920
+ self.target_net = self.target_net(device='cuda')
921
+ if checkpoint:
922
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
923
+
924
+ self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
925
+ self.loss_fn = torch.nn.SmoothL1Loss()
926
+ # self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
927
+ # self.loss_fn = torch.nn.MSELoss()
928
+
929
+
930
+ def act(self, state):
931
+ """
932
+ Given a state, choose an epsilon-greedy action and update value of step.
933
+
934
+ Inputs:
935
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
936
+ Outputs:
937
+ action_idx (int): An integer representing which action the agent will perform
938
+ """
939
+ # EXPLORE
940
+ if np.random.rand() < self.exploration_rate:
941
+ action_idx = np.random.randint(self.action_dim)
942
+
943
+ # EXPLOIT
944
+ else:
945
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
946
+ state = state.unsqueeze(0)
947
+ action_values = self.online_net(state)
948
+ action_idx = torch.argmax(action_values, axis=1).item()
949
+
950
+ # decrease exploration_rate
951
+ self.exploration_rate *= self.exploration_rate_decay
952
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
953
+
954
+ # increment step
955
+ self.curr_step += 1
956
+ return action_idx
957
+
958
+ def cache(self, state, next_state, action, reward, done, stepnumber):
959
+ """
960
+ Store the experience to self.memory (replay buffer)
961
+
962
+ Inputs:
963
+ state (LazyFrame),
964
+ next_state (LazyFrame),
965
+ action (int),
966
+ reward (float),
967
+ done(bool))
968
+ """
969
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
970
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
971
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
972
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
973
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
974
+ stepnumber = torch.LongTensor([stepnumber]).cuda() if self.use_cuda else torch.LongTensor([stepnumber])
975
+
976
+ self.memory.append( (state, next_state, action, reward, done, stepnumber) )
977
+
978
+
979
+ def recall(self):
980
+ """
981
+ Retrieve a batch of experiences from memory
982
+ """
983
+ batch = random.sample(self.memory, self.batch_size)
984
+ state, next_state, action, reward, done, stepnumber = map(torch.stack, zip(*batch))
985
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze(), stepnumber.squeeze()
986
+
987
+
988
+ def td_estimate(self, states, actions):
989
+ actions = actions.reshape(-1, 1)
990
+ predicted_qs = self.online_net(states)# Q_online(s,a)
991
+ predicted_qs = predicted_qs.gather(1, actions)
992
+ return predicted_qs
993
+
994
+
995
+ @torch.no_grad()
996
+ def td_target(self, rewards, next_states, dones, stepnumbers):
997
+ rewards = rewards.reshape(-1, 1)
998
+ dones = dones.reshape(-1, 1)
999
+ stepnumbers = stepnumbers.reshape(-1, 1)
1000
+ target_qs = self.target_net.forward(next_states)
1001
+ target_qs = torch.max(target_qs, dim=1).values
1002
+ target_qs = target_qs.reshape(-1, 1)
1003
+ target_qs[dones] = 0.0
1004
+ discount = ((200 - stepnumbers)/200)
1005
+ val = np.minimum(discount, self.gamma * target_qs)
1006
+ return (rewards + val)
1007
+
1008
+ def update_Q_online(self, td_estimate, td_target) :
1009
+ loss = self.loss_fn(td_estimate.float(), td_target.float())
1010
+ self.optimizer.zero_grad()
1011
+ loss.backward()
1012
+ self.optimizer.step()
1013
+ return loss.item()
1014
+
1015
+
1016
+ def sync_Q_target(self):
1017
+ self.target_net.load_state_dict(self.online_net.state_dict())
1018
+
1019
+
1020
+ def learn(self):
1021
+ if self.curr_step % self.target_network_sync_frequency == 0:
1022
+ self.sync_Q_target()
1023
+
1024
+ if self.curr_step % self.save_every == 0:
1025
+ self.save()
1026
+
1027
+ if self.curr_step < self.learning_starts:
1028
+ return None, None
1029
+
1030
+ if self.curr_step % self.training_frequency != 0:
1031
+ return None, None
1032
+
1033
+ # Sample from memory
1034
+ state, next_state, action, reward, done, stepnumbers = self.recall()
1035
+
1036
+ # Get TD Estimate
1037
+ td_est = self.td_estimate(state, action)
1038
+
1039
+ # Get TD Target
1040
+ td_tgt = self.td_target(reward, next_state, done, stepnumbers)
1041
+
1042
+ # Backpropagate loss through Q_online
1043
+ loss = self.update_Q_online(td_est, td_tgt)
1044
+
1045
+ return (td_est.mean().item(), loss)
1046
+
1047
+
1048
+ def save(self):
1049
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
1050
+ torch.save(
1051
+ dict(
1052
+ model=self.online_net.state_dict(),
1053
+ exploration_rate=self.exploration_rate,
1054
+ replay_memory=self.memory
1055
+ ),
1056
+ save_path
1057
+ )
1058
+
1059
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
1060
+
1061
+
1062
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
1063
+ if not load_path.exists():
1064
+ raise ValueError(f"{load_path} does not exist")
1065
+
1066
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
1067
+ exploration_rate = ckp.get('exploration_rate')
1068
+ state_dict = ckp.get('model')
1069
+
1070
+
1071
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
1072
+ self.online_net.load_state_dict(state_dict)
1073
+ self.target_net = copy.deepcopy(self.online_net)
1074
+ self.sync_Q_target()
1075
+
1076
+ if load_replay_buffer:
1077
+ replay_memory = ckp.get('replay_memory')
1078
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
1079
+ self.memory = replay_memory if replay_memory else self.memory
1080
+
1081
+ if reset_exploration_rate:
1082
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
1083
+ else:
1084
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
1085
+ self.exploration_rate = exploration_rate
1086
+
1087
+
1088
+ class DuelingDDQNAgentWithStepDecay(DuelingDQNAgentWithStepDecay):
1089
+ @torch.no_grad()
1090
+ def td_target(self, rewards, next_states, dones, stepnumbers):
1091
+ rewards = rewards.reshape(-1, 1)
1092
+ dones = dones.reshape(-1, 1)
1093
+ stepnumbers = stepnumbers.reshape(-1, 1)
1094
+ q_vals = self.online_net.forward(next_states)
1095
+ target_actions = torch.argmax(q_vals, axis=1)
1096
+ target_actions = target_actions.reshape(-1, 1)
1097
+
1098
+ target_qs = self.target_net.forward(next_states)
1099
+ target_qs = target_qs.gather(1, target_actions)
1100
+ target_qs = target_qs.reshape(-1, 1)
1101
+ target_qs[dones] = 0.0
1102
+ discount = ((200 - stepnumbers)/200)
1103
+ val = np.minimum(discount, self.gamma * target_qs)
1104
+ return (rewards + val)
src/lunar-lander/params.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hyperparams = dict(
2
+ batch_size=128,
3
+ exploration_rate=1,
4
+ exploration_rate_decay=0.99999,
5
+ exploration_rate_min=0.01,
6
+ training_frequency=1,
7
+ target_network_sync_frequency=20,
8
+ max_memory_size=1000000,
9
+ learning_rate=0.001,
10
+ learning_starts=128,
11
+ save_frequency=100000
12
+ )
src/lunar-lander/replay.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pathlib import Path
3
+ from agent import DQNAgent, DDQNAgent, MetricLogger
4
+ from wrappers import make_lunar
5
+
6
+
7
+ env = make_lunar()
8
+
9
+ env.reset()
10
+
11
+ save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
12
+ save_dir.mkdir(parents=True)
13
+
14
+ # checkpoint = Path('checkpoints/lunar-lander-dueling-ddqn/airstriker_net_2.chkpt')
15
+ checkpoint = Path('checkpoints/lunar-lander-dqn-rc/airstriker_net_1.chkpt')
16
+
17
+ logger = MetricLogger(save_dir)
18
+
19
+ print("Testing Double DQN Agent!")
20
+ agent = DDQNAgent(
21
+ state_dim=8,
22
+ action_dim=env.action_space.n,
23
+ save_dir=save_dir,
24
+ batch_size=512,
25
+ checkpoint=checkpoint,
26
+ exploration_rate_decay=0.999995,
27
+ exploration_rate_min=0.05,
28
+ training_frequency=1,
29
+ target_network_sync_frequency=200,
30
+ max_memory_size=50000,
31
+ learning_rate=0.0005,
32
+ load_replay_buffer=False
33
+
34
+ )
35
+ agent.exploration_rate = agent.exploration_rate_min
36
+
37
+ episodes = 100
38
+
39
+ for e in range(episodes):
40
+
41
+ state = env.reset()
42
+
43
+ while True:
44
+
45
+ env.render()
46
+
47
+ action = agent.act(state)
48
+
49
+ next_state, reward, done, info = env.step(action)
50
+
51
+ # agent.cache(state, next_state, action, reward, done)
52
+
53
+ # logger.log_step(reward, None, None)
54
+
55
+ state = next_state
56
+
57
+ if done:
58
+ break
59
+
60
+ # logger.log_episode()
61
+
62
+ # if e % 20 == 0:
63
+ # logger.record(
64
+ # episode=e,
65
+ # epsilon=agent.exploration_rate,
66
+ # step=agent.curr_step
67
+ # )
src/lunar-lander/run-lunar-ddqn.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DDQNAgent, DDQNAgentWithStepDecay, MetricLogger
6
+ from wrappers import make_lunar
7
+ import os
8
+ from train import train, fill_memory
9
+ from params import hyperparams
10
+
11
+ env = make_lunar()
12
+
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"Using CUDA: {use_cuda}\n")
15
+
16
+ checkpoint = None
17
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
18
+
19
+ path = "checkpoints/lunar-lander-ddqn-rc"
20
+ save_dir = Path(path)
21
+
22
+ isExist = os.path.exists(path)
23
+ if not isExist:
24
+ os.makedirs(path)
25
+
26
+ logger = MetricLogger(save_dir)
27
+
28
+ print("Training DDQN Agent!")
29
+ agent = DDQNAgentWithStepDecay(
30
+ state_dim=8,
31
+ action_dim=env.action_space.n,
32
+ save_dir=save_dir,
33
+ checkpoint=checkpoint,
34
+ **hyperparams
35
+ )
36
+ # agent = DDQNAgent(
37
+ # state_dim=8,
38
+ # action_dim=env.action_space.n,
39
+ # save_dir=save_dir,
40
+ # checkpoint=checkpoint,
41
+ # **hyperparams
42
+ # )
43
+
44
+ # fill_memory(agent, env, 5000)
45
+ train(agent, env, logger)
src/lunar-lander/run-lunar-dqn.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DQNAgent, DQNAgentWithStepDecay, MetricLogger
6
+ from wrappers import make_lunar
7
+ import os
8
+ from train import train, fill_memory
9
+ from params import hyperparams
10
+
11
+ env = make_lunar()
12
+
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"Using CUDA: {use_cuda}\n")
15
+
16
+ checkpoint = None
17
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
18
+
19
+ path = "checkpoints/lunar-lander-dqn-rc"
20
+ save_dir = Path(path)
21
+
22
+ isExist = os.path.exists(path)
23
+ if not isExist:
24
+ os.makedirs(path)
25
+
26
+ logger = MetricLogger(save_dir)
27
+
28
+ print("Training Vanilla DQN Agent with decay!")
29
+ agent = DQNAgentWithStepDecay(
30
+ state_dim=8,
31
+ action_dim=env.action_space.n,
32
+ save_dir=save_dir,
33
+ checkpoint=checkpoint,
34
+ **hyperparams
35
+ )
36
+ # print("Training Vanilla DQN Agent!")
37
+ # agent = DQNAgent(
38
+ # state_dim=8,
39
+ # action_dim=env.action_space.n,
40
+ # save_dir=save_dir,
41
+ # checkpoint=checkpoint,
42
+ # **hyperparams
43
+ # )
44
+
45
+ # fill_memory(agent, env, 5000)
46
+ train(agent, env, logger)
src/lunar-lander/run-lunar-dueling-ddqn.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DuelingDDQNAgent, DuelingDDQNAgentWithStepDecay,MetricLogger
6
+ from wrappers import make_lunar
7
+ import os
8
+ from train import train, fill_memory
9
+ from params import hyperparams
10
+
11
+
12
+ env = make_lunar()
13
+
14
+ use_cuda = torch.cuda.is_available()
15
+ print(f"Using CUDA: {use_cuda}\n")
16
+
17
+ checkpoint = None
18
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
19
+
20
+ path = "checkpoints/lunar-lander-dueling-ddqn-rc"
21
+ save_dir = Path(path)
22
+
23
+ isExist = os.path.exists(path)
24
+ if not isExist:
25
+ os.makedirs(path)
26
+
27
+ logger = MetricLogger(save_dir)
28
+
29
+ print("Training Dueling DDQN Agent with step decay!")
30
+ agent = DuelingDDQNAgentWithStepDecay(
31
+ state_dim=8,
32
+ action_dim=env.action_space.n,
33
+ save_dir=save_dir,
34
+ checkpoint=checkpoint,
35
+ **hyperparams
36
+ )
37
+ # print("Training Dueling DDQN Agent!")
38
+ # agent = DuelingDDQNAgent(
39
+ # state_dim=8,
40
+ # action_dim=env.action_space.n,
41
+ # save_dir=save_dir,
42
+ # checkpoint=checkpoint,
43
+ # **hyperparams
44
+ # )
45
+
46
+ # fill_memory(agent, env, 5000)
47
+ train(agent, env, logger)
src/lunar-lander/run-lunar-dueling-dqn.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DuelingDQNAgent, DuelingDQNAgentWithStepDecay, MetricLogger
6
+ from wrappers import make_lunar
7
+ import os
8
+ from train import train, fill_memory
9
+ from params import hyperparams
10
+
11
+ env = make_lunar()
12
+
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"Using CUDA: {use_cuda}\n")
15
+
16
+ checkpoint = None
17
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
18
+
19
+ path = "checkpoints/lunar-lander-dueling-dqn-rc"
20
+ save_dir = Path(path)
21
+
22
+ isExist = os.path.exists(path)
23
+ if not isExist:
24
+ os.makedirs(path)
25
+
26
+ logger = MetricLogger(save_dir)
27
+
28
+ print("Training Dueling DQN Agent with step decay!")
29
+ agent = DuelingDQNAgentWithStepDecay(
30
+ state_dim=8,
31
+ action_dim=env.action_space.n,
32
+ save_dir=save_dir,
33
+ checkpoint=checkpoint,
34
+ **hyperparams
35
+ )
36
+ # print("Training Dueling DQN Agent!")
37
+ # agent = DuelingDQNAgent(
38
+ # state_dim=8,
39
+ # action_dim=env.action_space.n,
40
+ # save_dir=save_dir,
41
+ # checkpoint=checkpoint,
42
+ # **hyperparams
43
+ # )
44
+
45
+ # fill_memory(agent, env, 5000)
46
+ train(agent, env, logger)
src/lunar-lander/train.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+
3
+ def fill_memory(agent, env, num_episodes=500 ):
4
+ print("Filling up memory....")
5
+ for _ in trange(500):
6
+ state = env.reset()
7
+ done = False
8
+ while not done:
9
+ action = agent.act(state)
10
+ next_state, reward, done, _ = env.step(action)
11
+ agent.cache(state, next_state, action, reward, done)
12
+ state = next_state
13
+
14
+
15
+ # def train(agent, env, logger):
16
+ # episodes = 5000
17
+ # for e in range(episodes):
18
+
19
+ # state = env.reset()
20
+ # # Play the game!
21
+ # while True:
22
+
23
+ # # Run agent on the state
24
+ # action = agent.act(state)
25
+
26
+ # # Agent performs action
27
+ # next_state, reward, done, info = env.step(action)
28
+
29
+ # # Remember
30
+ # agent.cache(state, next_state, action, reward, done)
31
+
32
+ # # Learn
33
+ # q, loss = agent.learn()
34
+
35
+ # # Logging
36
+ # logger.log_step(reward, loss, q)
37
+
38
+ # # Update state
39
+ # state = next_state
40
+
41
+ # # Check if end of game
42
+ # if done:
43
+ # break
44
+
45
+ # logger.log_episode(e)
46
+
47
+ # if e % 20 == 0:
48
+ # logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
49
+
50
+
51
+ def train(agent, env, logger):
52
+ episodes = 5000
53
+ for e in range(episodes):
54
+
55
+ state = env.reset()
56
+ # Play the game!
57
+ for i in range(1000):
58
+
59
+ # Run agent on the state
60
+ action = agent.act(state)
61
+ env.render()
62
+ # Agent performs action
63
+ next_state, reward, done, info = env.step(action)
64
+
65
+ # Remember
66
+ agent.cache(state, next_state, action, reward, done, i)
67
+
68
+ # Learn
69
+ q, loss = agent.learn()
70
+
71
+ # Logging
72
+ logger.log_step(reward, loss, q)
73
+
74
+ # Update state
75
+ state = next_state
76
+
77
+ # Check if end of game
78
+ if done:
79
+ break
80
+
81
+ logger.log_episode(e)
82
+
83
+ if e % 20 == 0:
84
+ logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
src/lunar-lander/wrappers.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from collections import deque
4
+ import gym
5
+ from gym import spaces
6
+ import cv2
7
+ import math
8
+
9
+ '''
10
+ Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
11
+ '''
12
+
13
+
14
+ class LazyFrames(object):
15
+ def __init__(self, frames):
16
+ """This object ensures that common frames between the observations are only stored once.
17
+ It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
18
+ buffers.
19
+ This object should only be converted to numpy array before being passed to the model.
20
+ You'd not believe how complex the previous solution was."""
21
+ self._frames = frames
22
+ self._out = None
23
+
24
+ def _force(self):
25
+ if self._out is None:
26
+ self._out = np.concatenate(self._frames, axis=2)
27
+ self._frames = None
28
+ return self._out
29
+
30
+ def __array__(self, dtype=None):
31
+ out = self._force()
32
+ if dtype is not None:
33
+ out = out.astype(dtype)
34
+ return out
35
+
36
+ def __len__(self):
37
+ return len(self._force())
38
+
39
+ def __getitem__(self, i):
40
+ return self._force()[i]
41
+
42
+ class FireResetEnv(gym.Wrapper):
43
+ def __init__(self, env):
44
+ """Take action on reset for environments that are fixed until firing."""
45
+ gym.Wrapper.__init__(self, env)
46
+ assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
47
+ assert len(env.unwrapped.get_action_meanings()) >= 3
48
+
49
+ def reset(self, **kwargs):
50
+ self.env.reset(**kwargs)
51
+ obs, _, done, _ = self.env.step(1)
52
+ if done:
53
+ self.env.reset(**kwargs)
54
+ obs, _, done, _ = self.env.step(2)
55
+ if done:
56
+ self.env.reset(**kwargs)
57
+ return obs
58
+
59
+ def step(self, ac):
60
+ return self.env.step(ac)
61
+
62
+
63
+ class MaxAndSkipEnv(gym.Wrapper):
64
+ def __init__(self, env, skip=4):
65
+ """Return only every `skip`-th frame"""
66
+ gym.Wrapper.__init__(self, env)
67
+ # most recent raw observations (for max pooling across time steps)
68
+ self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
69
+ self._skip = skip
70
+
71
+ def step(self, action):
72
+ """Repeat action, sum reward, and max over last observations."""
73
+ total_reward = 0.0
74
+ done = None
75
+ for i in range(self._skip):
76
+ obs, reward, done, info = self.env.step(action)
77
+ if i == self._skip - 2: self._obs_buffer[0] = obs
78
+ if i == self._skip - 1: self._obs_buffer[1] = obs
79
+ total_reward += reward
80
+ if done:
81
+ break
82
+ # Note that the observation on the done=True frame
83
+ # doesn't matter
84
+ max_frame = self._obs_buffer.max(axis=0)
85
+
86
+ return max_frame, total_reward, done, info
87
+
88
+ def reset(self, **kwargs):
89
+ return self.env.reset(**kwargs)
90
+
91
+
92
+
93
+ class WarpFrame(gym.ObservationWrapper):
94
+ def __init__(self, env):
95
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
96
+ gym.ObservationWrapper.__init__(self, env)
97
+ self.width = 84
98
+ self.height = 84
99
+ self.observation_space = spaces.Box(low=0, high=255,
100
+ shape=(self.height, self.width, 1), dtype=np.uint8)
101
+
102
+ def observation(self, frame):
103
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
104
+ frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
105
+ return frame[:, :, None]
106
+
107
+ class WarpFrameNoResize(gym.ObservationWrapper):
108
+ def __init__(self, env):
109
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
110
+ gym.ObservationWrapper.__init__(self, env)
111
+
112
+ def observation(self, frame):
113
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
114
+ # frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
115
+ return frame[:, :, None]
116
+
117
+
118
+
119
+ class FrameStack(gym.Wrapper):
120
+ def __init__(self, env, k):
121
+ """Stack k last frames.
122
+ Returns lazy array, which is much more memory efficient.
123
+ See Also
124
+ --------
125
+ baselines.common.atari_wrappers.LazyFrames
126
+ """
127
+ gym.Wrapper.__init__(self, env)
128
+ self.k = k
129
+ self.frames = deque([], maxlen=k)
130
+ shp = env.observation_space.shape
131
+ self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
132
+
133
+ def reset(self):
134
+ ob = self.env.reset()
135
+ for _ in range(self.k):
136
+ self.frames.append(ob)
137
+ return self._get_ob()
138
+
139
+ def step(self, action):
140
+ ob, reward, done, info = self.env.step(action)
141
+ self.frames.append(ob)
142
+ return self._get_ob(), reward, done, info
143
+
144
+ def _get_ob(self):
145
+ assert len(self.frames) == self.k
146
+ return LazyFrames(list(self.frames))
147
+
148
+
149
+ class ImageToPyTorch(gym.ObservationWrapper):
150
+ def __init__(self, env):
151
+ super(ImageToPyTorch, self).__init__(env)
152
+ old_shape = self.observation_space.shape
153
+ self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
154
+
155
+ def observation(self, observation):
156
+ return np.moveaxis(observation, 2, 0)
157
+
158
+
159
+ class ScaledFloatFrame(gym.ObservationWrapper):
160
+ def __init__(self, env):
161
+ gym.ObservationWrapper.__init__(self, env)
162
+ self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
163
+
164
+ def observation(self, observation):
165
+ # careful! This undoes the memory optimization, use
166
+ # with smaller replay buffers only.
167
+ return np.array(observation).astype(np.float32) / 255.0
168
+
169
+ class ClipRewardEnv(gym.RewardWrapper):
170
+ def __init__(self, env):
171
+ gym.RewardWrapper.__init__(self, env)
172
+
173
+ def reward(self, reward):
174
+ """Bin reward to {+1, 0, -1} by its sign."""
175
+ return np.sign(reward)
176
+
177
+ class TanRewardClipperEnv(gym.RewardWrapper):
178
+ def __init__(self, env):
179
+ gym.RewardWrapper.__init__(self, env)
180
+
181
+ def reward(self, reward):
182
+ """Bin reward to {+1, 0, -1} by its sign."""
183
+ return 10 * math.tanh(float(reward)/30.)
184
+
185
+
186
+ def make_lunar(render=False):
187
+ print("Environment: Lunar Lander")
188
+ env = gym.make("LunarLander-v2")
189
+ # env = TanRewardClipperEnv(env)
190
+ # env = WarpFrameNoResize(env) ## Reshape image
191
+ # env = ImageToPyTorch(env) ## Invert shape
192
+ # env = FrameStack(env, 4) ## Stack last 4 frames
193
+ return env
src/procgen/agent.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import torch.nn as nn
5
+ import copy
6
+ import time, datetime
7
+ import matplotlib.pyplot as plt
8
+ from collections import deque
9
+ from torch.utils.tensorboard import SummaryWriter
10
+
11
+
12
+ class DQNet(nn.Module):
13
+ """mini cnn structure
14
+ input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
15
+ """
16
+
17
+ def __init__(self, input_dim, output_dim):
18
+ super().__init__()
19
+ print("#################################")
20
+ print("#################################")
21
+ print(input_dim)
22
+ print(output_dim)
23
+ print("#################################")
24
+ print("#################################")
25
+ c, h, w = input_dim
26
+
27
+
28
+ self.online = nn.Sequential(
29
+ nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
30
+ nn.ReLU(),
31
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
32
+ nn.ReLU(),
33
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
34
+ nn.ReLU(),
35
+ nn.Flatten(),
36
+ nn.Linear(7168, 512),
37
+ nn.ReLU(),
38
+ nn.Linear(512, output_dim),
39
+ )
40
+
41
+
42
+ self.target = copy.deepcopy(self.online)
43
+
44
+ # Q_target parameters are frozen.
45
+ for p in self.target.parameters():
46
+ p.requires_grad = False
47
+
48
+ def forward(self, input, model):
49
+ if model == "online":
50
+ return self.online(input)
51
+ elif model == "target":
52
+ return self.target(input)
53
+
54
+
55
+
56
+ class MetricLogger:
57
+ def __init__(self, save_dir):
58
+ self.writer = SummaryWriter(log_dir=save_dir)
59
+ self.save_log = save_dir / "log"
60
+ with open(self.save_log, "w") as f:
61
+ f.write(
62
+ f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
63
+ f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
64
+ f"{'TimeDelta':>15}{'Time':>20}\n"
65
+ )
66
+ self.ep_rewards_plot = save_dir / "reward_plot.jpg"
67
+ self.ep_lengths_plot = save_dir / "length_plot.jpg"
68
+ self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
69
+ self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
70
+
71
+ # History metrics
72
+ self.ep_rewards = []
73
+ self.ep_lengths = []
74
+ self.ep_avg_losses = []
75
+ self.ep_avg_qs = []
76
+
77
+ # Moving averages, added for every call to record()
78
+ self.moving_avg_ep_rewards = []
79
+ self.moving_avg_ep_lengths = []
80
+ self.moving_avg_ep_avg_losses = []
81
+ self.moving_avg_ep_avg_qs = []
82
+
83
+ # Current episode metric
84
+ self.init_episode()
85
+
86
+ # Timing
87
+ self.record_time = time.time()
88
+
89
+ def log_step(self, reward, loss, q):
90
+ self.curr_ep_reward += reward
91
+ self.curr_ep_length += 1
92
+ if loss:
93
+ self.curr_ep_loss += loss
94
+ self.curr_ep_q += q
95
+ self.curr_ep_loss_length += 1
96
+
97
+ def log_episode(self, episode_number):
98
+ "Mark end of episode"
99
+ self.ep_rewards.append(self.curr_ep_reward)
100
+ self.ep_lengths.append(self.curr_ep_length)
101
+ if self.curr_ep_loss_length == 0:
102
+ ep_avg_loss = 0
103
+ ep_avg_q = 0
104
+ else:
105
+ ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
106
+ ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
107
+ self.ep_avg_losses.append(ep_avg_loss)
108
+ self.ep_avg_qs.append(ep_avg_q)
109
+ self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
110
+ self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
111
+ self.writer.flush()
112
+ self.init_episode()
113
+
114
+ def init_episode(self):
115
+ self.curr_ep_reward = 0.0
116
+ self.curr_ep_length = 0
117
+ self.curr_ep_loss = 0.0
118
+ self.curr_ep_q = 0.0
119
+ self.curr_ep_loss_length = 0
120
+
121
+ def record(self, episode, epsilon, step):
122
+ mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
123
+ mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
124
+ mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
125
+ mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
126
+ self.moving_avg_ep_rewards.append(mean_ep_reward)
127
+ self.moving_avg_ep_lengths.append(mean_ep_length)
128
+ self.moving_avg_ep_avg_losses.append(mean_ep_loss)
129
+ self.moving_avg_ep_avg_qs.append(mean_ep_q)
130
+
131
+ last_record_time = self.record_time
132
+ self.record_time = time.time()
133
+ time_since_last_record = np.round(self.record_time - last_record_time, 3)
134
+
135
+ print(
136
+ f"Episode {episode} - "
137
+ f"Step {step} - "
138
+ f"Epsilon {epsilon} - "
139
+ f"Mean Reward {mean_ep_reward} - "
140
+ f"Mean Length {mean_ep_length} - "
141
+ f"Mean Loss {mean_ep_loss} - "
142
+ f"Mean Q Value {mean_ep_q} - "
143
+ f"Time Delta {time_since_last_record} - "
144
+ f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
145
+ )
146
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
147
+ self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
148
+ self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
149
+ self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
150
+ self.writer.add_scalar("Epsilon value", epsilon, episode)
151
+ self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
152
+ self.writer.flush()
153
+ with open(self.save_log, "a") as f:
154
+ f.write(
155
+ f"{episode:8d}{step:8d}{epsilon:10.3f}"
156
+ f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
157
+ f"{time_since_last_record:15.3f}"
158
+ f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
159
+ )
160
+
161
+ for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
162
+ plt.plot(getattr(self, f"moving_avg_{metric}"))
163
+ plt.savefig(getattr(self, f"{metric}_plot"))
164
+ plt.clf()
165
+
166
+
167
+ class DQNAgent:
168
+ def __init__(self,
169
+ state_dim,
170
+ action_dim,
171
+ save_dir,
172
+ checkpoint=None,
173
+ learning_rate=0.00025,
174
+ max_memory_size=100000,
175
+ batch_size=32,
176
+ exploration_rate=1,
177
+ exploration_rate_decay=0.9999999,
178
+ exploration_rate_min=0.1,
179
+ training_frequency=1,
180
+ learning_starts=1000,
181
+ target_network_sync_frequency=500,
182
+ reset_exploration_rate=False,
183
+ save_frequency=100000,
184
+ gamma=0.9,
185
+ load_replay_buffer=True):
186
+ self.state_dim = state_dim
187
+ self.action_dim = action_dim
188
+ self.max_memory_size = max_memory_size
189
+ self.memory = deque(maxlen=max_memory_size)
190
+ self.batch_size = batch_size
191
+
192
+ self.exploration_rate = exploration_rate
193
+ self.exploration_rate_decay = exploration_rate_decay
194
+ self.exploration_rate_min = exploration_rate_min
195
+ self.gamma = gamma
196
+
197
+ self.curr_step = 0
198
+ self.learning_starts = learning_starts # min. experiences before training
199
+
200
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
201
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
202
+
203
+ self.save_every = save_frequency # no. of experiences between saving the network
204
+ self.save_dir = save_dir
205
+
206
+ self.use_cuda = torch.cuda.is_available()
207
+
208
+ self.net = DQNet(self.state_dim, self.action_dim).float()
209
+ if self.use_cuda:
210
+ self.net = self.net.to(device='cuda')
211
+ if checkpoint:
212
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
213
+
214
+ self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
215
+ self.loss_fn = torch.nn.SmoothL1Loss()
216
+
217
+
218
+ def act(self, state):
219
+ """
220
+ Given a state, choose an epsilon-greedy action and update value of step.
221
+
222
+ Inputs:
223
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
224
+ Outputs:
225
+ action_idx (int): An integer representing which action the agent will perform
226
+ """
227
+ # EXPLORE
228
+ if np.random.rand() < self.exploration_rate:
229
+ action_idx = np.random.randint(self.action_dim)
230
+
231
+ # EXPLOIT
232
+ else:
233
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
234
+ state = state.unsqueeze(0)
235
+ action_values = self.net(state, model='online')
236
+ action_idx = torch.argmax(action_values, axis=1).item()
237
+
238
+ # decrease exploration_rate
239
+ self.exploration_rate *= self.exploration_rate_decay
240
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
241
+
242
+ # increment step
243
+ self.curr_step += 1
244
+ return action_idx
245
+
246
+ def cache(self, state, next_state, action, reward, done):
247
+ """
248
+ Store the experience to self.memory (replay buffer)
249
+
250
+ Inputs:
251
+ state (LazyFrame),
252
+ next_state (LazyFrame),
253
+ action (int),
254
+ reward (float),
255
+ done(bool))
256
+ """
257
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
258
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
259
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
260
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
261
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
262
+
263
+ self.memory.append( (state, next_state, action, reward, done,) )
264
+
265
+
266
+ def recall(self):
267
+ """
268
+ Retrieve a batch of experiences from memory
269
+ """
270
+ batch = random.sample(self.memory, self.batch_size)
271
+ state, next_state, action, reward, done = map(torch.stack, zip(*batch))
272
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
273
+
274
+
275
+ def td_estimate(self, states, actions):
276
+ actions = actions.reshape(-1, 1)
277
+ predicted_qs = self.net(states, model='online')# Q_online(s,a)
278
+ predicted_qs = predicted_qs.gather(1, actions)
279
+ return predicted_qs
280
+
281
+
282
+ @torch.no_grad()
283
+ def td_target(self, rewards, next_states, dones):
284
+ rewards = rewards.reshape(-1, 1)
285
+ dones = dones.reshape(-1, 1)
286
+ target_qs = self.net(next_states, model='target')
287
+ target_qs = torch.max(target_qs, dim=1).values
288
+ target_qs = target_qs.reshape(-1, 1)
289
+ target_qs[dones] = 0.0
290
+ return (rewards + (self.gamma * target_qs))
291
+
292
+ def update_Q_online(self, td_estimate, td_target) :
293
+ loss = self.loss_fn(td_estimate, td_target)
294
+ self.optimizer.zero_grad()
295
+ loss.backward()
296
+ self.optimizer.step()
297
+ return loss.item()
298
+
299
+
300
+ def sync_Q_target(self):
301
+ self.net.target.load_state_dict(self.net.online.state_dict())
302
+
303
+
304
+ def learn(self):
305
+ if self.curr_step % self.target_network_sync_frequency == 0:
306
+ self.sync_Q_target()
307
+
308
+ if self.curr_step % self.save_every == 0:
309
+ self.save()
310
+
311
+ if self.curr_step < self.learning_starts:
312
+ return None, None
313
+
314
+ if self.curr_step % self.training_frequency != 0:
315
+ return None, None
316
+
317
+ # Sample from memory
318
+ state, next_state, action, reward, done = self.recall()
319
+
320
+ # Get TD Estimate
321
+ td_est = self.td_estimate(state, action)
322
+
323
+ # Get TD Target
324
+ td_tgt = self.td_target(reward, next_state, done)
325
+
326
+ # Backpropagate loss through Q_online
327
+ loss = self.update_Q_online(td_est, td_tgt)
328
+
329
+ return (td_est.mean().item(), loss)
330
+
331
+
332
+ def save(self):
333
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
334
+ torch.save(
335
+ dict(
336
+ model=self.net.state_dict(),
337
+ exploration_rate=self.exploration_rate,
338
+ replay_memory=self.memory
339
+ ),
340
+ save_path
341
+ )
342
+
343
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
344
+
345
+
346
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
347
+ if not load_path.exists():
348
+ raise ValueError(f"{load_path} does not exist")
349
+
350
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
351
+ exploration_rate = ckp.get('exploration_rate')
352
+ state_dict = ckp.get('model')
353
+
354
+
355
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
356
+ self.net.load_state_dict(state_dict)
357
+
358
+ if load_replay_buffer:
359
+ replay_memory = ckp.get('replay_memory')
360
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
361
+ self.memory = replay_memory if replay_memory else self.memory
362
+
363
+ if reset_exploration_rate:
364
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
365
+ else:
366
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
367
+ self.exploration_rate = exploration_rate
368
+
369
+
370
+ class DDQNAgent(DQNAgent):
371
+ @torch.no_grad()
372
+ def td_target(self, rewards, next_states, dones):
373
+ rewards = rewards.reshape(-1, 1)
374
+ dones = dones.reshape(-1, 1)
375
+ q_vals = self.net(next_states, model='online')
376
+ target_actions = torch.argmax(q_vals, axis=1)
377
+ target_actions = target_actions.reshape(-1, 1)
378
+
379
+ target_qs = self.net(next_states, model='target')
380
+ target_qs = target_qs.gather(1, target_actions)
381
+ target_qs = target_qs.reshape(-1, 1)
382
+ target_qs[dones] = 0.0
383
+ return (rewards + (self.gamma * target_qs))
384
+
385
+
386
+ class DuelingDQNet(nn.Module):
387
+ def __init__(self, input_dim, output_dim):
388
+ super().__init__()
389
+ print("#################################")
390
+ print("#################################")
391
+ print(input_dim)
392
+ print(output_dim)
393
+ print("#################################")
394
+ print("#################################")
395
+ c, h, w = input_dim
396
+
397
+
398
+ self.conv_layer = nn.Sequential(
399
+ nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
400
+ nn.ReLU(),
401
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
402
+ nn.ReLU(),
403
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
404
+ nn.ReLU(),
405
+
406
+ )
407
+
408
+
409
+ self.value_layer = nn.Sequential(
410
+ nn.Linear(7168, 128),
411
+ nn.ReLU(),
412
+ nn.Linear(128, 1)
413
+ )
414
+
415
+ self.advantage_layer = nn.Sequential(
416
+ nn.Linear(7168, 128),
417
+ nn.ReLU(),
418
+ nn.Linear(128, output_dim)
419
+ )
420
+
421
+ def forward(self, state):
422
+ conv_output = self.conv_layer(state)
423
+ conv_output = conv_output.view(conv_output.size(0), -1)
424
+ value = self.value_layer(conv_output)
425
+ advantage = self.advantage_layer(conv_output)
426
+ q_value = value + (advantage - advantage.mean())
427
+
428
+ return q_value
429
+
430
+
431
+ class DuelingDQNAgent:
432
+ def __init__(self,
433
+ state_dim,
434
+ action_dim,
435
+ save_dir,
436
+ checkpoint=None,
437
+ learning_rate=0.00025,
438
+ max_memory_size=100000,
439
+ batch_size=32,
440
+ exploration_rate=1,
441
+ exploration_rate_decay=0.9999999,
442
+ exploration_rate_min=0.1,
443
+ training_frequency=1,
444
+ learning_starts=1000,
445
+ target_network_sync_frequency=500,
446
+ reset_exploration_rate=False,
447
+ save_frequency=100000,
448
+ gamma=0.9,
449
+ load_replay_buffer=True):
450
+ self.state_dim = state_dim
451
+ self.action_dim = action_dim
452
+ self.max_memory_size = max_memory_size
453
+ self.memory = deque(maxlen=max_memory_size)
454
+ self.batch_size = batch_size
455
+
456
+ self.exploration_rate = exploration_rate
457
+ self.exploration_rate_decay = exploration_rate_decay
458
+ self.exploration_rate_min = exploration_rate_min
459
+ self.gamma = gamma
460
+
461
+ self.curr_step = 0
462
+ self.learning_starts = learning_starts # min. experiences before training
463
+
464
+ self.training_frequency = training_frequency # no. of experiences between updates to Q_online
465
+ self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
466
+
467
+ self.save_every = save_frequency # no. of experiences between saving the network
468
+ self.save_dir = save_dir
469
+
470
+ self.use_cuda = torch.cuda.is_available()
471
+
472
+
473
+ self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
474
+ self.target_net = copy.deepcopy(self.online_net)
475
+ # Q_target parameters are frozen.
476
+ for p in self.target_net.parameters():
477
+ p.requires_grad = False
478
+
479
+ if self.use_cuda:
480
+ self.online_net = self.online_net(device='cuda')
481
+ self.target_net = self.target_net(device='cuda')
482
+ if checkpoint:
483
+ self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
484
+
485
+ self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
486
+ self.loss_fn = torch.nn.SmoothL1Loss()
487
+
488
+
489
+ def act(self, state):
490
+ """
491
+ Given a state, choose an epsilon-greedy action and update value of step.
492
+
493
+ Inputs:
494
+ state(LazyFrame): A single observation of the current state, dimension is (state_dim)
495
+ Outputs:
496
+ action_idx (int): An integer representing which action the agent will perform
497
+ """
498
+ # EXPLORE
499
+ if np.random.rand() < self.exploration_rate:
500
+ action_idx = np.random.randint(self.action_dim)
501
+
502
+ # EXPLOIT
503
+ else:
504
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
505
+ state = state.unsqueeze(0)
506
+ action_values = self.online_net(state)
507
+ action_idx = torch.argmax(action_values, axis=1).item()
508
+
509
+ # decrease exploration_rate
510
+ self.exploration_rate *= self.exploration_rate_decay
511
+ self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
512
+
513
+ # increment step
514
+ self.curr_step += 1
515
+ return action_idx
516
+
517
+ def cache(self, state, next_state, action, reward, done):
518
+ """
519
+ Store the experience to self.memory (replay buffer)
520
+
521
+ Inputs:
522
+ state (LazyFrame),
523
+ next_state (LazyFrame),
524
+ action (int),
525
+ reward (float),
526
+ done(bool))
527
+ """
528
+ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
529
+ next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
530
+ action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
531
+ reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
532
+ done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
533
+
534
+ self.memory.append( (state, next_state, action, reward, done,) )
535
+
536
+
537
+ def recall(self):
538
+ """
539
+ Retrieve a batch of experiences from memory
540
+ """
541
+ batch = random.sample(self.memory, self.batch_size)
542
+ state, next_state, action, reward, done = map(torch.stack, zip(*batch))
543
+ return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
544
+
545
+
546
+ def td_estimate(self, states, actions):
547
+ actions = actions.reshape(-1, 1)
548
+ predicted_qs = self.online_net(states)# Q_online(s,a)
549
+ predicted_qs = predicted_qs.gather(1, actions)
550
+ return predicted_qs
551
+
552
+
553
+ @torch.no_grad()
554
+ def td_target(self, rewards, next_states, dones):
555
+ rewards = rewards.reshape(-1, 1)
556
+ dones = dones.reshape(-1, 1)
557
+ target_qs = self.target_net.forward(next_states)
558
+ target_qs = torch.max(target_qs, dim=1).values
559
+ target_qs = target_qs.reshape(-1, 1)
560
+ target_qs[dones] = 0.0
561
+ return (rewards + (self.gamma * target_qs))
562
+
563
+ def update_Q_online(self, td_estimate, td_target) :
564
+ loss = self.loss_fn(td_estimate, td_target)
565
+ self.optimizer.zero_grad()
566
+ loss.backward()
567
+ self.optimizer.step()
568
+ return loss.item()
569
+
570
+
571
+ def sync_Q_target(self):
572
+ self.target_net.load_state_dict(self.online_net.state_dict())
573
+
574
+
575
+ def learn(self):
576
+ if self.curr_step % self.target_network_sync_frequency == 0:
577
+ self.sync_Q_target()
578
+
579
+ if self.curr_step % self.save_every == 0:
580
+ self.save()
581
+
582
+ if self.curr_step < self.learning_starts:
583
+ return None, None
584
+
585
+ if self.curr_step % self.training_frequency != 0:
586
+ return None, None
587
+
588
+ # Sample from memory
589
+ state, next_state, action, reward, done = self.recall()
590
+
591
+ # Get TD Estimate
592
+ td_est = self.td_estimate(state, action)
593
+
594
+ # Get TD Target
595
+ td_tgt = self.td_target(reward, next_state, done)
596
+
597
+ # Backpropagate loss through Q_online
598
+ loss = self.update_Q_online(td_est, td_tgt)
599
+
600
+ return (td_est.mean().item(), loss)
601
+
602
+
603
+ def save(self):
604
+ save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
605
+ torch.save(
606
+ dict(
607
+ model=self.online_net.state_dict(),
608
+ exploration_rate=self.exploration_rate,
609
+ replay_memory=self.memory
610
+ ),
611
+ save_path
612
+ )
613
+
614
+ print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
615
+
616
+
617
+ def load(self, load_path, reset_exploration_rate, load_replay_buffer):
618
+ if not load_path.exists():
619
+ raise ValueError(f"{load_path} does not exist")
620
+
621
+ ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
622
+ exploration_rate = ckp.get('exploration_rate')
623
+ state_dict = ckp.get('model')
624
+
625
+
626
+ print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
627
+ self.online_net.load_state_dict(state_dict)
628
+ self.target_net = copy.deepcopy(self.online_net)
629
+ self.sync_Q_target()
630
+
631
+ if load_replay_buffer:
632
+ replay_memory = ckp.get('replay_memory')
633
+ print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
634
+ self.memory = replay_memory if replay_memory else self.memory
635
+
636
+ if reset_exploration_rate:
637
+ print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
638
+ else:
639
+ print(f"Setting exploration rate to {exploration_rate} not loaded.")
640
+ self.exploration_rate = exploration_rate
641
+
642
+
643
+
644
+
645
+ class DuelingDDQNAgent(DuelingDQNAgent):
646
+ @torch.no_grad()
647
+ def td_target(self, rewards, next_states, dones):
648
+ rewards = rewards.reshape(-1, 1)
649
+ dones = dones.reshape(-1, 1)
650
+ q_vals = self.online_net.forward(next_states)
651
+ target_actions = torch.argmax(q_vals, axis=1)
652
+ target_actions = target_actions.reshape(-1, 1)
653
+
654
+ target_qs = self.target_net.forward(next_states)
655
+ target_qs = target_qs.gather(1, target_actions)
656
+ target_qs = target_qs.reshape(-1, 1)
657
+ target_qs[dones] = 0.0
658
+ return (rewards + (self.gamma * target_qs))
659
+
660
+
661
+
662
+
663
+
664
+
src/procgen/run-starpilot-ddqn.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DDQNAgent, MetricLogger
6
+ from wrappers import make_starpilot
7
+ import os
8
+ from train import train, fill_memory
9
+
10
+
11
+ env = make_starpilot()
12
+
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"Using CUDA: {use_cuda}\n")
15
+
16
+ checkpoint = None
17
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
18
+
19
+ path = "checkpoints/procgen-starpilot-ddqn"
20
+ save_dir = Path(path)
21
+
22
+ isExist = os.path.exists(path)
23
+ if not isExist:
24
+ os.makedirs(path)
25
+
26
+ logger = MetricLogger(save_dir)
27
+
28
+ print("Training DDQN Agent!")
29
+ agent = DDQNAgent(
30
+ state_dim=(1, 64, 64),
31
+ action_dim=env.action_space.n,
32
+ save_dir=save_dir,
33
+ batch_size=256,
34
+ checkpoint=checkpoint,
35
+ exploration_rate_decay=0.999995,
36
+ exploration_rate_min=0.05,
37
+ training_frequency=1,
38
+ target_network_sync_frequency=200,
39
+ max_memory_size=50000,
40
+ learning_rate=0.0005,
41
+
42
+ )
43
+
44
+ fill_memory(agent, env, 300)
45
+ train(agent, env, logger)
src/procgen/run-starpilot-dqn.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DQNAgent, MetricLogger
6
+ from wrappers import make_starpilot
7
+ import os
8
+ from train import train, fill_memory
9
+
10
+
11
+ env = make_starpilot()
12
+
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"Using CUDA: {use_cuda}\n")
15
+
16
+ checkpoint = None
17
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
18
+
19
+ path = "checkpoints/procgen-starpilot-dqn"
20
+ save_dir = Path(path)
21
+
22
+ isExist = os.path.exists(path)
23
+ if not isExist:
24
+ os.makedirs(path)
25
+
26
+ logger = MetricLogger(save_dir)
27
+
28
+ print("Training Vanilla DQN Agent!")
29
+ agent = DQNAgent(
30
+ state_dim=(1, 64, 64),
31
+ action_dim=env.action_space.n,
32
+ save_dir=save_dir,
33
+ batch_size=256,
34
+ checkpoint=checkpoint,
35
+ exploration_rate_decay=0.999995,
36
+ exploration_rate_min=0.05,
37
+ training_frequency=1,
38
+ target_network_sync_frequency=200,
39
+ max_memory_size=50000,
40
+ learning_rate=0.0005,
41
+
42
+ )
43
+
44
+ fill_memory(agent, env, 300)
45
+ train(agent, env, logger)
src/procgen/run-starpilot-dueling-ddqn.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DuelingDDQNAgent, MetricLogger
6
+ from wrappers import make_starpilot
7
+ import os
8
+ from train import train, fill_memory
9
+
10
+
11
+ env = make_starpilot()
12
+
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"Using CUDA: {use_cuda}\n")
15
+
16
+ checkpoint = None
17
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
18
+
19
+ path = "checkpoints/procgen-starpilot-dueling-ddqn"
20
+ save_dir = Path(path)
21
+
22
+ isExist = os.path.exists(path)
23
+ if not isExist:
24
+ os.makedirs(path)
25
+
26
+ logger = MetricLogger(save_dir)
27
+
28
+ print("Training Dueling Double DQN Agent!")
29
+ agent = DuelingDDQNAgent(
30
+ state_dim=(1, 64, 64),
31
+ action_dim=env.action_space.n,
32
+ save_dir=save_dir,
33
+ batch_size=256,
34
+ checkpoint=checkpoint,
35
+ exploration_rate_decay=0.999995,
36
+ exploration_rate_min=0.05,
37
+ training_frequency=1,
38
+ target_network_sync_frequency=200,
39
+ max_memory_size=50000,
40
+ learning_rate=0.0005,
41
+
42
+ )
43
+
44
+ # fill_memory(agent, env, 300)
45
+ train(agent, env, logger)
src/procgen/run-starpilot-dueling-dqn.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ from agent import DuelingDQNAgent, MetricLogger
6
+ from wrappers import make_starpilot
7
+ import os
8
+ from train import train, fill_memory
9
+
10
+
11
+ env = make_starpilot()
12
+
13
+ use_cuda = torch.cuda.is_available()
14
+ print(f"Using CUDA: {use_cuda}\n")
15
+
16
+ checkpoint = None
17
+ # checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
18
+
19
+ path = "checkpoints/procgen-starpilot-dueling-dqn"
20
+ save_dir = Path(path)
21
+
22
+ isExist = os.path.exists(path)
23
+ if not isExist:
24
+ os.makedirs(path)
25
+
26
+ logger = MetricLogger(save_dir)
27
+
28
+ print("Training Dueling DQN Agent!")
29
+ agent = DuelingDQNAgent(
30
+ state_dim=(1, 64, 64),
31
+ action_dim=env.action_space.n,
32
+ save_dir=save_dir,
33
+ batch_size=256,
34
+ checkpoint=checkpoint,
35
+ exploration_rate_decay=0.999995,
36
+ exploration_rate_min=0.05,
37
+ training_frequency=1,
38
+ target_network_sync_frequency=200,
39
+ max_memory_size=50000,
40
+ learning_rate=0.0005,
41
+
42
+ )
43
+
44
+ # fill_memory(agent, env, 300)
45
+ train(agent, env, logger)
src/procgen/test-procgen.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ env = gym.make("procgen:procgen-starpilot-v0")
3
+
4
+ obs = env.reset()
5
+ step = 0
6
+ while True:
7
+ obs, rew, done, info = env.step(env.action_space.sample())
8
+ print(info)
9
+ print(f"step {step} reward {rew} done {done}")
10
+ step += 1
11
+ if done:
12
+ break
src/procgen/train.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import trange
2
+
3
+ def fill_memory(agent, env, num_episodes=500 ):
4
+ print("Filling up memory....")
5
+ for _ in trange(num_episodes):
6
+ state = env.reset()
7
+ done = False
8
+ while not done:
9
+ action = agent.act(state)
10
+ next_state, reward, done, _ = env.step(action)
11
+ agent.cache(state, next_state, action, reward, done)
12
+ state = next_state
13
+
14
+
15
+ def train(agent, env, logger):
16
+ episodes = 5000
17
+ for e in range(episodes):
18
+
19
+ state = env.reset()
20
+ # Play the game!
21
+ while True:
22
+
23
+ # Run agent on the state
24
+ action = agent.act(state)
25
+
26
+ # Agent performs action
27
+ next_state, reward, done, info = env.step(action)
28
+
29
+ # Remember
30
+ agent.cache(state, next_state, action, reward, done)
31
+
32
+ # Learn
33
+ q, loss = agent.learn()
34
+
35
+ # Logging
36
+ logger.log_step(reward, loss, q)
37
+
38
+ # Update state
39
+ state = next_state
40
+
41
+ # Check if end of game
42
+ if done:
43
+ break
44
+
45
+ logger.log_episode(e)
46
+
47
+ if e % 20 == 0:
48
+ logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
src/procgen/wrappers.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from collections import deque
4
+ import gym
5
+ from gym import spaces
6
+ import cv2
7
+
8
+
9
+ '''
10
+ Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
11
+ '''
12
+
13
+
14
+ class LazyFrames(object):
15
+ def __init__(self, frames):
16
+ """This object ensures that common frames between the observations are only stored once.
17
+ It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
18
+ buffers.
19
+ This object should only be converted to numpy array before being passed to the model.
20
+ You'd not believe how complex the previous solution was."""
21
+ self._frames = frames
22
+ self._out = None
23
+
24
+ def _force(self):
25
+ if self._out is None:
26
+ self._out = np.concatenate(self._frames, axis=2)
27
+ self._frames = None
28
+ return self._out
29
+
30
+ def __array__(self, dtype=None):
31
+ out = self._force()
32
+ if dtype is not None:
33
+ out = out.astype(dtype)
34
+ return out
35
+
36
+ def __len__(self):
37
+ return len(self._force())
38
+
39
+ def __getitem__(self, i):
40
+ return self._force()[i]
41
+
42
+ class FireResetEnv(gym.Wrapper):
43
+ def __init__(self, env):
44
+ """Take action on reset for environments that are fixed until firing."""
45
+ gym.Wrapper.__init__(self, env)
46
+ assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
47
+ assert len(env.unwrapped.get_action_meanings()) >= 3
48
+
49
+ def reset(self, **kwargs):
50
+ self.env.reset(**kwargs)
51
+ obs, _, done, _ = self.env.step(1)
52
+ if done:
53
+ self.env.reset(**kwargs)
54
+ obs, _, done, _ = self.env.step(2)
55
+ if done:
56
+ self.env.reset(**kwargs)
57
+ return obs
58
+
59
+ def step(self, ac):
60
+ return self.env.step(ac)
61
+
62
+
63
+ class MaxAndSkipEnv(gym.Wrapper):
64
+ def __init__(self, env, skip=4):
65
+ """Return only every `skip`-th frame"""
66
+ gym.Wrapper.__init__(self, env)
67
+ # most recent raw observations (for max pooling across time steps)
68
+ self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
69
+ self._skip = skip
70
+
71
+ def step(self, action):
72
+ """Repeat action, sum reward, and max over last observations."""
73
+ total_reward = 0.0
74
+ done = None
75
+ for i in range(self._skip):
76
+ obs, reward, done, info = self.env.step(action)
77
+ if i == self._skip - 2: self._obs_buffer[0] = obs
78
+ if i == self._skip - 1: self._obs_buffer[1] = obs
79
+ total_reward += reward
80
+ if done:
81
+ break
82
+ # Note that the observation on the done=True frame
83
+ # doesn't matter
84
+ max_frame = self._obs_buffer.max(axis=0)
85
+
86
+ return max_frame, total_reward, done, info
87
+
88
+ def reset(self, **kwargs):
89
+ return self.env.reset(**kwargs)
90
+
91
+
92
+
93
+ class WarpFrame(gym.ObservationWrapper):
94
+ def __init__(self, env):
95
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
96
+ gym.ObservationWrapper.__init__(self, env)
97
+ self.width = 84
98
+ self.height = 84
99
+ self.observation_space = spaces.Box(low=0, high=255,
100
+ shape=(self.height, self.width, 1), dtype=np.uint8)
101
+
102
+ def observation(self, frame):
103
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
104
+ frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
105
+ return frame[:, :, None]
106
+
107
+ class WarpFrameNoResize(gym.ObservationWrapper):
108
+ def __init__(self, env):
109
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
110
+ gym.ObservationWrapper.__init__(self, env)
111
+
112
+ def observation(self, frame):
113
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
114
+ # frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
115
+ return frame[:, :, None]
116
+
117
+
118
+
119
+ class FrameStack(gym.Wrapper):
120
+ def __init__(self, env, k):
121
+ """Stack k last frames.
122
+ Returns lazy array, which is much more memory efficient.
123
+ See Also
124
+ --------
125
+ baselines.common.atari_wrappers.LazyFrames
126
+ """
127
+ gym.Wrapper.__init__(self, env)
128
+ self.k = k
129
+ self.frames = deque([], maxlen=k)
130
+ shp = env.observation_space.shape
131
+ self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
132
+
133
+ def reset(self):
134
+ ob = self.env.reset()
135
+ for _ in range(self.k):
136
+ self.frames.append(ob)
137
+ return self._get_ob()
138
+
139
+ def step(self, action):
140
+ ob, reward, done, info = self.env.step(action)
141
+ self.frames.append(ob)
142
+ return self._get_ob(), reward, done, info
143
+
144
+ def _get_ob(self):
145
+ assert len(self.frames) == self.k
146
+ return LazyFrames(list(self.frames))
147
+
148
+
149
+ class ImageToPyTorch(gym.ObservationWrapper):
150
+ def __init__(self, env):
151
+ super(ImageToPyTorch, self).__init__(env)
152
+ old_shape = self.observation_space.shape
153
+ self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
154
+
155
+ def observation(self, observation):
156
+ return np.moveaxis(observation, 2, 0)
157
+
158
+
159
+ class ScaledFloatFrame(gym.ObservationWrapper):
160
+ def __init__(self, env):
161
+ gym.ObservationWrapper.__init__(self, env)
162
+ self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
163
+
164
+ def observation(self, observation):
165
+ # careful! This undoes the memory optimization, use
166
+ # with smaller replay buffers only.
167
+ return np.array(observation).astype(np.float32) / 255.0
168
+
169
+ class ClipRewardEnv(gym.RewardWrapper):
170
+ def __init__(self, env):
171
+ gym.RewardWrapper.__init__(self, env)
172
+
173
+ def reward(self, reward):
174
+ """Bin reward to {+1, 0, -1} by its sign."""
175
+ return np.sign(reward)
176
+
177
+
178
+ def make_starpilot(render=False):
179
+ print("Environment: Starpilot")
180
+ if render:
181
+ env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy", render_mode="human")
182
+ else:
183
+ env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy")
184
+ env = WarpFrameNoResize(env) ## Reshape image
185
+ env = ImageToPyTorch(env) ## Invert shape
186
+ env = FrameStack(env, 4) ## Stack last 4 frames
187
+ return env
troubleshooting.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ml-reinforcement-learning
2
+
3
+ Python version: 3.7.3
4
+
5
+
6
+ Troubleshooting
7
+
8
+
9
+ - RuntimeError: Polyfit sanity test emitted a warning, most likely due to using a buggy Accelerate backend. If you compiled yourself, more information is available at https://numpy.org/doc/stable/user/building.html#accelerated-blas-lapack-libraries Otherwise report this to the vendor that provided NumPy.
10
+ RankWarning: Polyfit may be poorly conditioned
11
+
12
+ ```
13
+ $ pip uninstall numpy
14
+ $ export OPENBLAS=$(brew --prefix openblas)
15
+ $ pip install --no-cache-dir numpy
16
+ ```
17
+
18
+
19
+ During grpcio installation 👇
20
+ distutils.errors.CompileError: command 'clang' failed with exit status 1
21
+ ```
22
+ CFLAGS="-I/Library/Developer/CommandLineTools/usr/include/c++/v1 -I/opt/homebrew/opt/openssl/include" LDFLAGS="-L/opt/homebrew/opt/openssl/lib" pip3 install grpcio
23
+ ```
24
+
25
+
26
+ ModuleNotFoundError: No module named 'gym.envs.classic_control.rendering'
27
+
28
+
29
+ #Setup
30
+
31
+ ```
32
+ conda install pytorch torchvision -c pytorch
33
+ pip install gym-retro
34
+ conda install numpy
35
+ pip install "gym[atari]==0.21.0"
36
+ pip install importlib-metadata==4.13.0
37
+ ```