merve HF staff commited on
Commit
382424a
1 Parent(s): 10b396f

Upload ddpg_pendulum.ipynb

Browse files
Files changed (1) hide show
  1. ddpg_pendulum.ipynb +1077 -0
ddpg_pendulum.ipynb ADDED
@@ -0,0 +1,1077 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "W5ut4-Uo_wFL"
7
+ },
8
+ "source": [
9
+ "# Deep Deterministic Policy Gradient (DDPG)\n",
10
+ "\n",
11
+ "**Author:** [amifunny](https://github.com/amifunny)<br>\n",
12
+ "**Date created:** 2020/06/04<br>\n",
13
+ "**Last modified:** 2020/09/21<br>\n",
14
+ "**Description:** Implementing DDPG algorithm on the Inverted Pendulum Problem."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {
20
+ "id": "1eX-gAYp_wFP"
21
+ },
22
+ "source": [
23
+ "## Introduction\n",
24
+ "\n",
25
+ "**Deep Deterministic Policy Gradient (DDPG)** is a model-free off-policy algorithm for\n",
26
+ "learning continous actions.\n",
27
+ "\n",
28
+ "It combines ideas from DPG (Deterministic Policy Gradient) and DQN (Deep Q-Network).\n",
29
+ "It uses Experience Replay and slow-learning target networks from DQN, and it is based on\n",
30
+ "DPG,\n",
31
+ "which can operate over continuous action spaces.\n",
32
+ "\n",
33
+ "This tutorial closely follow this paper -\n",
34
+ "[Continuous control with deep reinforcement learning](https://arxiv.org/pdf/1509.02971.pdf)\n",
35
+ "\n",
36
+ "## Problem\n",
37
+ "\n",
38
+ "We are trying to solve the classic **Inverted Pendulum** control problem.\n",
39
+ "In this setting, we can take only two actions: swing left or swing right.\n",
40
+ "\n",
41
+ "What make this problem challenging for Q-Learning Algorithms is that actions\n",
42
+ "are **continuous** instead of being **discrete**. That is, instead of using two\n",
43
+ "discrete actions like `-1` or `+1`, we have to select from infinite actions\n",
44
+ "ranging from `-2` to `+2`.\n",
45
+ "\n",
46
+ "## Quick theory\n",
47
+ "\n",
48
+ "Just like the Actor-Critic method, we have two networks:\n",
49
+ "\n",
50
+ "1. Actor - It proposes an action given a state.\n",
51
+ "2. Critic - It predicts if the action is good (positive value) or bad (negative value)\n",
52
+ "given a state and an action.\n",
53
+ "\n",
54
+ "DDPG uses two more techniques not present in the original DQN:\n",
55
+ "\n",
56
+ "**First, it uses two Target networks.**\n",
57
+ "\n",
58
+ "**Why?** Because it add stability to training. In short, we are learning from estimated\n",
59
+ "targets and Target networks are updated slowly, hence keeping our estimated targets\n",
60
+ "stable.\n",
61
+ "\n",
62
+ "Conceptually, this is like saying, \"I have an idea of how to play this well,\n",
63
+ "I'm going to try it out for a bit until I find something better\",\n",
64
+ "as opposed to saying \"I'm going to re-learn how to play this entire game after every\n",
65
+ "move\".\n",
66
+ "See this [StackOverflow answer](https://stackoverflow.com/a/54238556/13475679).\n",
67
+ "\n",
68
+ "**Second, it uses Experience Replay.**\n",
69
+ "\n",
70
+ "We store list of tuples `(state, action, reward, next_state)`, and instead of\n",
71
+ "learning only from recent experience, we learn from sampling all of our experience\n",
72
+ "accumulated so far.\n",
73
+ "\n",
74
+ "Now, let's see how is it implemented."
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 1,
80
+ "metadata": {
81
+ "id": "EhtEA5C1_wFR"
82
+ },
83
+ "outputs": [],
84
+ "source": [
85
+ "import gym\n",
86
+ "import tensorflow as tf\n",
87
+ "from tensorflow.keras import layers\n",
88
+ "import numpy as np\n",
89
+ "import matplotlib.pyplot as plt"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {
95
+ "id": "vvhqTnJ8_wFT"
96
+ },
97
+ "source": [
98
+ "We use [OpenAIGym](http://gym.openai.com/docs) to create the environment.\n",
99
+ "We will use the `upper_bound` parameter to scale our actions later."
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 2,
105
+ "metadata": {
106
+ "id": "6limWVE-_wFU",
107
+ "outputId": "8d672186-664b-40c5-ce82-e3450bf42221",
108
+ "colab": {
109
+ "base_uri": "https://localhost:8080/"
110
+ }
111
+ },
112
+ "outputs": [
113
+ {
114
+ "output_type": "stream",
115
+ "name": "stdout",
116
+ "text": [
117
+ "Size of State Space -> 3\n",
118
+ "Size of Action Space -> 1\n",
119
+ "Max Value of Action -> 2.0\n",
120
+ "Min Value of Action -> -2.0\n"
121
+ ]
122
+ }
123
+ ],
124
+ "source": [
125
+ "problem = \"Pendulum-v0\"\n",
126
+ "env = gym.make(problem)\n",
127
+ "\n",
128
+ "num_states = env.observation_space.shape[0]\n",
129
+ "print(\"Size of State Space -> {}\".format(num_states))\n",
130
+ "num_actions = env.action_space.shape[0]\n",
131
+ "print(\"Size of Action Space -> {}\".format(num_actions))\n",
132
+ "\n",
133
+ "upper_bound = env.action_space.high[0]\n",
134
+ "lower_bound = env.action_space.low[0]\n",
135
+ "\n",
136
+ "print(\"Max Value of Action -> {}\".format(upper_bound))\n",
137
+ "print(\"Min Value of Action -> {}\".format(lower_bound))"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "metadata": {
143
+ "id": "SxQKZi35_wFU"
144
+ },
145
+ "source": [
146
+ "To implement better exploration by the Actor network, we use noisy perturbations,\n",
147
+ "specifically\n",
148
+ "an **Ornstein-Uhlenbeck process** for generating noise, as described in the paper.\n",
149
+ "It samples noise from a correlated normal distribution."
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 3,
155
+ "metadata": {
156
+ "id": "0u9tVI2J_wFV"
157
+ },
158
+ "outputs": [],
159
+ "source": [
160
+ "\n",
161
+ "class OUActionNoise:\n",
162
+ " def __init__(self, mean, std_deviation, theta=0.15, dt=1e-2, x_initial=None):\n",
163
+ " self.theta = theta\n",
164
+ " self.mean = mean\n",
165
+ " self.std_dev = std_deviation\n",
166
+ " self.dt = dt\n",
167
+ " self.x_initial = x_initial\n",
168
+ " self.reset()\n",
169
+ "\n",
170
+ " def __call__(self):\n",
171
+ " # Formula taken from https://www.wikipedia.org/wiki/Ornstein-Uhlenbeck_process.\n",
172
+ " x = (\n",
173
+ " self.x_prev\n",
174
+ " + self.theta * (self.mean - self.x_prev) * self.dt\n",
175
+ " + self.std_dev * np.sqrt(self.dt) * np.random.normal(size=self.mean.shape)\n",
176
+ " )\n",
177
+ " # Store x into x_prev\n",
178
+ " # Makes next noise dependent on current one\n",
179
+ " self.x_prev = x\n",
180
+ " return x\n",
181
+ "\n",
182
+ " def reset(self):\n",
183
+ " if self.x_initial is not None:\n",
184
+ " self.x_prev = self.x_initial\n",
185
+ " else:\n",
186
+ " self.x_prev = np.zeros_like(self.mean)\n"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {
192
+ "id": "aiaIXtYc_wFW"
193
+ },
194
+ "source": [
195
+ "The `Buffer` class implements Experience Replay.\n",
196
+ "\n",
197
+ "---\n",
198
+ "![Algorithm](https://i.imgur.com/mS6iGyJ.jpg)\n",
199
+ "---\n",
200
+ "\n",
201
+ "\n",
202
+ "**Critic loss** - Mean Squared Error of `y - Q(s, a)`\n",
203
+ "where `y` is the expected return as seen by the Target network,\n",
204
+ "and `Q(s, a)` is action value predicted by the Critic network. `y` is a moving target\n",
205
+ "that the critic model tries to achieve; we make this target\n",
206
+ "stable by updating the Target model slowly.\n",
207
+ "\n",
208
+ "**Actor loss** - This is computed using the mean of the value given by the Critic network\n",
209
+ "for the actions taken by the Actor network. We seek to maximize this quantity.\n",
210
+ "\n",
211
+ "Hence we update the Actor network so that it produces actions that get\n",
212
+ "the maximum predicted value as seen by the Critic, for a given state."
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 4,
218
+ "metadata": {
219
+ "id": "HmrqnrR3_wFX"
220
+ },
221
+ "outputs": [],
222
+ "source": [
223
+ "\n",
224
+ "class Buffer:\n",
225
+ " def __init__(self, buffer_capacity=100000, batch_size=64):\n",
226
+ " # Number of \"experiences\" to store at max\n",
227
+ " self.buffer_capacity = buffer_capacity\n",
228
+ " # Num of tuples to train on.\n",
229
+ " self.batch_size = batch_size\n",
230
+ "\n",
231
+ " # Its tells us num of times record() was called.\n",
232
+ " self.buffer_counter = 0\n",
233
+ "\n",
234
+ " # Instead of list of tuples as the exp.replay concept go\n",
235
+ " # We use different np.arrays for each tuple element\n",
236
+ " self.state_buffer = np.zeros((self.buffer_capacity, num_states))\n",
237
+ " self.action_buffer = np.zeros((self.buffer_capacity, num_actions))\n",
238
+ " self.reward_buffer = np.zeros((self.buffer_capacity, 1))\n",
239
+ " self.next_state_buffer = np.zeros((self.buffer_capacity, num_states))\n",
240
+ "\n",
241
+ " # Takes (s,a,r,s') obervation tuple as input\n",
242
+ " def record(self, obs_tuple):\n",
243
+ " # Set index to zero if buffer_capacity is exceeded,\n",
244
+ " # replacing old records\n",
245
+ " index = self.buffer_counter % self.buffer_capacity\n",
246
+ "\n",
247
+ " self.state_buffer[index] = obs_tuple[0]\n",
248
+ " self.action_buffer[index] = obs_tuple[1]\n",
249
+ " self.reward_buffer[index] = obs_tuple[2]\n",
250
+ " self.next_state_buffer[index] = obs_tuple[3]\n",
251
+ "\n",
252
+ " self.buffer_counter += 1\n",
253
+ "\n",
254
+ " # Eager execution is turned on by default in TensorFlow 2. Decorating with tf.function allows\n",
255
+ " # TensorFlow to build a static graph out of the logic and computations in our function.\n",
256
+ " # This provides a large speed up for blocks of code that contain many small TensorFlow operations such as this one.\n",
257
+ " @tf.function\n",
258
+ " def update(\n",
259
+ " self, state_batch, action_batch, reward_batch, next_state_batch,\n",
260
+ " ):\n",
261
+ " # Training and updating Actor & Critic networks.\n",
262
+ " # See Pseudo Code.\n",
263
+ " with tf.GradientTape() as tape:\n",
264
+ " target_actions = target_actor(next_state_batch, training=True)\n",
265
+ " y = reward_batch + gamma * target_critic(\n",
266
+ " [next_state_batch, target_actions], training=True\n",
267
+ " )\n",
268
+ " critic_value = critic_model([state_batch, action_batch], training=True)\n",
269
+ " critic_loss = tf.math.reduce_mean(tf.math.square(y - critic_value))\n",
270
+ "\n",
271
+ " critic_grad = tape.gradient(critic_loss, critic_model.trainable_variables)\n",
272
+ " critic_optimizer.apply_gradients(\n",
273
+ " zip(critic_grad, critic_model.trainable_variables)\n",
274
+ " )\n",
275
+ "\n",
276
+ " with tf.GradientTape() as tape:\n",
277
+ " actions = actor_model(state_batch, training=True)\n",
278
+ " critic_value = critic_model([state_batch, actions], training=True)\n",
279
+ " # Used `-value` as we want to maximize the value given\n",
280
+ " # by the critic for our actions\n",
281
+ " actor_loss = -tf.math.reduce_mean(critic_value)\n",
282
+ "\n",
283
+ " actor_grad = tape.gradient(actor_loss, actor_model.trainable_variables)\n",
284
+ " actor_optimizer.apply_gradients(\n",
285
+ " zip(actor_grad, actor_model.trainable_variables)\n",
286
+ " )\n",
287
+ "\n",
288
+ " # We compute the loss and update parameters\n",
289
+ " def learn(self):\n",
290
+ " # Get sampling range\n",
291
+ " record_range = min(self.buffer_counter, self.buffer_capacity)\n",
292
+ " # Randomly sample indices\n",
293
+ " batch_indices = np.random.choice(record_range, self.batch_size)\n",
294
+ "\n",
295
+ " # Convert to tensors\n",
296
+ " state_batch = tf.convert_to_tensor(self.state_buffer[batch_indices])\n",
297
+ " action_batch = tf.convert_to_tensor(self.action_buffer[batch_indices])\n",
298
+ " reward_batch = tf.convert_to_tensor(self.reward_buffer[batch_indices])\n",
299
+ " reward_batch = tf.cast(reward_batch, dtype=tf.float32)\n",
300
+ " next_state_batch = tf.convert_to_tensor(self.next_state_buffer[batch_indices])\n",
301
+ "\n",
302
+ " self.update(state_batch, action_batch, reward_batch, next_state_batch)\n",
303
+ "\n",
304
+ "\n",
305
+ "# This update target parameters slowly\n",
306
+ "# Based on rate `tau`, which is much less than one.\n",
307
+ "@tf.function\n",
308
+ "def update_target(target_weights, weights, tau):\n",
309
+ " for (a, b) in zip(target_weights, weights):\n",
310
+ " a.assign(b * tau + a * (1 - tau))\n"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "markdown",
315
+ "metadata": {
316
+ "id": "yuatLEJ3_wFY"
317
+ },
318
+ "source": [
319
+ "Here we define the Actor and Critic networks. These are basic Dense models\n",
320
+ "with `ReLU` activation.\n",
321
+ "\n",
322
+ "Note: We need the initialization for last layer of the Actor to be between\n",
323
+ "`-0.003` and `0.003` as this prevents us from getting `1` or `-1` output values in\n",
324
+ "the initial stages, which would squash our gradients to zero,\n",
325
+ "as we use the `tanh` activation."
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": 5,
331
+ "metadata": {
332
+ "id": "OCCV2VAQ_wFY"
333
+ },
334
+ "outputs": [],
335
+ "source": [
336
+ "\n",
337
+ "def get_actor():\n",
338
+ " # Initialize weights between -3e-3 and 3-e3\n",
339
+ " last_init = tf.random_uniform_initializer(minval=-0.003, maxval=0.003)\n",
340
+ "\n",
341
+ " inputs = layers.Input(shape=(num_states,))\n",
342
+ " out = layers.Dense(256, activation=\"relu\")(inputs)\n",
343
+ " out = layers.Dense(256, activation=\"relu\")(out)\n",
344
+ " outputs = layers.Dense(1, activation=\"tanh\", kernel_initializer=last_init)(out)\n",
345
+ "\n",
346
+ " # Our upper bound is 2.0 for Pendulum.\n",
347
+ " outputs = outputs * upper_bound\n",
348
+ " model = tf.keras.Model(inputs, outputs)\n",
349
+ " return model\n",
350
+ "\n",
351
+ "\n",
352
+ "def get_critic():\n",
353
+ " # State as input\n",
354
+ " state_input = layers.Input(shape=(num_states))\n",
355
+ " state_out = layers.Dense(16, activation=\"relu\")(state_input)\n",
356
+ " state_out = layers.Dense(32, activation=\"relu\")(state_out)\n",
357
+ "\n",
358
+ " # Action as input\n",
359
+ " action_input = layers.Input(shape=(num_actions))\n",
360
+ " action_out = layers.Dense(32, activation=\"relu\")(action_input)\n",
361
+ "\n",
362
+ " # Both are passed through seperate layer before concatenating\n",
363
+ " concat = layers.Concatenate()([state_out, action_out])\n",
364
+ "\n",
365
+ " out = layers.Dense(256, activation=\"relu\")(concat)\n",
366
+ " out = layers.Dense(256, activation=\"relu\")(out)\n",
367
+ " outputs = layers.Dense(1)(out)\n",
368
+ "\n",
369
+ " # Outputs single value for give state-action\n",
370
+ " model = tf.keras.Model([state_input, action_input], outputs)\n",
371
+ "\n",
372
+ " return model\n"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "markdown",
377
+ "metadata": {
378
+ "id": "gkg29m65_wFZ"
379
+ },
380
+ "source": [
381
+ "`policy()` returns an action sampled from our Actor network plus some noise for\n",
382
+ "exploration."
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 6,
388
+ "metadata": {
389
+ "id": "KmHbyy8l_wFZ"
390
+ },
391
+ "outputs": [],
392
+ "source": [
393
+ "\n",
394
+ "def policy(state, noise_object):\n",
395
+ " sampled_actions = tf.squeeze(actor_model(state))\n",
396
+ " noise = noise_object()\n",
397
+ " # Adding noise to action\n",
398
+ " sampled_actions = sampled_actions.numpy() + noise\n",
399
+ "\n",
400
+ " # We make sure action is within bounds\n",
401
+ " legal_action = np.clip(sampled_actions, lower_bound, upper_bound)\n",
402
+ "\n",
403
+ " return [np.squeeze(legal_action)]\n"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "markdown",
408
+ "metadata": {
409
+ "id": "r2EUVRZA_wFa"
410
+ },
411
+ "source": [
412
+ "## Training hyperparameters"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": 7,
418
+ "metadata": {
419
+ "id": "8FELtxWr_wFa"
420
+ },
421
+ "outputs": [],
422
+ "source": [
423
+ "std_dev = 0.2\n",
424
+ "ou_noise = OUActionNoise(mean=np.zeros(1), std_deviation=float(std_dev) * np.ones(1))\n",
425
+ "\n",
426
+ "actor_model = get_actor()\n",
427
+ "critic_model = get_critic()\n",
428
+ "\n",
429
+ "target_actor = get_actor()\n",
430
+ "target_critic = get_critic()\n",
431
+ "\n",
432
+ "# Making the weights equal initially\n",
433
+ "target_actor.set_weights(actor_model.get_weights())\n",
434
+ "target_critic.set_weights(critic_model.get_weights())\n",
435
+ "\n",
436
+ "# Learning rate for actor-critic models\n",
437
+ "critic_lr = 0.002\n",
438
+ "actor_lr = 0.001\n",
439
+ "\n",
440
+ "critic_optimizer = tf.keras.optimizers.Adam(critic_lr)\n",
441
+ "actor_optimizer = tf.keras.optimizers.Adam(actor_lr)\n",
442
+ "\n",
443
+ "total_episodes = 100\n",
444
+ "# Discount factor for future rewards\n",
445
+ "gamma = 0.99\n",
446
+ "# Used to update target networks\n",
447
+ "tau = 0.005\n",
448
+ "\n",
449
+ "buffer = Buffer(50000, 64)"
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "markdown",
454
+ "metadata": {
455
+ "id": "4RDsrs-U_wFa"
456
+ },
457
+ "source": [
458
+ "Now we implement our main training loop, and iterate over episodes.\n",
459
+ "We sample actions using `policy()` and train with `learn()` at each time step,\n",
460
+ "along with updating the Target networks at a rate `tau`."
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": 8,
466
+ "metadata": {
467
+ "id": "ytHAvwkZ_wFb",
468
+ "outputId": "3bb12bba-bd57-4e17-d456-f3812dbab54c",
469
+ "colab": {
470
+ "base_uri": "https://localhost:8080/",
471
+ "height": 1000
472
+ }
473
+ },
474
+ "outputs": [
475
+ {
476
+ "output_type": "stream",
477
+ "name": "stdout",
478
+ "text": [
479
+ "Episode * 0 * Avg Reward is ==> -1562.2726156328647\n",
480
+ "Episode * 1 * Avg Reward is ==> -1483.9644029814049\n",
481
+ "Episode * 2 * Avg Reward is ==> -1505.4752484518995\n",
482
+ "Episode * 3 * Avg Reward is ==> -1502.139312848798\n",
483
+ "Episode * 4 * Avg Reward is ==> -1499.098591143217\n",
484
+ "Episode * 5 * Avg Reward is ==> -1498.9744436801384\n",
485
+ "Episode * 6 * Avg Reward is ==> -1513.1644515271164\n",
486
+ "Episode * 7 * Avg Reward is ==> -1468.1716524312988\n",
487
+ "Episode * 8 * Avg Reward is ==> -1427.4514428693592\n",
488
+ "Episode * 9 * Avg Reward is ==> -1386.0628554198065\n",
489
+ "Episode * 10 * Avg Reward is ==> -1330.6460640495982\n",
490
+ "Episode * 11 * Avg Reward is ==> -1296.9259472439123\n",
491
+ "Episode * 12 * Avg Reward is ==> -1266.648683865007\n",
492
+ "Episode * 13 * Avg Reward is ==> -1217.2723090173156\n",
493
+ "Episode * 14 * Avg Reward is ==> -1163.0702894847986\n",
494
+ "Episode * 15 * Avg Reward is ==> -1105.9657062118963\n",
495
+ "Episode * 16 * Avg Reward is ==> -1056.4251588131688\n",
496
+ "Episode * 17 * Avg Reward is ==> -1004.7175789706645\n",
497
+ "Episode * 18 * Avg Reward is ==> -958.4439292235802\n",
498
+ "Episode * 19 * Avg Reward is ==> -916.8559819842148\n",
499
+ "Episode * 20 * Avg Reward is ==> -879.2971938851208\n",
500
+ "Episode * 21 * Avg Reward is ==> -839.4309276948444\n",
501
+ "Episode * 22 * Avg Reward is ==> -813.1273589718702\n",
502
+ "Episode * 23 * Avg Reward is ==> -784.5041398737862\n",
503
+ "Episode * 24 * Avg Reward is ==> -765.0508430770639\n",
504
+ "Episode * 25 * Avg Reward is ==> -740.464676744745\n",
505
+ "Episode * 26 * Avg Reward is ==> -721.947957211692\n",
506
+ "Episode * 27 * Avg Reward is ==> -705.225509729946\n",
507
+ "Episode * 28 * Avg Reward is ==> -685.144228863127\n",
508
+ "Episode * 29 * Avg Reward is ==> -670.6879188788478\n",
509
+ "Episode * 30 * Avg Reward is ==> -653.0154864082411\n",
510
+ "Episode * 31 * Avg Reward is ==> -643.4128610660125\n",
511
+ "Episode * 32 * Avg Reward is ==> -635.5798183939222\n",
512
+ "Episode * 33 * Avg Reward is ==> -623.9639787229108\n",
513
+ "Episode * 34 * Avg Reward is ==> -616.205090622738\n",
514
+ "Episode * 35 * Avg Reward is ==> -606.1140412258295\n",
515
+ "Episode * 36 * Avg Reward is ==> -603.6670876160974\n",
516
+ "Episode * 37 * Avg Reward is ==> -600.9921602699909\n",
517
+ "Episode * 38 * Avg Reward is ==> -591.8512444239832\n",
518
+ "Episode * 39 * Avg Reward is ==> -580.1600306375576\n",
519
+ "Episode * 40 * Avg Reward is ==> -553.821931002297\n",
520
+ "Episode * 41 * Avg Reward is ==> -521.7188034600143\n",
521
+ "Episode * 42 * Avg Reward is ==> -486.36319375176225\n",
522
+ "Episode * 43 * Avg Reward is ==> -453.6710442310697\n",
523
+ "Episode * 44 * Avg Reward is ==> -425.8450281985057\n",
524
+ "Episode * 45 * Avg Reward is ==> -400.74408779723456\n",
525
+ "Episode * 46 * Avg Reward is ==> -366.61738270546164\n",
526
+ "Episode * 47 * Avg Reward is ==> -345.13626004307355\n",
527
+ "Episode * 48 * Avg Reward is ==> -323.61757746366766\n",
528
+ "Episode * 49 * Avg Reward is ==> -301.23857698979566\n",
529
+ "Episode * 50 * Avg Reward is ==> -284.8999331286917\n",
530
+ "Episode * 51 * Avg Reward is ==> -264.84457621322116\n",
531
+ "Episode * 52 * Avg Reward is ==> -248.26764695916563\n",
532
+ "Episode * 53 * Avg Reward is ==> -237.25723863370771\n",
533
+ "Episode * 54 * Avg Reward is ==> -230.53260988021324\n",
534
+ "Episode * 55 * Avg Reward is ==> -236.8247039675385\n",
535
+ "Episode * 56 * Avg Reward is ==> -242.88089725188564\n",
536
+ "Episode * 57 * Avg Reward is ==> -249.99625421933737\n",
537
+ "Episode * 58 * Avg Reward is ==> -256.24104876179\n",
538
+ "Episode * 59 * Avg Reward is ==> -259.44539205532294\n",
539
+ "Episode * 60 * Avg Reward is ==> -259.771282922727\n",
540
+ "Episode * 61 * Avg Reward is ==> -266.78264262398795\n",
541
+ "Episode * 62 * Avg Reward is ==> -264.49970490719676\n",
542
+ "Episode * 63 * Avg Reward is ==> -264.7907401035075\n",
543
+ "Episode * 64 * Avg Reward is ==> -263.65884770574297\n",
544
+ "Episode * 65 * Avg Reward is ==> -263.8187150138804\n",
545
+ "Episode * 66 * Avg Reward is ==> -263.88096070288253\n",
546
+ "Episode * 67 * Avg Reward is ==> -263.68977140982696\n",
547
+ "Episode * 68 * Avg Reward is ==> -272.91279743733804\n",
548
+ "Episode * 69 * Avg Reward is ==> -272.777443352942\n",
549
+ "Episode * 70 * Avg Reward is ==> -283.87400325047287\n",
550
+ "Episode * 71 * Avg Reward is ==> -278.45777238816385\n",
551
+ "Episode * 72 * Avg Reward is ==> -272.09964609736335\n",
552
+ "Episode * 73 * Avg Reward is ==> -269.45733302243724\n",
553
+ "Episode * 74 * Avg Reward is ==> -263.91679852075515\n",
554
+ "Episode * 75 * Avg Reward is ==> -264.0434345954452\n",
555
+ "Episode * 76 * Avg Reward is ==> -260.102765623681\n",
556
+ "Episode * 77 * Avg Reward is ==> -253.51808301808424\n",
557
+ "Episode * 78 * Avg Reward is ==> -250.83738958549662\n",
558
+ "Episode * 79 * Avg Reward is ==> -254.1812329126542\n",
559
+ "Episode * 80 * Avg Reward is ==> -250.125238569467\n",
560
+ "Episode * 81 * Avg Reward is ==> -250.27037014579363\n",
561
+ "Episode * 82 * Avg Reward is ==> -250.1389180516676\n",
562
+ "Episode * 83 * Avg Reward is ==> -245.8142236436616\n",
563
+ "Episode * 84 * Avg Reward is ==> -245.1797777642314\n",
564
+ "Episode * 85 * Avg Reward is ==> -236.02398746977263\n",
565
+ "Episode * 86 * Avg Reward is ==> -239.18889403843315\n",
566
+ "Episode * 87 * Avg Reward is ==> -247.41187644346664\n",
567
+ "Episode * 88 * Avg Reward is ==> -247.82499330593242\n",
568
+ "Episode * 89 * Avg Reward is ==> -250.9072749126738\n",
569
+ "Episode * 90 * Avg Reward is ==> -263.1470922715929\n",
570
+ "Episode * 91 * Avg Reward is ==> -278.58573644976707\n",
571
+ "Episode * 92 * Avg Reward is ==> -280.6476742795351\n",
572
+ "Episode * 93 * Avg Reward is ==> -280.70748492063154\n",
573
+ "Episode * 94 * Avg Reward is ==> -280.565226725522\n",
574
+ "Episode * 95 * Avg Reward is ==> -268.37926234836465\n",
575
+ "Episode * 96 * Avg Reward is ==> -256.03979865280746\n",
576
+ "Episode * 97 * Avg Reward is ==> -255.01505149822543\n",
577
+ "Episode * 98 * Avg Reward is ==> -245.98157845584518\n",
578
+ "Episode * 99 * Avg Reward is ==> -245.54148137920984\n"
579
+ ]
580
+ },
581
+ {
582
+ "output_type": "display_data",
583
+ "data": {
584
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAEGCAYAAACgt3iRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3gc1fX/8feR5N57k3uvuAjTuwHTYorpxIQklAA/0kiAkATS+CakQiAkDpgaYgjVYMBgmik2Ri64GwtXucpNtmWr7vn9saOgGEleS9qdlfR5Pc8+2rkzu3uGMXv2lrnX3B0REZHqSAk7ABERqf2UTEREpNqUTEREpNqUTEREpNqUTEREpNrSwg4gLO3bt/devXqFHYaISK0yb9687e7e4eDyeptMevXqRWZmZthhiIjUKma2rrxyNXOJiEi1JV0yMbPfm9kKM1tkZi+aWesy++4wsywzW2lmZ5YpHx+UZZnZ7eFELiJSfyVdMgHeAoa5+wjgc+AOADMbAlwGDAXGA38zs1QzSwUeBM4ChgCXB8eKiEiCJF0ycfc33b042JwDpAfPJwBT3b3A3dcAWcDY4JHl7qvdvRCYGhwrIiIJknTJ5CDfBF4PnncDNpTZlx2UVVT+FWZ2nZllmllmTk5OHMIVEamfQhnNZWYzgc7l7LrT3V8OjrkTKAb+VVOf6+6TgckAGRkZmuFSRKSGhJJM3H1cZfvN7BvAucBp/uW0xhuB7mUOSw/KqKRcREQSIOnuMzGz8cCPgZPcfX+ZXdOAp83sT0BXoD8wFzCgv5n1JppELgOuSGzUIuHJ3V/EOyu3kl8U4eSBHejSqknYIUk9lHTJBHgAaAS8ZWYAc9z9BndfambPAsuINn/d5O4lAGZ2MzADSAWmuPvScEIXSYxIxHl18Waem5fNx1nbKY582Wo7rFtLThvUiVMHdWR4t1akpFiIkdZt+UUlTFu4iQ+zttO1dRP6dmjGwM4tGN6tFcH3V71h9XVxrIyMDNcd8FLbuDvvrczhd2+sYMWWvXRv24Szh3Vh/LDONG+Uxszl25i5fCvz1+/CHdo3b8Spgzpw7oiuHNu3HWmpyT7mJnltyc1nzfY8CopLKCiOsGzTHp6as44deYV0aNGI3fsLKSqJfp/2ateUS4/swcQx6XRo0SjUuAuLIyzbvIeCohKKI05hSYSjerelacOq1SXMbJ67Z3ylXMlEpHbYsa+A7z2zkA9Wbadnu6b88IyBnDu8S7k1j515hbz/+TbeWZHDeyu2sbegmLbNGjJ+WGdOH9KJY/q0o3GD1BDOIvH25hexdvt+Nuzaz8ZdB2jaKJX+HVvQr2Nz8gqKWbIxlyWbctmZV0SKgRkYRsQdB3L2FvDZht1s21vwlfc+dVBHvn18b47p246SiLNh1wHmr9vFM5kbmLtmJwApBmmpKTRKSyG9TVP6tG9Gnw7NuGh0Or3aN4vbee/YV8DTn6zniTnryDko9pk/OIl+HZtX6X2VTA6iZCK1yeLsXK5/MpMdeYXccdYgrjy6Jw1irGXkF5Xw/uc5vPLZJt5evo0DRSU0bZjKMX3aMbBzC/p0aE7Pdk1plJZCaorRtGEavdo1rdXNNJtzDzBz+TZmLNnCnNU7/qcZsDypKUabpg0BJ+LRGmBqigFGqyZpHJHemhHprejfqQWNG6TSKC2F9s0b0blV4wrfM2vbPmYs3cL+wmKKS5z8ohI27DrAmu15rN+5n1Qzrj2xNzed0q/KtYRSRSURnpy9jrdXbGVffjF5hSWs37mfwuIIJw7owCUZ6bRp2pC0FKNBWgpDurSs8o8JJZODKJlIbfHSgo3c9vwi2jVryD++nsHw9FZVfq/8ohJmr97B28u3MvuLHazbsb/cL9r0Nk04Z0QXzhjSmYg7m3Pz2bYnn5JI9Es2NcXIL4qwN7+IfQXFRNxJMSMlSEClv+4dJxKJfkH3bt+MiRnptGzcoMrxl9qbX8Qnq3cyb/0uiksipJhRHHGytu1j6aY9bN8X/SXep30zzhzWmSPSW9O9bRPS2zRlb34Rq7bt44tt+2jSMJVhXVsxsHOLhNbUtu3J57evr+CFBRvp0qox91wwnFMGdazSe83+Ygd3TVvC51v3MbhLSzq2aETzRml0adWYy8Z2p1/HFjUau5LJQZRMpDZ4YvZafv7yUo7q3Za/XTmads1rtv29qCTChp372bDrAEXFEYojzs68Qt5ctoUPV22P6Rd980ZppKVEm4VKItGmIRycaFJJMcMMdu8vonmjNC7J6M74YZ1JsegxnVo0pke7phV+xtY9+cz+YkcQ536ytu1jUXYuxREnLcVomJZCxB3D6NW+GUO7tmRIl5ac0L89/To2T+oaVubandz54hJWbt3L9Sf24dYzB8Zc49y0+wD3vLacVxdtJr1NE35+7hBOH9Ip7uerZHIQJRNJdg9/sJpfT1/O6UM68cAVo2iUltg+jl15hcxevYNmwa/cTi0bk5ZilLhTUuI0aRht7on1y2txdi6PfLiaVxdt/p8kZQYXjOzG908fQPe2TSmJOEs35fL+yhzeWr6VRdm5/z22Y4tG9GjblLG923J8//aM7tGm1vf95BeV8KtXl/GvT9Yzukdrrj+pL51aNqZji0Y0K23+MoIE7RRHnGc+3cAD72QRcef6k/py48l9E/bfQcnkIEomkqzcnQfeyeKPb33OOSO68JdLR8b8a7U22JKbz4ote/5bY/lw1XYe/XgtOIzt3ZZF2bvZkx+dnm9Uj9aMG9yJUwZ2pE+HZrU+cVTm1UWbuOP5xewtKD70wcD4oZ2585zBdG9bca0uHpRMDqJkIslof2Extz+/mGmfbeLCUd24d+KIejGcd9PuA/xl5ufMX7+bMT3acGy/dhzbt33ow2oTbW9+Eet27Gfrnny27ikgv6gEJ/oDw8yio82AwV1aclSfdqHEqGRyECUTSTZrt+dxw1PzWLl1L7eeMZDvnNRXNxxK0qkomSTjHfAi9c6yTXu4/J9zMIPHrhnLSQO+ssS2SFJTMhEJ2drteUyaMpemDVOZet3R9GwXvxvZROKl7jfGiiSxrXvyueqRTyiJRHjyW2OVSKTWUjIRCcmuvEImPTKXXXmFPHbN2Bq/uUwkkdTMJRKCPflFTJoylzU78njsG0dyRPfWYYckUi2qmYgkWF5BMdc8+ikrtuzhH1eN4dh+7cMOSaTaVDMRSaD8ohK+/XgmCzfs5sErRlV5PiaRZKOaiUiCFBSXcN2T85izZgd/uuQIxg/rEnZIIjVGyUQkAYpKItz89AJmfZ7D7y4cwYSR3cIOSaRGKZmIxFlJxPneMwt5a9lWfjlhKJcc2T3skERqnJKJSJxN+XAN0xdt5idnD2LSMb3CDkckLpRMROJoc250AsPTBnXkuhP7hh2OSNwkbTIxsx+amZtZ+2DbzOx+M8sys0VmNrrMsVeb2argcXV4UYv8r1+/upziiHP314aGHYpIXCXl0GAz6w6cAawvU3wW0D94HAU8BBxlZm2Bu4AMogu3zTOzae6+K7FRi/yvWZ/nMH3xZn4YLPokUpcla83kz8CPiSaHUhOAJzxqDtDazLoAZwJvufvOIIG8BYxPeMQiZRQUl3DXtKX0bt+M607qE3Y4InGXdMnEzCYAG939s4N2dQM2lNnODsoqKi/vva8zs0wzy8zJyanBqEX+12+mL2fN9jx+8bWhCV9uVyQMoTRzmdlMoHM5u+4EfkK0iavGuftkYDJEF8eKx2eIPDVnHU/MXsd1J/bhRK1LIvVEKMnE3ceVV25mw4HewGdmBpAOzDezscBGoOwA/fSgbCNw8kHl79V40CIxmP3FDu6etpRTBnbgtvGDwg5HJGGSqpnL3Re7e0d37+XuvYg2WY129y3ANGBSMKrraCDX3TcDM4AzzKyNmbUhWquZEdY5SP21Yed+vvOvefRq34z7Lh9FqpbclXokKUdzVeA14GwgC9gPXAPg7jvN7FfAp8Fxv3T3neGEKPVVQXEJN/5rPpGI8/CkDFo2bhB2SCIJldTJJKidlD534KYKjpsCTElQWCJf8dvXV7B4Yy6Tvz6GXu21WqLUP0nVzCVSG81YuoVHP1rLNcf14oyh5Y0rEan7lExEqiF7135+9J/PGJHeijvOGhx2OCKhUTIRqSJ357bnFxFx+Ovlo2iYpv+dpP7Sv36RKnph/kY+ytrB7WcNomc79ZNI/aZkIlIFO/YV8OvpyxjTsw1XjO0RdjgioVMyEamC30xfzr6CYv7vwuGk6H4SESUTkcP1waocXliwke+c1JcBnVqEHY5IUlAyETkMhcUR7no5Ohvwjaf0CzsckaShZCJyGB79aA2rt+fx8/OG0LiBZgMWKaVkIhKjbXvyuf/tVZw2qCOnDOwYdjgiSUXJRCRGv31jBUUlzs/OHRJ2KCJJR8lEJAbz1u3ihfkb+dYJvTX3lkg5lExEDqGoJMLPXlpCp5aNuFmd7iLlSupZg0WSweRZq1m2eQ9/v2oMzRrpfxmR8qhmIlKJL3L2cd/bqzhrWGfGD9OMwCIVUTIRqUAk4tzx/GIap6XwiwlDww5HJKkpmYhU4F9z1zN37U5+es4QOrZoHHY4IkmtwgZgM/sr4BXtd/db4hKRSBJYsjGXX7+6jBP6t+fijPSwwxFJepXVTDKBeUBjYDSwKniMBBrGPzSRcOzeX8gNT82jTdOG/PnSkZhpIkeRQ6mwZuLujwOY2XeA4929ONj+O/BBYsITSaySiHPL1IVs21PAM9cfTfvmjcIOSaRWiKXPpA3Qssx286BMpE7JLyrhl68sZdbnOdz1tSGM6qF/5iKxiiWZ/BZYYGaPmdnjwHzgnngGZWb/z8xWmNlSM7u3TPkdZpZlZivN7Mwy5eODsiwzuz2esUnd4+68vHAjp/3xfR6fvY6rj+mpBa9EDlOld2CZWQqwEjgqeADc5u5b4hWQmZ0CTACOcPcCM+sYlA8BLgOGAl2BmWY2IHjZg8DpQDbwqZlNc/dl8YpR6oaikgivLd7Mwx+sYfHGXIZ0acnvLx7BsX3bhx2aSK1TaTJx94iZPejuo4CXExTTd4DfuntBEMO2oHwCMDUoX2NmWcDYYF+Wu68GMLOpwbFKJvJfkYizbW8BG3btZ8PO/XyRs48X5m9kc24+vds3496JI7hodDqpWjVRpEpimRvibTO7CHjB3SscKlyDBgAnmNlvgHzgVnf/FOgGzClzXHZQBrDhoPKjKIeZXQdcB9Cjh5ox6jJ3572VOTz28VpWb9/Hltx8ikr+95/vsX3b8evzh3HKwI5aelekmmJJJtcDPwCKzSwfMMDdvWXlL6uYmc0Eypub4s4gprbA0cCRwLNm1qeqn1WWu08GJgNkZGQkIjFKghWXRJi5fBsPvpvF4o25dG3VmCN7t6Vr6yZ0bdWY9LZN6d6mKeltmmhxK5EadMhk4u41vsi1u4+raF8wFLm0FjTXzCJAe2Aj0L3MoelBGZWUSz2xautenpuXzYsLNrJtbwE92zXl3otGcP6objRM00QPIvEW0xSoZtYG6E/0BkYA3H1WnGJ6CTgFeDfoYG8IbAemAU+b2Z+IdsD3B+YSrSn1N7PeRJPIZcAVcYpNksyGnfv57RsrmL5oM6kpxikDOzJxTDrjBnckLVVJRCRRDplMzOzbwHeJ/uJfSLT5aTZwapximgJMMbMlQCFwdVBLWWpmzxLtWC8GbnL3kiDGm4EZQCowxd2Xxik2SRJbcvOZ8tEaHvtoLSkpcMup/fj6Mb3o0EI3GYqEwQ7Vp25mi4n2Xcxx95FmNgi4x90vTESA8ZKRkeGZmZlhhyGHobA4wpvLtvCfzGw+WJWDAxeNTufWMwbSuZUmYhRJBDOb5+4ZB5fH0syV7+75ZoaZNXL3FWY2MA4xipRrX0ExU+eu55EP17A5N58urRpz48n9mDgmXUvoiiSJWJJJtpm1JtqX8ZaZ7QLWxTcskajn5mXzy1eWsie/mKP7tOWeC4Zz4oAOuh9EJMnEMprrguDp3Wb2LtAKeCOuUUm9F4k4f3hzJX977wuO6t2WO84ezMjurcMOS0QqEEsH/K+AWcDH7v5+/EOS+i6/qIQf/uczpi/azOVju/PLCcNooJFZIkktlmau1cDlwP1mtpfo9POz3D1R06tIPbJ88x6+/8xCVm7dy0/OHsS1J/TReiIitUAszVyPAo+aWWfgEuBWolOS1PjNjFJ/lUSchz9YzR/f/JxWTRvw6DeO5OSBHcMOS0RiFEsz18PAEGAr0VrJRKLT0IvUiPyiEm5+egEzl29l/NDO3HPhcNo202KeIrVJLM1c7YjeDLgb2AlsL111UaS6DhSWcN2TmXywajt3nTeEbxzbS81aIrVQzKO5zGwwcCbRaU5S3T093sFJ3ZZXUMy3H89kzpod3HvRCC45svuhXyQiSSmWZq5zgROAE4HWwDtoDXipps25B7jhyXks2bSHP18ykvNHdTv0i0QkacXSzDWeaPK4z903xTkeqQcy1+7khqfmc6CwmL9fNYbTh3QKOyQRqaZYmrluNrOeRDvhN5lZEyDN3ffGPTqpc575dD0/fWkJ3Vo34d/XHkX/ThoUKFIXxNLMdS3RocBtgb5EZw/+O3BafEOTusTd+fPMVdz/9ipOHNCBv142ilZNG4QdlojUkFiauW4iutb6JwDuvsrMdAOAxKy4JMKdLy7hmcwNXJKRzj0XDNdaIyJ1TCzJpMDdC0uHa5pZGqAlbyUmJRHnxn/N581lW7nl1H58//QBGvorUgfFkkzeN7OfAE3M7HTgRuCV+IYldcVvpi/nzWVb+fm5Q/jm8b3DDkdE4iSWtobbgRxgMXA98Jq73xnXqKROeGrOOqZ8tIZrjuulRCJSxx0ymbh7xN3/6e4Xu/tEYJ2ZvZWA2KQW+3DVdu6atpRTBnbgp+cMCTscEYmzCpOJmZ1qZp+b2T4ze8rMhptZJvB/wEOJC1Fqm09W7+CGp+bRr0Nz7r98lBayEqkHKquZ/JHokOB2wHPAbOAxdx/j7i8kIjipfd5dsY1JU+bSqWUjHvvmkbRorOG/IvVBZR3w7u7vBc9fMrON7v5AAmKSWmr6os18d+oCBnVpwePXjKVd80ZhhyQiCVJZzaS1mV1Y+gDSDtqOCzMbaWZzzGyhmWWa2dig3MzsfjPLMrNFZja6zGuuNrNVwePqeMUmFZu/fhe3TF3AqB6tefrao5VIROqZymom7wPnldmeVWbbgXg1dd0L/MLdXzezs4Ptk4GzgP7B4yii/TZHmVlb4C4gI4hrnplNc/ddcYpPDrInv4jvTl1A55aNefjqI2mppi2ReqfCZOLu1yQykLIfDbQMnrcCSieXnAA84e4OzDGz1mbWhWiiecvddwIEI83GA/9OaNT1lLvzs5eWsGl3Ps9efzStmiiRiNRHsdy0mGjfA2aY2R+INsMdG5R3AzaUOS47KKuo/CvM7Dqigwro0aNHzUZdT70wfyMvL9zED04fwJiebcMOR0RCEkoyMbOZQOdydt1JdALJ77v782Z2CfAIMK4mPtfdJwOTATIyMjQlTDVlbdvLz19ewtjebbnplH5hhyMiIQolmbh7hcnBzJ4Avhts/gd4OHi+ESi7FF96ULaRaFNX2fL3aihUqcDe/CKue3IejRukct9lI3UviUg9d8g74M3sJjNrXWa7jZndGMeYNgEnBc9PBVYFz6cBk4JRXUcDue6+GZgBnBHE1QY4IyiTOHF3bv3PZ6zbsZ8HrhhNl1ZNwg5JREIWS83kWnd/sHTD3XcFa5z8LU4xXQvcF8xOnE/QxwG8BpwNZAH7gWuCeHaa2a+AT4PjflnaGS/x8dD7XzBj6VZ+es5gjunbLuxwRCQJxJJMUs3MglFUmFkq0DBeAbn7h8CYcsqd6Noq5b1mCjAlXjHJl5Zt2sMfZqzk3BFd+JYmbxSRQCzJ5A3gGTP7R7B9fVAm9dAf31xJ80Zp/OaC4VqXRET+K5ZkchvRBPKdYPstvuwUl3pk/vpdvL1iGz86c6DuJxGR/3HIZOLuEaJ3m2um4HruDzNW0r55Q75xbK+wQxGRJFNhMjGzZ939EjNbTDnL9Lr7iLhGJknl46ztfPzFDn527hCaNUrGe11FJEyVfSuU3utxbiICkeTl7vz+zZV0adWYK4/SzAEi8lWVzc21Ofi7LnHhSDKauXwbC9bv5p4LhtO4QWrY4YhIEqqsmWsv5TRvlXL3lhXtk7qjqCTC/72+nD4dmnFxRnrY4YhIkqqsZtICILghcDPwJGDAlUCXhEQnoZv66QZW5+Txz0kZNEg95IQJIlJPxfLt8DV3/5u773X3Pe7+ENHp4KWO21dQzH0zP2ds77aMG9wx7HBEJInFkkzyzOxKM0s1sxQzuxLIi3dgEr5/vP8F2/cVcufZg3WDoohUKpZkcgVwCbAV2AZcHJRJHbZ1Tz7//GA15x3RlSO6tz70C0SkXovlpsW1qFmr3vnbu1kUlzg/OmNg2KGISC0QyxT06Wb2opltCx7Pm5mG9dRhW3Lz+ffcDUwck06Pdk3DDkdEaoFYmrkeJbqWSNfg8UpQJnXU39//goi7Vk8UkZjFkkw6uPuj7l4cPB4DOsQ5LgnJ1j35PD13PRPHpNO9rWolIhKbWJLJDjO7KhjNlWpmVwE74h2YhOOh974gElGtREQOTyzJ5JtER3NtIXrz4kSCVQ6lbimtlVw0WrUSETk8sYzmWgd8LQGxSMimfLiG4pKIaiUictgqm5vrx+5+r5n9lfKnoL8lrpFJQu0rKObpues5a3gXjeASkcNWWc1kefA3MxGBSLie/XQDe/OLufaEPmGHIiK1UIV9Ju7+SvD38dIH0ckeXwyeV5mZXWxmS80sYmYZB+27w8yyzGylmZ1Zpnx8UJZlZreXKe9tZp8E5c+YWcPqxFYflUScKR+tIaNnG0bqbncRqYJYblp82sxamlkzYAmwzMx+VM3PXQJcCMw66LOGAJcBQ4HxwN9KR5EBDwJnAUOAy4NjAX4H/Nnd+wG7gG9VM7Z6Z8bSLWTvOsC3VSsRkSqKZTTXEHffA5wPvA70Br5enQ919+XuvrKcXROAqe5e4O5rgCxgbPDIcvfV7l4ITAUmWHT2wVOB54LXPx7EKYfhnx+spme7ppw+pFPYoYhILRVLMmlgZg2IfklPc/ciKlk0q5q6ARvKbGcHZRWVtwN2u3vxQeXlMrPrzCzTzDJzcnJqNPDaat66XSxYv5tvHteb1BTNDCwiVRNLMvkHsBZoBswys57AnkO9yMxmmtmSch6hTRrp7pPdPcPdMzp00E38AE/OXkuLRmlMHKPp1kSk6mK5z+R+4P4yRevM7JQYXjeuCvFsBLqX2U4PyqigfAfQ2szSgtpJ2ePlEHblFfLaki1cdmR3mjU65D8FEZEKxdIB387M7jez+WY2z8zuA1rFKZ5pwGVm1sjMegP9gbnAp0D/YORWQ6Kd9NPc3YF3id6VD3A18HKcYqtznp+fTWFxhCuO6hF2KCJSy8XSzDUVyAEuIvqlnQM8U50PNbMLzCwbOAaYbmYzANx9KfAssAx4A7jJ3UuCWsfNwAyi9788GxwLcBvwAzPLItqH8kh1Yqsv3J2n565nVI/WDOrcMuxwRKSWs+iP+0oOMFvi7sMOKlvs7sPjGlmcZWRkeGZm/b0f85PVO7h08hx+P3EEF2d0P/QLREQAM5vn7hkHl8dSM3nTzC4L1n9PMbNLiNYQpBZ7eu56WjRO49wRXcMORUTqgFiSybXA00BB8JgKXG9me83skKO6JPnsyivk9cVbuHBUN5o0TA07HBGpA2IZzdUiEYFI4jw/P5vCkgiXq+NdRGpIhTWTYBGs0ufHHbTv5ngGJfETiThPzllHRs826ngXkRpTWTPXD8o8/+tB+74Zh1gkAWatymHdjv18/ZieYYciInVIZcnEKnhe3rbUEk/OXkf75o04a1iXsEMRkTqksmTiFTwvb1tqgQ079/POym1cPrY7DdNiGXshIhKbyjrgB5nZIqK1kL7Bc4JtzVVeCz31yTpSzHTHu4jUuMqSyeCERSFxl19UwrOfbuD0wZ3o0qpJ2OGISB1TYTJx93WJDETi69VFm9m1v4hJ6ngXkThQw3k98WzmBvq0b8YxfduFHYqI1EFKJvXAhp37mbtmJxeO7kZ0cUoRkZqlZFIPvLwwusTLhJEVLkIpIlItVUomZnZ3DcchceLuvLBgI2N7t6V726ZhhyMidVRVaybzajQKiZtF2bmszsnjglGqlYhI/FQpmbj7KzUdiMTHiws20jAthbOH6453EYmfQ84abGb3l1OcC2S6u5bITWJFJRFe+WwT4wZ3pFWTBmGHIyJ1WCw1k8bASGBV8BgBpAPfMrO/xDE2qaYPVuWwI6+QC0alhx2KiNRxh6yZEE0ex7l7CYCZPQR8ABwPLI5jbFJNz8/fSJumDThpQIewQxGROi6WmkkboHmZ7WZA2yC5FMQlKqm23fsLeWvpViaM7KZJHUUk7mKpmdwLLDSz94hO8ngicI+ZNQNmxjE2qYaXF26isCTCJRndww5FROqBQ/5kdfdHgGOBl4AXgePd/WF3z3P3H1XlQ83sYjNbamYRM8soU366mc0zs8XB31PL7BsTlGeZ2f0W3MptZm3N7C0zWxX8bVOVmOqaZzM3MKxbS4Z01WqKIhJ/h0wmZvYKcDIw091fdvdNNfC5S4ALgVkHlW8HznP34cDVwJNl9j0EXAv0Dx7jg/LbgbfdvT/wdrBdry3ZmMvSTXtUKxGRhImlMf0PwAnAMjN7zswmmlnj6nyouy9395XllC8ok6yWAk3MrJGZdQFauvscd3fgCeD84LgJwOPB88fLlNdbz83LpmFqCl87omvYoYhIPRFLM9f77n4j0QWx/gFcAmyLd2DARcB8dy8AugHZZfZlB2UAndx9c/B8C9Cpojc0s+vMLNPMMnNycuIRc+jyi0p4ccFGzhjaidZNG4YdjojUE7F0wGNmTYDzgEuB0XxZE6jsNTOBzuXsuvNQNzua2VDgd8AZscRXyt3dzCpcUtjdJwOTATIyMurk0sMzl28l90CRmrhEJKFiuQP+WWAs8AbwAPC+u0cO9Tp3H1eVgMwsnWhH/yR3/yIo3rihxVYAAA7vSURBVEj0RslS6UEZwFYz6+Lum4PmsETUmpLWM59uoGurxhzXr33YoYhIPRJLn8kjQF93v8Hd3wWONbMH4xGMmbUGpgO3u/tHpeVBM9YeMzs6GMU1CSit3Uwj2llP8LfeTvGyOmcfH6zazqVH9iA1ReuWiEjixNJnMgMYYWb3mtla4FfAiup8qJldYGbZwDHAdDObEey6GegH/NzMFgaPjsG+G4GHgSzgC+D1oPy3wOlmtgoYF2zXS0/NWU+DVOPyo9TEJSKJVWEzl5kNAC4PHtuBZwBz91Oq+6Hu/iLRpqyDy38N/LqC12QCw8op3wGcVt2Yarv9hcX8Z94Gxg/rQscW1RpsJyJy2CrrM1lBdA6uc909C8DMvp+QqOSwvbRgE3vzi5l0TM+wQxGReqiyZq4Lgc3Au2b2TzM7jeh0KpJk3J0nZq9lcJeWZPTUBAAikngVJhN3f8ndLwMGAe8C3wM6mtlDZnZYQ3Ylvj5du4sVW/Yy6ZieBLPMiIgkVCwd8Hnu/rS7n0d0SO4C4La4RyYxe+zjNbRonMaEkbrjXUTCcVhzk7v7Lnef7O71vsM7Wby9fCuvLd7CNcf2omnDmO5BFRGpcVroohbbmVfIbc8vZlDnFtx0ar+wwxGRekw/ZWspd+dnLy0h90AhT3xzLI3SUsMOSUTqMdVMaqlpn21i+uLNfG/cAK1ZIiKhUzKphXbsK+CuaUsZ2b0115/YJ+xwRESUTGqj/3t9Bfvyi7l34gjSUnUJRSR8+iaqZeau2clz87L59gl9GNCpRdjhiIgASia1SlFJhJ++tJhurZtwy2kavSUiyUOjuWqRRz5cw+db9/HPSRm6p0REkopqJrVEXkExf317FeMGd+T0IRWuTCwiEgolk1ritcWbySss4Tsn9w07FBGRr1AyqSWem5dN7/bNGN1DswKLSPJRMqkFNuzczydrdnLR6G6aFVhEkpKSSS3w/PxszOCC0elhhyIiUi4lkyQXiTjPz8/muL7t6da6SdjhiIiUS8kkyc1du5MNOw8wcYxqJSKSvJRMktxz87Jp3iiNM4d2DjsUEZEKhZJMzOxiM1tqZhEzyyhnfw8z22dmt5YpG29mK80sy8xuL1Pe28w+CcqfMbOGiTqPeMs9UMTrizdzzvAuNGmoKeZFJHmFVTNZAlwIzKpg/5+A10s3zCwVeBA4CxgCXG5mQ4LdvwP+7O79gF3At+IVdKJN+XANeYUlTDq2Z9ihiIhUKpRk4u7L3X1lefvM7HxgDbC0TPFYIMvdV7t7ITAVmGDRcbKnAs8Fxz0OnB+/yBMn90ARUz5aw5lDOzG0a6uwwxERqVRS9ZmYWXPgNuAXB+3qBmwos50dlLUDdrt78UHlFb3/dWaWaWaZOTk5NRd4HEz5cA1784u55bT+YYciInJIcUsmZjbTzJaU85hQycvuJtpktS8eMbn7ZHfPcPeMDh06xOMjaoRqJSJS28Rt6ll3H1eFlx0FTDSze4HWQMTM8oF5QPcyx6UDG4EdQGszSwtqJ6XltZpqJSJS2yTVPObufkLpczO7G9jn7g+YWRrQ38x6E00WlwFXuLub2bvARKL9KFcDLyc+8pqTu1+1EhGpfcIaGnyBmWUDxwDTzWxGZccHtY6bgRnAcuBZdy/toL8N+IGZZRHtQ3kkfpHH3yMfrmZvfjHfPW1A2KGIiMQslJqJu78IvHiIY+4+aPs14LVyjltNdLRXrbcrr5ApH63l7OGdGdK1ZdjhiIjELKlGc9V3//xgNXmFxXxvnGolIlK7KJkkiR37Cnjs47WcN6IrAzq1CDscEZHDomSSJP4xazX5RSUawSUitZKSSRLI2VvAE7PXcv7IbvTr2DzscEREDpuSSRJ4/OO1FBRHuPnUfmGHIiJSJUom1bBh537mrtlZrffIKyjmidlrOXNIZ/p0UK1ERGonJZNq+OlLS7h08mxemJ9d5feY+ukG9uQXc/1JfWowMhGRxEqqO+Brk4LiEuau2UmD1BRu/c9nNExL4dwRXf/nGHcnZ28Bm3LzKYk44DRrlMagztF7SIpKIjzywWrG9m7LqB5tQjgLEZGaoWRSRQvW7+ZAUQn3XTaSp+as43tTF7Jp9wHyiyJkbdvHFzn7WLs9j7zCkq+8dsLIrvzya8N4Z+VWNuXm8+sLhoVwBiIiNUfJpIo+ytpOisEpgzpy6qCOXPXIXO55bQUA6W2a0LdDc47s1ZY+HZrRtVUT0lINM2Peul387d0s5qzeQaO0VAZ0as7JAzqGfDYiItWjZFJFH2Vt54jurWnZuAEA/7n+GNZsz6N72yY0bVjxf9aTBnTg9MGd+MGzC1m1bR+/nziClBRLVNgiInGhZFIFe/KL+Cw7l++c1Pe/ZQ3TUhjYObY714ent+KV/3c8n67dyXF928crTBGRhFEyqYJPVu+kJOIc16/qiaBxg1RO6J+8C3SJiBwODQ2ugo+yttO4QQqje7YOOxQRkaSgZFIFH2Vt58hebWmUlhp2KCIiSUHJ5DBt3ZPPqm37OL4aTVwiInWNkslh+ihrO0C1+ktEROoaJZPD9FHWDto0bcCQLloJUUSklEZzHaa+HZvRoUUP3RsiIlKGkslhuvFkTRMvInKwUJq5zOxiM1tqZhEzyzho3wgzmx3sX2xmjYPyMcF2lpndb2YWlLc1s7fMbFXwVzMmiogkWFh9JkuAC4FZZQvNLA14CrjB3YcCJwNFwe6HgGuB/sFjfFB+O/C2u/cH3g62RUQkgUJJJu6+3N1XlrPrDGCRu38WHLfD3UvMrAvQ0t3nuLsDTwDnB6+ZADwePH+8TLmIiCRIso3mGgC4mc0ws/lm9uOgvBtQdgWq7KAMoJO7bw6ebwE6JSZUEREpFbcOeDObCXQuZ9ed7v5yJfEcDxwJ7AfeNrN5QG4sn+nubmZeSUzXAdcB9OjRI5a3FBGRGMQtmbj7uCq8LBuY5e7bAczsNWA00X6U9DLHpQMbg+dbzayLu28OmsO2VRLTZGAyQEZGRoVJR0REDk+yNXPNAIabWdOgM/4kYFnQjLXHzI4ORnFNAkprN9OAq4PnV5cpFxGRBAlraPAFZpYNHANMN7MZAO6+C/gT8CmwEJjv7tODl90IPAxkAV8ArwflvwVON7NVwLhgW0REEsiig6PqHzPLAdZV8eXtge01GE5tUR/Puz6eM9TP89Y5x6anu39lMaZ6m0yqw8wy3T3j0EfWLfXxvOvjOUP9PG+dc/UkW5+JiIjUQkomIiJSbUomVTM57ABCUh/Puz6eM9TP89Y5V4P6TEREpNpUMxERkWpTMhERkWpTMjlMZjbezFYG66rUyenuzay7mb1rZsuCdWW+G5TX+bVjzCzVzBaY2avBdm8z+yS43s+YWcOwY6xpZtbazJ4zsxVmttzMjqnr19rMvh/8215iZv82s8Z18Vqb2RQz22ZmS8qUlXttLer+4PwXmdnow/ksJZPDYGapwIPAWcAQ4HIzGxJuVHFRDPzQ3YcARwM3BedZH9aO+S6wvMz274A/u3s/YBfwrVCiiq/7gDfcfRBwBNHzr7PX2sy6AbcAGe4+DEgFLqNuXuvH+HLtp1IVXduz+HK9qOuIriEVMyWTwzMWyHL31e5eCEwlup5KneLum919fvB8L9Evl27U8bVjzCwdOIfotD0E88CdCjwXHFIXz7kVcCLwCIC7F7r7bur4tSY6yW2TYA7ApsBm6uC1dvdZwM6Diiu6thOAJzxqDtA6mDw3Jkomh6cbsKHMdtl1VeokM+sFjAI+oe6vHfMX4MdAJNhuB+x29+Jguy5e795ADvBo0Lz3sJk1ow5fa3ffCPwBWE80ieQC86j717pURde2Wt9vSiZSITNrDjwPfM/d95TdF6x4WWfGlZvZucA2d58XdiwJlkZ0mYeH3H0UkMdBTVp18Fq3IforvDfQFWjGV5uC6oWavLZKJodnI9C9zHbZdVXqFDNrQDSR/MvdXwiKt5ZWew+1dkwtdBzwNTNbS7T58lSifQmtg6YQqJvXOxvIdvdPgu3niCaXunytxwFr3D3H3YuAF4he/7p+rUtVdG2r9f2mZHJ4PgX6B6M+GhLttJsWckw1LugreARY7u5/KrOrzq4d4+53uHu6u/ciel3fcfcrgXeBicFhdeqcAdx9C7DBzAYGRacBy6jD15po89bRwbpJxpfnXKevdRkVXdtpwKRgVNfRQG6Z5rBD0h3wh8nMzibatp4KTHH334QcUo0zs+OBD4DFfNl/8BOi/SbPAj2ITt9/ibsf3LlX65nZycCt7n6umfUhWlNpCywArnL3gjDjq2lmNpLooIOGwGrgGqI/NOvstTazXwCXEh25uAD4NtH+gTp1rc3s38DJRKea3wrcBbxEOdc2SKwPEG3y2w9c4+6ZMX+WkomIiFSXmrlERKTalExERKTalExERKTalExERKTalExERKTalExEaoiZlZjZwjKPSidHNLMbzGxSDXzuWjNrX933EakODQ0WqSFmts/dm4fwuWuJzoC7PdGfLVJKNROROAtqDvea2WIzm2tm/YLyu83s1uD5LcH6MYvMbGpQ1tbMXgrK5pjZiKC8nZm9GazH8TBgZT7rquAzFprZP4JlE0TiTslEpOY0OaiZ69Iy+3LdfTjRO4z/Us5rbwdGufsI4Iag7BfAgqDsJ8ATQfldwIfuPhR4keidzJjZYKJ3dR/n7iOBEuDKmj1FkfKlHfoQEYnRgeBLvDz/LvP3z+XsXwT8y8xeIjrdBcDxwEUA7v5OUCNpSXT9kQuD8ulmtis4/jRgDPBpdGYMmlC3JmiUJKZkIpIYXsHzUucQTRLnAXea2fAqfIYBj7v7HVV4rUi1qJlLJDEuLfN3dtkdZpYCdHf3d4HbgFZAc6KTbV4ZHHMysD1YV2YWcEVQfhZQuj7728BEM+sY7GtrZj3jeE4i/6WaiUjNaWJmC8tsv+HupcOD25jZIqAAuPyg16UCTwVL6Bpwv7vvNrO7gSnB6/bz5bThvwD+bWZLgY+JTqmOuy8zs58CbwYJqgi4iejMsCJxpaHBInGmobtSH6iZS0REqk01ExERqTbVTEREpNqUTEREpNqUTEREpNqUTEREpNqUTEREpNr+PzuwM5S9NvHoAAAAAElFTkSuQmCC\n",
585
+ "text/plain": [
586
+ "<Figure size 432x288 with 1 Axes>"
587
+ ]
588
+ },
589
+ "metadata": {
590
+ "needs_background": "light"
591
+ }
592
+ }
593
+ ],
594
+ "source": [
595
+ "# To store reward history of each episode\n",
596
+ "ep_reward_list = []\n",
597
+ "# To store average reward history of last few episodes\n",
598
+ "avg_reward_list = []\n",
599
+ "\n",
600
+ "# Takes about 4 min to train\n",
601
+ "for ep in range(total_episodes):\n",
602
+ "\n",
603
+ " prev_state = env.reset()\n",
604
+ " episodic_reward = 0\n",
605
+ "\n",
606
+ " while True:\n",
607
+ " # Uncomment this to see the Actor in action\n",
608
+ " # But not in a python notebook.\n",
609
+ " # env.render()\n",
610
+ "\n",
611
+ " tf_prev_state = tf.expand_dims(tf.convert_to_tensor(prev_state), 0)\n",
612
+ "\n",
613
+ " action = policy(tf_prev_state, ou_noise)\n",
614
+ " # Recieve state and reward from environment.\n",
615
+ " state, reward, done, info = env.step(action)\n",
616
+ "\n",
617
+ " buffer.record((prev_state, action, reward, state))\n",
618
+ " episodic_reward += reward\n",
619
+ "\n",
620
+ " buffer.learn()\n",
621
+ " update_target(target_actor.variables, actor_model.variables, tau)\n",
622
+ " update_target(target_critic.variables, critic_model.variables, tau)\n",
623
+ "\n",
624
+ " # End this episode when `done` is True\n",
625
+ " if done:\n",
626
+ " break\n",
627
+ "\n",
628
+ " prev_state = state\n",
629
+ "\n",
630
+ " ep_reward_list.append(episodic_reward)\n",
631
+ "\n",
632
+ " # Mean of last 40 episodes\n",
633
+ " avg_reward = np.mean(ep_reward_list[-40:])\n",
634
+ " print(\"Episode * {} * Avg Reward is ==> {}\".format(ep, avg_reward))\n",
635
+ " avg_reward_list.append(avg_reward)\n",
636
+ "\n",
637
+ "# Plotting graph\n",
638
+ "# Episodes versus Avg. Rewards\n",
639
+ "plt.plot(avg_reward_list)\n",
640
+ "plt.xlabel(\"Episode\")\n",
641
+ "plt.ylabel(\"Avg. Epsiodic Reward\")\n",
642
+ "plt.show()"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "markdown",
647
+ "metadata": {
648
+ "id": "XY85n6_l_wFb"
649
+ },
650
+ "source": [
651
+ "If training proceeds correctly, the average episodic reward will increase with time.\n",
652
+ "\n",
653
+ "Feel free to try different learning rates, `tau` values, and architectures for the\n",
654
+ "Actor and Critic networks.\n",
655
+ "\n",
656
+ "The Inverted Pendulum problem has low complexity, but DDPG work great on many other\n",
657
+ "problems.\n",
658
+ "\n",
659
+ "Another great environment to try this on is `LunarLandingContinuous-v2`, but it will take\n",
660
+ "more episodes to obtain good results."
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": 9,
666
+ "metadata": {
667
+ "id": "fDayimW0_wFb"
668
+ },
669
+ "outputs": [],
670
+ "source": [
671
+ "# Save the weights\n",
672
+ "actor_model.save_weights(\"pendulum_actor.h5\")\n",
673
+ "critic_model.save_weights(\"pendulum_critic.h5\")\n",
674
+ "\n",
675
+ "target_actor.save_weights(\"pendulum_target_actor.h5\")\n",
676
+ "target_critic.save_weights(\"pendulum_target_critic.h5\")"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "markdown",
681
+ "metadata": {
682
+ "id": "hYiCdLyE_wFb"
683
+ },
684
+ "source": [
685
+ "Before Training:\n",
686
+ "\n",
687
+ "![before_img](https://i.imgur.com/ox6b9rC.gif)"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "markdown",
692
+ "metadata": {
693
+ "id": "D1lklgTJ_wFc"
694
+ },
695
+ "source": [
696
+ "After 100 episodes:\n",
697
+ "\n",
698
+ "![after_img](https://i.imgur.com/eEH8Cz6.gif)"
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "source": [
704
+ "!pip install huggingface-hub\n",
705
+ "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n",
706
+ "!sudo apt-get install git-lfs\n",
707
+ "!git-lfs install"
708
+ ],
709
+ "metadata": {
710
+ "id": "c6Ao1vi4_zwE",
711
+ "outputId": "a2aa4ade-a162-432f-92d3-1d1358c2ead6",
712
+ "colab": {
713
+ "base_uri": "https://localhost:8080/"
714
+ }
715
+ },
716
+ "execution_count": 10,
717
+ "outputs": [
718
+ {
719
+ "output_type": "stream",
720
+ "name": "stdout",
721
+ "text": [
722
+ "Collecting huggingface-hub\n",
723
+ " Downloading huggingface_hub-0.2.1-py3-none-any.whl (61 kB)\n",
724
+ "\u001b[?25l\r\u001b[K |█████▎ | 10 kB 21.9 MB/s eta 0:00:01\r\u001b[K |██████████▋ | 20 kB 14.7 MB/s eta 0:00:01\r\u001b[K |███████████████▉ | 30 kB 11.1 MB/s eta 0:00:01\r\u001b[K |█████████████████████▏ | 40 kB 9.7 MB/s eta 0:00:01\r\u001b[K |██████████████████████████▌ | 51 kB 5.2 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▊| 61 kB 5.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 61 kB 446 kB/s \n",
725
+ "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (3.4.0)\n",
726
+ "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (21.3)\n",
727
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (3.13)\n",
728
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (3.10.0.2)\n",
729
+ "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (4.8.2)\n",
730
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (4.62.3)\n",
731
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from huggingface-hub) (2.23.0)\n",
732
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.9->huggingface-hub) (3.0.6)\n",
733
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->huggingface-hub) (3.6.0)\n",
734
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (2021.10.8)\n",
735
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (2.10)\n",
736
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (3.0.4)\n",
737
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->huggingface-hub) (1.24.3)\n",
738
+ "Installing collected packages: huggingface-hub\n",
739
+ "Successfully installed huggingface-hub-0.2.1\n",
740
+ "Detected operating system as Ubuntu/bionic.\n",
741
+ "Checking for curl...\n",
742
+ "Detected curl...\n",
743
+ "Checking for gpg...\n",
744
+ "Detected gpg...\n",
745
+ "Running apt-get update... done.\n",
746
+ "Installing apt-transport-https... done.\n",
747
+ "Installing /etc/apt/sources.list.d/github_git-lfs.list...done.\n",
748
+ "Importing packagecloud gpg key... done.\n",
749
+ "Running apt-get update... done.\n",
750
+ "\n",
751
+ "The repository is setup! You can now install packages.\n",
752
+ "Reading package lists... Done\n",
753
+ "Building dependency tree \n",
754
+ "Reading state information... Done\n",
755
+ "The following NEW packages will be installed:\n",
756
+ " git-lfs\n",
757
+ "0 upgraded, 1 newly installed, 0 to remove and 67 not upgraded.\n",
758
+ "Need to get 6,526 kB of archives.\n",
759
+ "After this operation, 14.7 MB of additional disk space will be used.\n",
760
+ "Get:1 https://packagecloud.io/github/git-lfs/ubuntu bionic/main amd64 git-lfs amd64 3.0.2 [6,526 kB]\n",
761
+ "Fetched 6,526 kB in 1s (6,123 kB/s)\n",
762
+ "debconf: unable to initialize frontend: Dialog\n",
763
+ "debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 1.)\n",
764
+ "debconf: falling back to frontend: Readline\n",
765
+ "debconf: unable to initialize frontend: Readline\n",
766
+ "debconf: (This frontend requires a controlling tty.)\n",
767
+ "debconf: falling back to frontend: Teletype\n",
768
+ "dpkg-preconfigure: unable to re-open stdin: \n",
769
+ "Selecting previously unselected package git-lfs.\n",
770
+ "(Reading database ... 155226 files and directories currently installed.)\n",
771
+ "Preparing to unpack .../git-lfs_3.0.2_amd64.deb ...\n",
772
+ "Unpacking git-lfs (3.0.2) ...\n",
773
+ "Setting up git-lfs (3.0.2) ...\n",
774
+ "Git LFS initialized.\n",
775
+ "Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n",
776
+ "Git LFS initialized.\n"
777
+ ]
778
+ }
779
+ ]
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "source": [
784
+ "!huggingface-cli login"
785
+ ],
786
+ "metadata": {
787
+ "id": "mBqbC9OLBIzY",
788
+ "outputId": "e213d779-fd78-49d6-affd-cfb0869e5624",
789
+ "colab": {
790
+ "base_uri": "https://localhost:8080/"
791
+ }
792
+ },
793
+ "execution_count": 11,
794
+ "outputs": [
795
+ {
796
+ "output_type": "stream",
797
+ "name": "stdout",
798
+ "text": [
799
+ "\n",
800
+ " _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n",
801
+ " _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
802
+ " _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n",
803
+ " _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
804
+ " _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n",
805
+ "\n",
806
+ " To login, `huggingface_hub` now requires a token generated from https://huggingface.co/settings/token.\n",
807
+ " (Deprecated, will be removed in v0.3.0) To login with username and password instead, interrupt with Ctrl+C.\n",
808
+ " \n",
809
+ "Token: \n",
810
+ "Login successful\n",
811
+ "Your token has been saved to /root/.huggingface/token\n",
812
+ "\u001b[1m\u001b[31mAuthenticated through git-credential store but this isn't the helper defined on your machine.\n",
813
+ "You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set this credential helper as the default\n",
814
+ "\n",
815
+ "git config --global credential.helper store\u001b[0m\n"
816
+ ]
817
+ }
818
+ ]
819
+ },
820
+ {
821
+ "cell_type": "code",
822
+ "source": [
823
+ "\n",
824
+ "from huggingface_hub.keras_mixin import push_to_hub_keras\n",
825
+ "push_to_hub_keras(model = actor_model, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
826
+ ],
827
+ "metadata": {
828
+ "id": "B6pop1vc_4yZ",
829
+ "outputId": "f1635dd0-ac6c-4375-8054-3c873f259ef5",
830
+ "colab": {
831
+ "base_uri": "https://localhost:8080/",
832
+ "height": 141
833
+ }
834
+ },
835
+ "execution_count": 12,
836
+ "outputs": [
837
+ {
838
+ "output_type": "stream",
839
+ "name": "stderr",
840
+ "text": [
841
+ "Cloning https://huggingface.co/keras-io/deep-deterministic-policy-gradient into local empty directory.\n"
842
+ ]
843
+ },
844
+ {
845
+ "output_type": "stream",
846
+ "name": "stdout",
847
+ "text": [
848
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
849
+ "INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
850
+ ]
851
+ },
852
+ {
853
+ "output_type": "stream",
854
+ "name": "stderr",
855
+ "text": [
856
+ "To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
857
+ " 0e015ab..e37f692 main -> main\n",
858
+ "\n"
859
+ ]
860
+ },
861
+ {
862
+ "output_type": "execute_result",
863
+ "data": {
864
+ "application/vnd.google.colaboratory.intrinsic+json": {
865
+ "type": "string"
866
+ },
867
+ "text/plain": [
868
+ "'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/e37f69227324cae395ac0075b8bee416685d2c54'"
869
+ ]
870
+ },
871
+ "metadata": {},
872
+ "execution_count": 12
873
+ }
874
+ ]
875
+ },
876
+ {
877
+ "cell_type": "code",
878
+ "source": [
879
+ "push_to_hub_keras(model = critic_model, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
880
+ ],
881
+ "metadata": {
882
+ "id": "89Cj-m50BQKv",
883
+ "outputId": "56899efb-3dda-4ca6-8656-f9060741b9b3",
884
+ "colab": {
885
+ "base_uri": "https://localhost:8080/",
886
+ "height": 161
887
+ }
888
+ },
889
+ "execution_count": 13,
890
+ "outputs": [
891
+ {
892
+ "output_type": "stream",
893
+ "name": "stderr",
894
+ "text": [
895
+ "/content/deep-deterministic-policy-gradient is already a clone of https://huggingface.co/keras-io/deep-deterministic-policy-gradient. Make sure you pull the latest changes with `repo.git_pull()`.\n"
896
+ ]
897
+ },
898
+ {
899
+ "output_type": "stream",
900
+ "name": "stdout",
901
+ "text": [
902
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
903
+ "INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
904
+ ]
905
+ },
906
+ {
907
+ "output_type": "stream",
908
+ "name": "stderr",
909
+ "text": [
910
+ "To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
911
+ " e37f692..fc4c3b0 main -> main\n",
912
+ "\n"
913
+ ]
914
+ },
915
+ {
916
+ "output_type": "execute_result",
917
+ "data": {
918
+ "application/vnd.google.colaboratory.intrinsic+json": {
919
+ "type": "string"
920
+ },
921
+ "text/plain": [
922
+ "'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/fc4c3b0eadf2d9d2e6ff7a59f4e1f99763d973fe'"
923
+ ]
924
+ },
925
+ "metadata": {},
926
+ "execution_count": 13
927
+ }
928
+ ]
929
+ },
930
+ {
931
+ "cell_type": "code",
932
+ "source": [
933
+ "push_to_hub_keras(model = target_actor, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
934
+ ],
935
+ "metadata": {
936
+ "id": "wv-epAixBYAJ",
937
+ "outputId": "9c62ad0a-1523-4ba3-ced9-b6d8d3e1cbdc",
938
+ "colab": {
939
+ "base_uri": "https://localhost:8080/",
940
+ "height": 161
941
+ }
942
+ },
943
+ "execution_count": 14,
944
+ "outputs": [
945
+ {
946
+ "output_type": "stream",
947
+ "name": "stderr",
948
+ "text": [
949
+ "/content/deep-deterministic-policy-gradient is already a clone of https://huggingface.co/keras-io/deep-deterministic-policy-gradient. Make sure you pull the latest changes with `repo.git_pull()`.\n"
950
+ ]
951
+ },
952
+ {
953
+ "output_type": "stream",
954
+ "name": "stdout",
955
+ "text": [
956
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
957
+ "INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
958
+ ]
959
+ },
960
+ {
961
+ "output_type": "stream",
962
+ "name": "stderr",
963
+ "text": [
964
+ "To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
965
+ " fc4c3b0..e34067a main -> main\n",
966
+ "\n"
967
+ ]
968
+ },
969
+ {
970
+ "output_type": "execute_result",
971
+ "data": {
972
+ "application/vnd.google.colaboratory.intrinsic+json": {
973
+ "type": "string"
974
+ },
975
+ "text/plain": [
976
+ "'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/e34067a57c76c29bf60d924f352e7d72708bec82'"
977
+ ]
978
+ },
979
+ "metadata": {},
980
+ "execution_count": 14
981
+ }
982
+ ]
983
+ },
984
+ {
985
+ "cell_type": "code",
986
+ "source": [
987
+ "push_to_hub_keras(model = target_critic, repo_url = \"https://huggingface.co/keras-io/deep-deterministic-policy-gradient\", organization = \"keras-io\")"
988
+ ],
989
+ "metadata": {
990
+ "id": "3LVvvq2hBcfv",
991
+ "outputId": "c9b37d03-3f98-46d3-ee91-37e271bb5fbe",
992
+ "colab": {
993
+ "base_uri": "https://localhost:8080/",
994
+ "height": 161
995
+ }
996
+ },
997
+ "execution_count": 15,
998
+ "outputs": [
999
+ {
1000
+ "output_type": "stream",
1001
+ "name": "stderr",
1002
+ "text": [
1003
+ "/content/deep-deterministic-policy-gradient is already a clone of https://huggingface.co/keras-io/deep-deterministic-policy-gradient. Make sure you pull the latest changes with `repo.git_pull()`.\n"
1004
+ ]
1005
+ },
1006
+ {
1007
+ "output_type": "stream",
1008
+ "name": "stdout",
1009
+ "text": [
1010
+ "WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n",
1011
+ "INFO:tensorflow:Assets written to: deep-deterministic-policy-gradient/assets\n"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "output_type": "stream",
1016
+ "name": "stderr",
1017
+ "text": [
1018
+ "To https://huggingface.co/keras-io/deep-deterministic-policy-gradient\n",
1019
+ " e34067a..10b396f main -> main\n",
1020
+ "\n"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "output_type": "execute_result",
1025
+ "data": {
1026
+ "application/vnd.google.colaboratory.intrinsic+json": {
1027
+ "type": "string"
1028
+ },
1029
+ "text/plain": [
1030
+ "'https://huggingface.co/keras-io/deep-deterministic-policy-gradient/commit/10b396f3c297b2359d5b5e96f2b78a03943ec833'"
1031
+ ]
1032
+ },
1033
+ "metadata": {},
1034
+ "execution_count": 15
1035
+ }
1036
+ ]
1037
+ },
1038
+ {
1039
+ "cell_type": "code",
1040
+ "source": [
1041
+ ""
1042
+ ],
1043
+ "metadata": {
1044
+ "id": "yzwDvkqZBfFJ"
1045
+ },
1046
+ "execution_count": null,
1047
+ "outputs": []
1048
+ }
1049
+ ],
1050
+ "metadata": {
1051
+ "colab": {
1052
+ "collapsed_sections": [],
1053
+ "name": "ddpg_pendulum",
1054
+ "provenance": [],
1055
+ "toc_visible": true
1056
+ },
1057
+ "kernelspec": {
1058
+ "display_name": "Python 3",
1059
+ "language": "python",
1060
+ "name": "python3"
1061
+ },
1062
+ "language_info": {
1063
+ "codemirror_mode": {
1064
+ "name": "ipython",
1065
+ "version": 3
1066
+ },
1067
+ "file_extension": ".py",
1068
+ "mimetype": "text/x-python",
1069
+ "name": "python",
1070
+ "nbconvert_exporter": "python",
1071
+ "pygments_lexer": "ipython3",
1072
+ "version": "3.7.0"
1073
+ }
1074
+ },
1075
+ "nbformat": 4,
1076
+ "nbformat_minor": 0
1077
+ }