jsebdev commited on
Commit
b46bd9c
1 Parent(s): 59c1eb1

Created using Colaboratory

Browse files
Files changed (1) hide show
  1. stock_predictor.ipynb +788 -123
stock_predictor.ipynb CHANGED
@@ -4,7 +4,10 @@
4
  "metadata": {
5
  "colab": {
6
  "provenance": [],
7
- "authorship_tag": "ABX9TyPcuRkmq64yTPWXIBG7lLf0",
 
 
 
8
  "include_colab_link": true
9
  },
10
  "kernelspec": {
@@ -40,9 +43,9 @@
40
  "base_uri": "https://localhost:8080/"
41
  },
42
  "id": "Xr3Qozgfktoc",
43
- "outputId": "28119a16-7e41-437a-969b-3713f019548e"
44
  },
45
- "execution_count": 1,
46
  "outputs": [
47
  {
48
  "output_type": "stream",
@@ -54,157 +57,409 @@
54
  }
55
  ]
56
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  {
58
  "cell_type": "code",
59
  "source": [
60
- "# install dotenv\n",
61
- "!pip install python-dotenv"
62
  ],
63
  "metadata": {
64
  "colab": {
65
  "base_uri": "https://localhost:8080/"
66
  },
67
- "id": "E0itUkoVeKYn",
68
- "outputId": "bc2a7293-a9f0-4f4d-d42f-f7ecfab7e5c5"
69
  },
70
- "execution_count": 2,
71
  "outputs": [
72
  {
73
  "output_type": "stream",
74
  "name": "stdout",
75
  "text": [
76
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
77
- "Collecting python-dotenv\n",
78
- " Downloading python_dotenv-1.0.0-py3-none-any.whl (19 kB)\n",
79
- "Installing collected packages: python-dotenv\n",
80
- "Successfully installed python-dotenv-1.0.0\n"
81
  ]
82
  }
83
  ]
84
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  {
86
  "cell_type": "code",
87
  "source": [
88
- "# install polygon client\n",
89
- "!pip install polygon-api-client"
 
 
 
 
 
 
 
 
 
 
 
90
  ],
91
  "metadata": {
92
  "colab": {
93
  "base_uri": "https://localhost:8080/"
94
  },
95
- "id": "2bylenpXc1oB",
96
- "outputId": "74b2587b-2b58-42a1-f5bf-c3866c13b8a1"
97
  },
98
- "execution_count": 3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  "outputs": [
100
  {
101
  "output_type": "stream",
102
  "name": "stdout",
103
  "text": [
104
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
105
- "Collecting polygon-api-client\n",
106
- " Downloading polygon_api_client-1.8.5-py3-none-any.whl (38 kB)\n",
107
- "Collecting websockets<11.0,>=10.3\n",
108
- " Downloading websockets-10.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (106 kB)\n",
109
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m106.5/106.5 KB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
110
- "\u001b[?25hRequirement already satisfied: urllib3<2.0.0,>=1.26.9 in /usr/local/lib/python3.9/dist-packages (from polygon-api-client) (1.26.15)\n",
111
- "Requirement already satisfied: certifi<2023.0.0,>=2022.5.18 in /usr/local/lib/python3.9/dist-packages (from polygon-api-client) (2022.12.7)\n",
112
- "Installing collected packages: websockets, polygon-api-client\n",
113
- "Successfully installed polygon-api-client-1.8.5 websockets-10.4\n"
114
  ]
115
  }
116
  ]
117
  },
118
  {
119
- "cell_type": "code",
120
- "execution_count": 4,
 
 
121
  "metadata": {
122
- "id": "e8SQqogMQYLh"
123
- },
124
- "outputs": [],
 
 
125
  "source": [
126
- "import numpy as np\n",
127
- "import matplotlib.pyplot as plt\n",
128
- "import pandas as pd\n",
129
- "import pandas_datareader as web\n",
130
- "import datetime as dt\n",
131
- "import yfinance as yfin\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  "\n",
133
- "from sklearn.preprocessing import MinMaxScaler\n",
134
- "from tensorflow.keras.models import Sequential\n",
135
- "from tensorflow.keras.layers import Dense, Dropout, LSTM\n",
136
- "from dotenv import dotenv_values\n",
137
- "from polygon import RESTClient\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  ]
139
  },
140
  {
141
  "cell_type": "code",
142
  "source": [
143
- "# geting poligon api key\n",
144
- "config = dotenv_values(\"env_stock_predictor\")\n",
145
- "POLYGON_API_KEY = config['POLYGON_API_KEY']"
146
  ],
147
  "metadata": {
148
- "id": "MwIQIS6GeSJr"
149
  },
150
- "execution_count": 18,
151
  "outputs": []
152
  },
 
 
 
 
 
 
 
 
 
153
  {
154
  "cell_type": "code",
155
  "source": [
156
- "# Select a company for now\n",
157
- "ticker = 'AAPL'\n",
158
- "\n",
159
- "data_sources = {'pandas': 'pandas-datareader',\n",
160
- " 'polygon':'polygon'}\n",
161
- "source = data_sources['polygon']\n",
162
- "# source = data_sources['pandas']\n",
163
  "\n",
164
- "start = dt.datetime(2013,1,1)\n",
165
- "end = dt.date.today()"
 
166
  ],
167
  "metadata": {
168
- "id": "O6dtJpJwS5Eg"
169
  },
170
- "execution_count": 19,
171
  "outputs": []
172
  },
 
 
 
 
 
 
 
 
 
173
  {
174
  "cell_type": "code",
175
  "source": [
176
- "if source == data_sources['pandas']:\n",
177
- " yfin.pdr_override()\n",
178
- " data = web.data.get_data_yahoo(ticker, start, end)\n",
179
- "elif source == data_sources['polygon']:\n",
180
- " # using the poligon API\n",
181
- " poligon_client = RESTClient(api_key=POLYGON_API_KEY)\n",
182
- " # bars = poligon_client.get_aggs(ticker=ticker, multiplier=1, timespan=\"day\", from_=\"2023-01-09\", to=\"2023-01-15\")\n",
183
- " # bars = poligon_client.get_aggs(ticker=ticker, multiplier=1, timespan=\"day\", from_=start, to=end)\n",
184
- " bars = poligon_client.get_aggs(ticker=ticker, multiplier=1, timespan=\"hour\", from_=dt.datetime.now() - dt.timedelta(days=5), to=dt.datetime.now())\n",
185
- " print(len(bars))\n",
186
- " for bar in bars[-2:]:\n",
187
- " print(type(bar))\n",
188
- " print(bar)\n",
189
- " print(bar.timestamp)\n",
190
- " print(dt.date.fromtimestamp(bar.timestamp/1000))\n",
191
- " print(dt.datetime.fromtimestamp(bar.timestamp/1000))"
192
- ],
193
- "metadata": {
194
- "id": "LwPyk8Uh-Zz_"
195
  },
196
- "execution_count": 36,
197
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  },
199
  {
200
  "cell_type": "code",
201
- "source": [],
 
 
202
  "metadata": {
203
  "colab": {
204
  "base_uri": "https://localhost:8080/"
205
  },
206
- "id": "IX_o3NTggblq",
207
- "outputId": "27d4d43b-e063-4651-db16-f5ecf819860b"
208
  },
209
  "execution_count": 37,
210
  "outputs": [
@@ -212,46 +467,99 @@
212
  "output_type": "stream",
213
  "name": "stdout",
214
  "text": [
215
- "41\n",
216
- "<class 'polygon.rest.models.aggs.Agg'>\n",
217
- "Agg(open=165.57, high=165.68, low=165.53, close=165.64, volume=11712, vwap=165.6067, timestamp=1680645600000, transactions=258, otc=None)\n",
218
- "1680645600000\n",
219
- "2023-04-04\n",
220
- "2023-04-04 22:00:00\n",
221
- "<class 'polygon.rest.models.aggs.Agg'>\n",
222
- "Agg(open=165.6, high=165.85, low=165.6, close=165.79, volume=28951, vwap=165.7385, timestamp=1680649200000, transactions=533, otc=None)\n",
223
- "1680649200000\n",
224
- "2023-04-04\n",
225
- "2023-04-04 23:00:00\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  ]
 
 
 
 
 
 
 
 
 
 
227
  }
228
  ]
229
  },
 
 
 
 
 
 
 
 
 
230
  {
231
  "cell_type": "code",
232
  "source": [
233
- "print(type(spy))\n",
234
- "print(spy.head())"
235
  ],
236
  "metadata": {
237
  "colab": {
238
- "base_uri": "https://localhost:8080/",
239
- "height": 187
240
  },
241
- "id": "EMoXLT5vd8Ex",
242
- "outputId": "74416af4-da65-4d12-ed3a-27806b8f0965"
243
  },
244
- "execution_count": 10,
245
  "outputs": [
246
  {
247
- "output_type": "error",
248
- "ename": "NameError",
249
- "evalue": "ignored",
250
- "traceback": [
251
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
252
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
253
- "\u001b[0;32m<ipython-input-10-dab045b648a5>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
254
- "\u001b[0;31mNameError\u001b[0m: name 'spy' is not defined"
255
  ]
256
  }
257
  ]
@@ -259,33 +567,390 @@
259
  {
260
  "cell_type": "code",
261
  "source": [
262
- "df = web.DataReader('GE', 'yahoo', start='2019-09-10', end='2019-10-09')\n",
263
- "print(start)\n",
264
- "print(end)"
265
  ],
266
  "metadata": {
267
- "id": "THGxnQbSUgvw"
268
  },
269
- "execution_count": null,
270
  "outputs": []
271
  },
272
  {
273
  "cell_type": "code",
274
  "source": [
275
- "scaler = MinMaxScaler(feature_range=(0,1))\n",
276
- "scaled_data = scaler.fit_transform(data['Close'].values.reshape(-1,1))\n",
277
- "prediction_days = 60\n",
278
  "\n",
279
- "x_train = []\n",
280
- "y_train = []\n",
 
 
 
 
 
 
281
  "\n",
282
- "for x in range()"
 
 
 
 
283
  ],
284
  "metadata": {
285
- "id": "ccV59ukvXaNF"
 
 
 
 
286
  },
287
- "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  }
290
  ]
291
  }
 
4
  "metadata": {
5
  "colab": {
6
  "provenance": [],
7
+ "collapsed_sections": [
8
+ "Z3N2WMYNV-qX"
9
+ ],
10
+ "authorship_tag": "ABX9TyOuk8MIfThoeWnRbBQlPf+h",
11
  "include_colab_link": true
12
  },
13
  "kernelspec": {
 
43
  "base_uri": "https://localhost:8080/"
44
  },
45
  "id": "Xr3Qozgfktoc",
46
+ "outputId": "e80033fb-a41f-438f-fc90-60bc0317d5d3"
47
  },
48
+ "execution_count": 2,
49
  "outputs": [
50
  {
51
  "output_type": "stream",
 
57
  }
58
  ]
59
  },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "metadata": {
64
+ "id": "e8SQqogMQYLh"
65
+ },
66
+ "outputs": [],
67
+ "source": [
68
+ "import numpy as np\n",
69
+ "import matplotlib.pyplot as plt\n",
70
+ "import pandas as pd\n",
71
+ "import pandas_datareader as web\n",
72
+ "import datetime as dt\n",
73
+ "import yfinance as yfin\n",
74
+ "import tensorflow as tf\n",
75
+ "import os\n",
76
+ "import re\n",
77
+ "\n",
78
+ "from sklearn.preprocessing import MinMaxScaler\n",
79
+ "from tensorflow.keras.models import Sequential\n",
80
+ "from tensorflow.keras.layers import Dense, Dropout, LSTM\n"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "source": [
86
+ "# Get Data"
87
+ ],
88
+ "metadata": {
89
+ "id": "5vO8pty3VwkG"
90
+ }
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "source": [
95
+ "# Select a company for now\n",
96
+ "ticker = 'AAPL'\n",
97
+ "\n",
98
+ "start = dt.datetime(2013,1,1)\n",
99
+ "end = dt.datetime(2023,4,5)"
100
+ ],
101
+ "metadata": {
102
+ "id": "O6dtJpJwS5Eg"
103
+ },
104
+ "execution_count": 93,
105
+ "outputs": []
106
+ },
107
  {
108
  "cell_type": "code",
109
  "source": [
110
+ "yfin.pdr_override()\n",
111
+ "data = web.data.get_data_yahoo(ticker, start, end)\n"
112
  ],
113
  "metadata": {
114
  "colab": {
115
  "base_uri": "https://localhost:8080/"
116
  },
117
+ "id": "LwPyk8Uh-Zz_",
118
+ "outputId": "63953807-ca2e-4e18-c571-a6bcc4f8db5d"
119
  },
120
+ "execution_count": 5,
121
  "outputs": [
122
  {
123
  "output_type": "stream",
124
  "name": "stdout",
125
  "text": [
126
+ "\r[*********************100%***********************] 1 of 1 completed\n"
 
 
 
 
127
  ]
128
  }
129
  ]
130
  },
131
+ {
132
+ "cell_type": "markdown",
133
+ "source": [
134
+ "# Preprocess_data"
135
+ ],
136
+ "metadata": {
137
+ "id": "SSuS9OONV5-a"
138
+ }
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "source": [
143
+ "def normalize_data(data, relative_to_previous=True, scaler=None):\n",
144
+ " def substract_to_values(data, value):\n",
145
+ " df_copy = pd.DataFrame.copy(data)\n",
146
+ " df_copy[['Open', 'High', 'Low', 'Close', 'Adj Close']] = df_copy[['Open', 'High', 'Low', 'Close', 'Adj Close']] - value\n",
147
+ " return df_copy\n",
148
+ " if relative_to_previous:\n",
149
+ " the_data = pd.DataFrame(substract_to_values(data.iloc[0], data.iloc[0]['Open'])).T\n",
150
+ " # the_data = substract_to_values(data.iloc[0], data.iloc[0]['Open']).to_frame().T # This is the same as the previous line\n",
151
+ " for i in range(1,len(data)):\n",
152
+ " the_data = pd.concat((the_data, substract_to_values(data.iloc[i], data.iloc[i-1]['Close']).to_frame().T))\n",
153
+ " else:\n",
154
+ " the_data = pd.DataFrame.copy(data)\n",
155
+ " \n",
156
+ " if scaler is None:\n",
157
+ " # Create the scaler\n",
158
+ " values = the_data.values\n",
159
+ " # print('values')\n",
160
+ " # print(values)\n",
161
+ " max_value = np.max(values[:,:-1])\n",
162
+ " # print(max_value)\n",
163
+ " min_value = np.min(values[:,:-1])\n",
164
+ " # print(min_value)\n",
165
+ " max_volume = np.max(values[:,-1])\n",
166
+ " min_volume = np.min(values[:,-1])\n",
167
+ " # print(max_volume, min_volume)\n",
168
+ " def scaler(data):\n",
169
+ " values = data.values\n",
170
+ " # print(values)\n",
171
+ " values[:,:-1] = (values[:,:-1] - min_value) / (max_value-min_value) * 2 - 1\n",
172
+ " values[:,-1] = (values[:,-1] - min_volume) / (max_volume-min_volume) * 2 - 1\n",
173
+ " # print(values)\n",
174
+ " return data\n",
175
+ " def anti_scaler(values):\n",
176
+ " decoded_values = (values + 1) * (max_value-min_value) / 2 + min_value \n",
177
+ " return decoded_values\n",
178
+ " \n",
179
+ " normalized_data = scaler(the_data)\n",
180
+ "\n",
181
+ " return normalized_data, scaler, anti_scaler\n",
182
+ "\n",
183
+ "\n"
184
+ ],
185
+ "metadata": {
186
+ "id": "v9RoqzBvtrOb"
187
+ },
188
+ "execution_count": 111,
189
+ "outputs": []
190
+ },
191
  {
192
  "cell_type": "code",
193
  "source": [
194
+ "norm_data, the_scaler, the_decoder = normalize_data(data, relative_to_previous=True)\n",
195
+ "#todo: save the_scaler somehow to use in new runtimes"
196
+ ],
197
+ "metadata": {
198
+ "id": "-kgo__Q3hw1_"
199
+ },
200
+ "execution_count": 112,
201
+ "outputs": []
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "source": [
206
+ "len(norm_data)"
207
  ],
208
  "metadata": {
209
  "colab": {
210
  "base_uri": "https://localhost:8080/"
211
  },
212
+ "id": "A1L8giqcsutX",
213
+ "outputId": "0aaf515b-3835-432c-b882-c2111a221ed4"
214
  },
215
+ "execution_count": 41,
216
+ "outputs": [
217
+ {
218
+ "output_type": "execute_result",
219
+ "data": {
220
+ "text/plain": [
221
+ "2583"
222
+ ]
223
+ },
224
+ "metadata": {},
225
+ "execution_count": 41
226
+ }
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "source": [
232
+ "prediction_days = 100\n",
233
+ "\n",
234
+ "x_train_list = []\n",
235
+ "y_train_list = []\n",
236
+ "\n",
237
+ "for i in range(prediction_days, len(norm_data)):\n",
238
+ " x_train_list.append(norm_data[i-prediction_days:i])\n",
239
+ " y_train_list.append(norm_data.iloc[i].values[0:4])\n",
240
+ "\n",
241
+ "x_train = np.array(x_train_list)\n",
242
+ "y_train = np.array(y_train_list)"
243
+ ],
244
+ "metadata": {
245
+ "id": "jMXkRAYFomHM"
246
+ },
247
+ "execution_count": 9,
248
+ "outputs": []
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "source": [
253
+ "print(x_train.shape)\n",
254
+ "print(y_train.shape)\n",
255
+ "print(x_train.shape[1:])"
256
+ ],
257
+ "metadata": {
258
+ "colab": {
259
+ "base_uri": "https://localhost:8080/"
260
+ },
261
+ "id": "G7oMd1fRyOYt",
262
+ "outputId": "2094c403-096d-4f3a-9b15-bae0fbb7bf11"
263
+ },
264
+ "execution_count": 10,
265
  "outputs": [
266
  {
267
  "output_type": "stream",
268
  "name": "stdout",
269
  "text": [
270
+ "(2483, 100, 6)\n",
271
+ "(2483, 4)\n",
272
+ "(100, 6)\n"
 
 
 
 
 
 
 
273
  ]
274
  }
275
  ]
276
  },
277
  {
278
+ "cell_type": "markdown",
279
+ "source": [
280
+ "# Model"
281
+ ],
282
  "metadata": {
283
+ "id": "Z3N2WMYNV-qX"
284
+ }
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
  "source": [
289
+ "## Create Model"
290
+ ],
291
+ "metadata": {
292
+ "id": "emDyvzVUp5KJ"
293
+ }
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "source": [
298
+ "def create_model():\n",
299
+ " model = Sequential()\n",
300
+ " # model.add(LSTM(units=112, return_sequences=True, input_shape=(x_train.shape[1:])))\n",
301
+ " model.add(LSTM(units=112, return_sequences=True, input_shape=(None,x_train.shape[-1],)))\n",
302
+ " model.add(Dropout(0.2))\n",
303
+ " model.add(LSTM(units=112, return_sequences=True))\n",
304
+ " model.add(Dropout(0.2))\n",
305
+ " model.add(LSTM(units=50))\n",
306
+ " model.add(Dropout(0.2))\n",
307
+ " model.add(Dense(units=4))\n",
308
+ " return model\n",
309
  "\n",
310
+ "model = create_model()\n",
311
+ "print(model.summary())"
312
+ ],
313
+ "metadata": {
314
+ "colab": {
315
+ "base_uri": "https://localhost:8080/"
316
+ },
317
+ "id": "GXhYAKzXVfku",
318
+ "outputId": "c54da788-6e82-4679-df1f-d3e89a20d228"
319
+ },
320
+ "execution_count": 66,
321
+ "outputs": [
322
+ {
323
+ "output_type": "stream",
324
+ "name": "stdout",
325
+ "text": [
326
+ "Model: \"sequential_1\"\n",
327
+ "_________________________________________________________________\n",
328
+ " Layer (type) Output Shape Param # \n",
329
+ "=================================================================\n",
330
+ " lstm_3 (LSTM) (None, None, 112) 53312 \n",
331
+ " \n",
332
+ " dropout_3 (Dropout) (None, None, 112) 0 \n",
333
+ " \n",
334
+ " lstm_4 (LSTM) (None, None, 112) 100800 \n",
335
+ " \n",
336
+ " dropout_4 (Dropout) (None, None, 112) 0 \n",
337
+ " \n",
338
+ " lstm_5 (LSTM) (None, 50) 32600 \n",
339
+ " \n",
340
+ " dropout_5 (Dropout) (None, 50) 0 \n",
341
+ " \n",
342
+ " dense_1 (Dense) (None, 4) 204 \n",
343
+ " \n",
344
+ "=================================================================\n",
345
+ "Total params: 186,916\n",
346
+ "Trainable params: 186,916\n",
347
+ "Non-trainable params: 0\n",
348
+ "_________________________________________________________________\n",
349
+ "None\n"
350
+ ]
351
+ }
352
  ]
353
  },
354
  {
355
  "cell_type": "code",
356
  "source": [
357
+ "model.compile(optimizer='adam', loss='mean_squared_error')"
 
 
358
  ],
359
  "metadata": {
360
+ "id": "ZhoWj_XeXQws"
361
  },
362
+ "execution_count": 12,
363
  "outputs": []
364
  },
365
+ {
366
+ "cell_type": "markdown",
367
+ "source": [
368
+ "## Create checkpoint callback"
369
+ ],
370
+ "metadata": {
371
+ "id": "XU0vc4n8p92L"
372
+ }
373
+ },
374
  {
375
  "cell_type": "code",
376
  "source": [
377
+ "# Directory where the checkpoints will be saved\n",
378
+ "checkpoint_dir = './training_checkpoints_'+dt.datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
379
+ "# Name of the checkpoint files\n",
380
+ "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_epoch{epoch}_loss{loss}\")\n",
 
 
 
381
  "\n",
382
+ "checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(\n",
383
+ " filepath=checkpoint_prefix,\n",
384
+ " save_weights_only=True)"
385
  ],
386
  "metadata": {
387
+ "id": "M5MBAB1-qCZr"
388
  },
389
+ "execution_count": 35,
390
  "outputs": []
391
  },
392
+ {
393
+ "cell_type": "markdown",
394
+ "source": [
395
+ "## Model Train"
396
+ ],
397
+ "metadata": {
398
+ "id": "65QbfffusPoJ"
399
+ }
400
+ },
401
  {
402
  "cell_type": "code",
403
  "source": [
404
+ "print(x_train.shape)\n",
405
+ "print(y_train.shape)"
406
+ ],
407
+ "metadata": {
408
+ "colab": {
409
+ "base_uri": "https://localhost:8080/"
410
+ },
411
+ "id": "HDT9XPXHvqyN",
412
+ "outputId": "60938333-8afe-4b80-9af3-37bca3d67f83"
 
 
 
 
 
 
 
 
 
 
413
  },
414
+ "execution_count": 15,
415
+ "outputs": [
416
+ {
417
+ "output_type": "stream",
418
+ "name": "stdout",
419
+ "text": [
420
+ "(2483, 100, 6)\n",
421
+ "(2483, 4)\n"
422
+ ]
423
+ }
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "source": [
429
+ "y_train[-2]"
430
+ ],
431
+ "metadata": {
432
+ "colab": {
433
+ "base_uri": "https://localhost:8080/"
434
+ },
435
+ "id": "F1wZkJMh3XNH",
436
+ "outputId": "37a023db-0727-434a-85be-141c3c377907"
437
+ },
438
+ "execution_count": 40,
439
+ "outputs": [
440
+ {
441
+ "output_type": "execute_result",
442
+ "data": {
443
+ "text/plain": [
444
+ "array([ 0.02002301, 0.0391905 , -0.09898045, -0.05744885])"
445
+ ]
446
+ },
447
+ "metadata": {},
448
+ "execution_count": 40
449
+ }
450
+ ]
451
  },
452
  {
453
  "cell_type": "code",
454
+ "source": [
455
+ "model.fit(x_train, y_train, epochs=25, batch_size=32, callbacks=[checkpoint_callback])\n"
456
+ ],
457
  "metadata": {
458
  "colab": {
459
  "base_uri": "https://localhost:8080/"
460
  },
461
+ "id": "9Ccc_Ej2TmYO",
462
+ "outputId": "235efc3b-616b-4e57-fb87-07efcb377e8e"
463
  },
464
  "execution_count": 37,
465
  "outputs": [
 
467
  "output_type": "stream",
468
  "name": "stdout",
469
  "text": [
470
+ "Epoch 1/25\n",
471
+ "78/78 [==============================] - 31s 395ms/step - loss: 0.0117\n",
472
+ "Epoch 2/25\n",
473
+ "78/78 [==============================] - 31s 394ms/step - loss: 0.0111\n",
474
+ "Epoch 3/25\n",
475
+ "78/78 [==============================] - 33s 429ms/step - loss: 0.0109\n",
476
+ "Epoch 4/25\n",
477
+ "78/78 [==============================] - 31s 396ms/step - loss: 0.0109\n",
478
+ "Epoch 5/25\n",
479
+ "78/78 [==============================] - 31s 398ms/step - loss: 0.0108\n",
480
+ "Epoch 6/25\n",
481
+ "78/78 [==============================] - 31s 400ms/step - loss: 0.0108\n",
482
+ "Epoch 7/25\n",
483
+ "78/78 [==============================] - 32s 405ms/step - loss: 0.0108\n",
484
+ "Epoch 8/25\n",
485
+ "78/78 [==============================] - 31s 394ms/step - loss: 0.0108\n",
486
+ "Epoch 9/25\n",
487
+ "78/78 [==============================] - 30s 385ms/step - loss: 0.0108\n",
488
+ "Epoch 10/25\n",
489
+ "78/78 [==============================] - 30s 385ms/step - loss: 0.0108\n",
490
+ "Epoch 11/25\n",
491
+ "78/78 [==============================] - 29s 373ms/step - loss: 0.0108\n",
492
+ "Epoch 12/25\n",
493
+ "78/78 [==============================] - 29s 375ms/step - loss: 0.0107\n",
494
+ "Epoch 13/25\n",
495
+ "78/78 [==============================] - 30s 383ms/step - loss: 0.0107\n",
496
+ "Epoch 14/25\n",
497
+ "78/78 [==============================] - 30s 388ms/step - loss: 0.0107\n",
498
+ "Epoch 15/25\n",
499
+ "78/78 [==============================] - 31s 396ms/step - loss: 0.0108\n",
500
+ "Epoch 16/25\n",
501
+ "78/78 [==============================] - 30s 379ms/step - loss: 0.0107\n",
502
+ "Epoch 17/25\n",
503
+ "78/78 [==============================] - 30s 386ms/step - loss: 0.0107\n",
504
+ "Epoch 18/25\n",
505
+ "78/78 [==============================] - 30s 383ms/step - loss: 0.0108\n",
506
+ "Epoch 19/25\n",
507
+ "78/78 [==============================] - 30s 382ms/step - loss: 0.0107\n",
508
+ "Epoch 20/25\n",
509
+ "78/78 [==============================] - 31s 397ms/step - loss: 0.0107\n",
510
+ "Epoch 21/25\n",
511
+ "78/78 [==============================] - 30s 384ms/step - loss: 0.0107\n",
512
+ "Epoch 22/25\n",
513
+ "78/78 [==============================] - 30s 381ms/step - loss: 0.0106\n",
514
+ "Epoch 23/25\n",
515
+ "78/78 [==============================] - 30s 380ms/step - loss: 0.0106\n",
516
+ "Epoch 24/25\n",
517
+ "78/78 [==============================] - 30s 385ms/step - loss: 0.0107\n",
518
+ "Epoch 25/25\n",
519
+ "78/78 [==============================] - 30s 383ms/step - loss: 0.0106\n"
520
  ]
521
+ },
522
+ {
523
+ "output_type": "execute_result",
524
+ "data": {
525
+ "text/plain": [
526
+ "<keras.callbacks.History at 0x7f5203d70cd0>"
527
+ ]
528
+ },
529
+ "metadata": {},
530
+ "execution_count": 37
531
  }
532
  ]
533
  },
534
+ {
535
+ "cell_type": "markdown",
536
+ "source": [
537
+ "# Testing a model"
538
+ ],
539
+ "metadata": {
540
+ "id": "dbSKl47vZvpe"
541
+ }
542
+ },
543
  {
544
  "cell_type": "code",
545
  "source": [
546
+ "#print trainings directories to pick one\n",
547
+ "!ls -d training_checkpoints_*/"
548
  ],
549
  "metadata": {
550
  "colab": {
551
+ "base_uri": "https://localhost:8080/"
 
552
  },
553
+ "id": "59CDDB0i4yTx",
554
+ "outputId": "497ae253-e3ac-47d0-d066-8e508f55782c"
555
  },
556
+ "execution_count": 49,
557
  "outputs": [
558
  {
559
+ "output_type": "stream",
560
+ "name": "stdout",
561
+ "text": [
562
+ "training_checkpoints_20230406041748/\n"
 
 
 
 
563
  ]
564
  }
565
  ]
 
567
  {
568
  "cell_type": "code",
569
  "source": [
570
+ "test_model = create_model()"
 
 
571
  ],
572
  "metadata": {
573
+ "id": "tpmru7nG9kbW"
574
  },
575
+ "execution_count": 72,
576
  "outputs": []
577
  },
578
  {
579
  "cell_type": "code",
580
  "source": [
581
+ "checkpoint_dir = 'training_checkpoints_20230406041748'\n",
 
 
582
  "\n",
583
+ "def load_weights(epoch=None):\n",
584
+ " if epoch is None:\n",
585
+ " weights_file = tf.train.latest_checkpoint(checkpoint_dir)\n",
586
+ " else:\n",
587
+ " with os.scandir(checkpoint_dir) as entries:\n",
588
+ " for entry in entries:\n",
589
+ " if re.search(f'^ckpt_epoch{epoch}_.*\\.index', entry.name):\n",
590
+ " weights_file = checkpoint_dir + '/'+ entry.name[:-6]\n",
591
  "\n",
592
+ " print(weights_file)\n",
593
+ " test_model.load_weights(weights_file)\n",
594
+ " return test_model\n",
595
+ "\n",
596
+ "test_model = load_weights()"
597
  ],
598
  "metadata": {
599
+ "colab": {
600
+ "base_uri": "https://localhost:8080/"
601
+ },
602
+ "id": "wQ0JTXsp4VKF",
603
+ "outputId": "d4b794c9-7a89-4867-d17c-de1f20b9b607"
604
  },
605
+ "execution_count": 87,
606
+ "outputs": [
607
+ {
608
+ "output_type": "stream",
609
+ "name": "stdout",
610
+ "text": [
611
+ "training_checkpoints_20230406041748/ckpt_epoch25_loss0.01064301934093237\n"
612
+ ]
613
+ }
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "source": [
619
+ "test_start = dt.date.today() - dt.timedelta(days=200)\n",
620
+ "test_end = dt.date.today()\n",
621
+ "\n",
622
+ "yfin.pdr_override()\n",
623
+ "test_data = web.data.get_data_yahoo(ticker, test_start, test_end)"
624
+ ],
625
+ "metadata": {
626
+ "colab": {
627
+ "base_uri": "https://localhost:8080/"
628
+ },
629
+ "id": "Mf4q97pfaSCA",
630
+ "outputId": "4317ef63-be5e-49ca-fdca-1d5760efbba1"
631
+ },
632
+ "execution_count": 99,
633
+ "outputs": [
634
+ {
635
+ "output_type": "stream",
636
+ "name": "stdout",
637
+ "text": [
638
+ "\r[*********************100%***********************] 1 of 1 completed\n"
639
+ ]
640
+ }
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "source": [
646
+ "test_data_norm, _ = normalize_data(test_data, scaler=the_scaler)"
647
+ ],
648
+ "metadata": {
649
+ "id": "xEG2yEdKC8uy"
650
+ },
651
+ "execution_count": 100,
652
  "outputs": []
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "source": [
657
+ "print(type(test_data_norm))"
658
+ ],
659
+ "metadata": {
660
+ "colab": {
661
+ "base_uri": "https://localhost:8080/"
662
+ },
663
+ "id": "mhbqRZ6cDhd6",
664
+ "outputId": "8b40a738-e143-4920-de03-8e8572f4389a"
665
+ },
666
+ "execution_count": 102,
667
+ "outputs": [
668
+ {
669
+ "output_type": "stream",
670
+ "name": "stdout",
671
+ "text": [
672
+ "<class 'pandas.core.frame.DataFrame'>\n"
673
+ ]
674
+ }
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "source": [
680
+ "input_data = np.expand_dims(test_data_norm.values, axis=0)\n",
681
+ "print(input_data.shape)"
682
+ ],
683
+ "metadata": {
684
+ "colab": {
685
+ "base_uri": "https://localhost:8080/"
686
+ },
687
+ "id": "F2bnofchD0xv",
688
+ "outputId": "0b2261fb-056d-4ec2-a98b-82517d7806f1"
689
+ },
690
+ "execution_count": 104,
691
+ "outputs": [
692
+ {
693
+ "output_type": "stream",
694
+ "name": "stdout",
695
+ "text": [
696
+ "(1, 138, 6)\n"
697
+ ]
698
+ }
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "source": [
704
+ "results = test_model.predict(input_data, batch_size=1)"
705
+ ],
706
+ "metadata": {
707
+ "colab": {
708
+ "base_uri": "https://localhost:8080/"
709
+ },
710
+ "id": "AVYFQZnqEqhx",
711
+ "outputId": "958d1669-c8bc-4eff-bb66-f25eb4dde011"
712
+ },
713
+ "execution_count": 105,
714
+ "outputs": [
715
+ {
716
+ "output_type": "stream",
717
+ "name": "stdout",
718
+ "text": [
719
+ "1/1 [==============================] - 1s 1s/step\n"
720
+ ]
721
+ }
722
+ ]
723
+ },
724
+ {
725
+ "cell_type": "code",
726
+ "source": [
727
+ "print(results)\n",
728
+ "print(the_decoder(results))"
729
+ ],
730
+ "metadata": {
731
+ "colab": {
732
+ "base_uri": "https://localhost:8080/"
733
+ },
734
+ "id": "FbdX4ulhExsX",
735
+ "outputId": "14a763ca-0983-41ec-e88b-da796fa4b51a"
736
+ },
737
+ "execution_count": 113,
738
+ "outputs": [
739
+ {
740
+ "output_type": "stream",
741
+ "name": "stdout",
742
+ "text": [
743
+ "[[-0.01962117 0.09634934 -0.10176479 -0.00849891]]\n",
744
+ "[[-0.06636524 1.3856668 -1.0948591 0.0728941 ]]\n"
745
+ ]
746
+ }
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "source": [
752
+ "test_data.head()"
753
+ ],
754
+ "metadata": {
755
+ "colab": {
756
+ "base_uri": "https://localhost:8080/",
757
+ "height": 237
758
+ },
759
+ "id": "m0k7toG3E2_9",
760
+ "outputId": "38ab6e43-1321-4028-9482-8e6687802a7d"
761
+ },
762
+ "execution_count": 107,
763
+ "outputs": [
764
+ {
765
+ "output_type": "execute_result",
766
+ "data": {
767
+ "text/plain": [
768
+ " Open High Low Close Adj Close \\\n",
769
+ "Date \n",
770
+ "2022-09-19 149.309998 154.559998 149.100006 154.479996 153.989029 \n",
771
+ "2022-09-20 153.399994 158.080002 153.080002 156.899994 156.401352 \n",
772
+ "2022-09-21 157.339996 158.740005 153.600006 153.720001 153.231461 \n",
773
+ "2022-09-22 152.380005 154.470001 150.910004 152.740005 152.254578 \n",
774
+ "2022-09-23 151.190002 151.470001 148.559998 150.429993 149.951904 \n",
775
+ "\n",
776
+ " Volume \n",
777
+ "Date \n",
778
+ "2022-09-19 81474200 \n",
779
+ "2022-09-20 107689800 \n",
780
+ "2022-09-21 101696800 \n",
781
+ "2022-09-22 86652500 \n",
782
+ "2022-09-23 96029900 "
783
+ ],
784
+ "text/html": [
785
+ "\n",
786
+ " <div id=\"df-51b6b5ba-2841-4317-9ce2-b32b40e2e9fc\">\n",
787
+ " <div class=\"colab-df-container\">\n",
788
+ " <div>\n",
789
+ "<style scoped>\n",
790
+ " .dataframe tbody tr th:only-of-type {\n",
791
+ " vertical-align: middle;\n",
792
+ " }\n",
793
+ "\n",
794
+ " .dataframe tbody tr th {\n",
795
+ " vertical-align: top;\n",
796
+ " }\n",
797
+ "\n",
798
+ " .dataframe thead th {\n",
799
+ " text-align: right;\n",
800
+ " }\n",
801
+ "</style>\n",
802
+ "<table border=\"1\" class=\"dataframe\">\n",
803
+ " <thead>\n",
804
+ " <tr style=\"text-align: right;\">\n",
805
+ " <th></th>\n",
806
+ " <th>Open</th>\n",
807
+ " <th>High</th>\n",
808
+ " <th>Low</th>\n",
809
+ " <th>Close</th>\n",
810
+ " <th>Adj Close</th>\n",
811
+ " <th>Volume</th>\n",
812
+ " </tr>\n",
813
+ " <tr>\n",
814
+ " <th>Date</th>\n",
815
+ " <th></th>\n",
816
+ " <th></th>\n",
817
+ " <th></th>\n",
818
+ " <th></th>\n",
819
+ " <th></th>\n",
820
+ " <th></th>\n",
821
+ " </tr>\n",
822
+ " </thead>\n",
823
+ " <tbody>\n",
824
+ " <tr>\n",
825
+ " <th>2022-09-19</th>\n",
826
+ " <td>149.309998</td>\n",
827
+ " <td>154.559998</td>\n",
828
+ " <td>149.100006</td>\n",
829
+ " <td>154.479996</td>\n",
830
+ " <td>153.989029</td>\n",
831
+ " <td>81474200</td>\n",
832
+ " </tr>\n",
833
+ " <tr>\n",
834
+ " <th>2022-09-20</th>\n",
835
+ " <td>153.399994</td>\n",
836
+ " <td>158.080002</td>\n",
837
+ " <td>153.080002</td>\n",
838
+ " <td>156.899994</td>\n",
839
+ " <td>156.401352</td>\n",
840
+ " <td>107689800</td>\n",
841
+ " </tr>\n",
842
+ " <tr>\n",
843
+ " <th>2022-09-21</th>\n",
844
+ " <td>157.339996</td>\n",
845
+ " <td>158.740005</td>\n",
846
+ " <td>153.600006</td>\n",
847
+ " <td>153.720001</td>\n",
848
+ " <td>153.231461</td>\n",
849
+ " <td>101696800</td>\n",
850
+ " </tr>\n",
851
+ " <tr>\n",
852
+ " <th>2022-09-22</th>\n",
853
+ " <td>152.380005</td>\n",
854
+ " <td>154.470001</td>\n",
855
+ " <td>150.910004</td>\n",
856
+ " <td>152.740005</td>\n",
857
+ " <td>152.254578</td>\n",
858
+ " <td>86652500</td>\n",
859
+ " </tr>\n",
860
+ " <tr>\n",
861
+ " <th>2022-09-23</th>\n",
862
+ " <td>151.190002</td>\n",
863
+ " <td>151.470001</td>\n",
864
+ " <td>148.559998</td>\n",
865
+ " <td>150.429993</td>\n",
866
+ " <td>149.951904</td>\n",
867
+ " <td>96029900</td>\n",
868
+ " </tr>\n",
869
+ " </tbody>\n",
870
+ "</table>\n",
871
+ "</div>\n",
872
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-51b6b5ba-2841-4317-9ce2-b32b40e2e9fc')\"\n",
873
+ " title=\"Convert this dataframe to an interactive table.\"\n",
874
+ " style=\"display:none;\">\n",
875
+ " \n",
876
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
877
+ " width=\"24px\">\n",
878
+ " <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
879
+ " <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
880
+ " </svg>\n",
881
+ " </button>\n",
882
+ " \n",
883
+ " <style>\n",
884
+ " .colab-df-container {\n",
885
+ " display:flex;\n",
886
+ " flex-wrap:wrap;\n",
887
+ " gap: 12px;\n",
888
+ " }\n",
889
+ "\n",
890
+ " .colab-df-convert {\n",
891
+ " background-color: #E8F0FE;\n",
892
+ " border: none;\n",
893
+ " border-radius: 50%;\n",
894
+ " cursor: pointer;\n",
895
+ " display: none;\n",
896
+ " fill: #1967D2;\n",
897
+ " height: 32px;\n",
898
+ " padding: 0 0 0 0;\n",
899
+ " width: 32px;\n",
900
+ " }\n",
901
+ "\n",
902
+ " .colab-df-convert:hover {\n",
903
+ " background-color: #E2EBFA;\n",
904
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
905
+ " fill: #174EA6;\n",
906
+ " }\n",
907
+ "\n",
908
+ " [theme=dark] .colab-df-convert {\n",
909
+ " background-color: #3B4455;\n",
910
+ " fill: #D2E3FC;\n",
911
+ " }\n",
912
+ "\n",
913
+ " [theme=dark] .colab-df-convert:hover {\n",
914
+ " background-color: #434B5C;\n",
915
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
916
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
917
+ " fill: #FFFFFF;\n",
918
+ " }\n",
919
+ " </style>\n",
920
+ "\n",
921
+ " <script>\n",
922
+ " const buttonEl =\n",
923
+ " document.querySelector('#df-51b6b5ba-2841-4317-9ce2-b32b40e2e9fc button.colab-df-convert');\n",
924
+ " buttonEl.style.display =\n",
925
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
926
+ "\n",
927
+ " async function convertToInteractive(key) {\n",
928
+ " const element = document.querySelector('#df-51b6b5ba-2841-4317-9ce2-b32b40e2e9fc');\n",
929
+ " const dataTable =\n",
930
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
931
+ " [key], {});\n",
932
+ " if (!dataTable) return;\n",
933
+ "\n",
934
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
935
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
936
+ " + ' to learn more about interactive tables.';\n",
937
+ " element.innerHTML = '';\n",
938
+ " dataTable['output_type'] = 'display_data';\n",
939
+ " await google.colab.output.renderOutput(dataTable, element);\n",
940
+ " const docLink = document.createElement('div');\n",
941
+ " docLink.innerHTML = docLinkHtml;\n",
942
+ " element.appendChild(docLink);\n",
943
+ " }\n",
944
+ " </script>\n",
945
+ " </div>\n",
946
+ " </div>\n",
947
+ " "
948
+ ]
949
+ },
950
+ "metadata": {},
951
+ "execution_count": 107
952
+ }
953
+ ]
954
  }
955
  ]
956
  }