FrederikKl commited on
Commit
a9831f4
β€’
1 Parent(s): 200f6a5

Add files via upload

Browse files
Files changed (1) hide show
  1. training_pipeline-copy.ipynb +1058 -0
training_pipeline-copy.ipynb ADDED
@@ -0,0 +1,1058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/plain": [
11
+ "False"
12
+ ]
13
+ },
14
+ "execution_count": 1,
15
+ "metadata": {},
16
+ "output_type": "execute_result"
17
+ }
18
+ ],
19
+ "source": [
20
+ "import hopsworks\n",
21
+ "from dotenv import load_dotenv\n",
22
+ "import os\n",
23
+ "import pandas as pd\n",
24
+ "from sklearn.preprocessing import OneHotEncoder\n",
25
+ "from sklearn.preprocessing import MinMaxScaler\n",
26
+ "from hsml.schema import Schema\n",
27
+ "from hsml.model_schema import ModelSchema\n",
28
+ "\n",
29
+ "\n",
30
+ "load_dotenv()"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 2,
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "Connected. Call `.close()` to terminate connection gracefully.\n",
43
+ "\n",
44
+ "Logged in to project, explore it here https://c.app.hopsworks.ai:443/p/693399\n",
45
+ "Connected. Call `.close()` to terminate connection gracefully.\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "api_key = os.environ.get('hopsworks_api')\n",
51
+ "project = hopsworks.login(api_key_value=api_key)\n",
52
+ "fs = project.get_feature_store()"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 3,
58
+ "metadata": {},
59
+ "outputs": [
60
+ {
61
+ "name": "stdout",
62
+ "output_type": "stream",
63
+ "text": [
64
+ "Connected. Call `.close()` to terminate connection gracefully.\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "import hsfs\n",
70
+ "\n",
71
+ "# Connection setup\n",
72
+ "# Connect to Hopsworks\n",
73
+ "api_key = os.getenv('hopsworks_api')\n",
74
+ "connection = hsfs.connection()\n",
75
+ "fs = connection.get_feature_store()\n",
76
+ "\n",
77
+ "# Get feature view\n",
78
+ "\n"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 4,
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "feature_view = fs.get_feature_view(\n",
88
+ " name='tesla_stocks_fv',\n",
89
+ " version=1\n",
90
+ ")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 5,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "train_start = \"2022-06-22\"\n",
100
+ "train_end = \"2023-12-31\"\n",
101
+ "\n",
102
+ "test_start = '2024-01-01'\n",
103
+ "test_end = \"2024-05-03\"\n"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 6,
109
+ "metadata": {},
110
+ "outputs": [
111
+ {
112
+ "name": "stdout",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "Training dataset job started successfully, you can follow the progress at \n",
116
+ "https://c.app.hopsworks.ai/p/693399/jobs/named/tesla_stocks_fv_1_create_fv_td_06052024212158/executions\n",
117
+ "2024-05-06 23:23:21,130 WARNING: VersionWarning: Incremented version to `5`.\n",
118
+ "\n"
119
+ ]
120
+ },
121
+ {
122
+ "data": {
123
+ "text/plain": [
124
+ "(5, <hsfs.core.job.Job at 0x194e21067d0>)"
125
+ ]
126
+ },
127
+ "execution_count": 6,
128
+ "metadata": {},
129
+ "output_type": "execute_result"
130
+ }
131
+ ],
132
+ "source": [
133
+ "feature_view.create_train_test_split(\n",
134
+ " train_start=train_start,\n",
135
+ " train_end=train_end,\n",
136
+ " test_start=test_start,\n",
137
+ " test_end=test_end,\n",
138
+ " data_format='csv',\n",
139
+ " coalesce= True,\n",
140
+ " statistics_config={'histogram':True,'correlations':True})"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 7,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "X_train, X_test, y_train, y_test = feature_view.get_train_test_split(5)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 8,
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "data": {
159
+ "text/html": [
160
+ "<div>\n",
161
+ "<style scoped>\n",
162
+ " .dataframe tbody tr th:only-of-type {\n",
163
+ " vertical-align: middle;\n",
164
+ " }\n",
165
+ "\n",
166
+ " .dataframe tbody tr th {\n",
167
+ " vertical-align: top;\n",
168
+ " }\n",
169
+ "\n",
170
+ " .dataframe thead th {\n",
171
+ " text-align: right;\n",
172
+ " }\n",
173
+ "</style>\n",
174
+ "<table border=\"1\" class=\"dataframe\">\n",
175
+ " <thead>\n",
176
+ " <tr style=\"text-align: right;\">\n",
177
+ " <th></th>\n",
178
+ " <th>date</th>\n",
179
+ " <th>ticker</th>\n",
180
+ " <th>sentiment</th>\n",
181
+ " </tr>\n",
182
+ " </thead>\n",
183
+ " <tbody>\n",
184
+ " <tr>\n",
185
+ " <th>0</th>\n",
186
+ " <td>2022-12-14T00:00:00.000Z</td>\n",
187
+ " <td>TSLA</td>\n",
188
+ " <td>0.102207</td>\n",
189
+ " </tr>\n",
190
+ " <tr>\n",
191
+ " <th>1</th>\n",
192
+ " <td>2023-02-21T00:00:00.000Z</td>\n",
193
+ " <td>TSLA</td>\n",
194
+ " <td>0.155833</td>\n",
195
+ " </tr>\n",
196
+ " <tr>\n",
197
+ " <th>2</th>\n",
198
+ " <td>2023-08-17T00:00:00.000Z</td>\n",
199
+ " <td>TSLA</td>\n",
200
+ " <td>0.024046</td>\n",
201
+ " </tr>\n",
202
+ " <tr>\n",
203
+ " <th>3</th>\n",
204
+ " <td>2022-09-16T00:00:00.000Z</td>\n",
205
+ " <td>TSLA</td>\n",
206
+ " <td>0.087306</td>\n",
207
+ " </tr>\n",
208
+ " <tr>\n",
209
+ " <th>4</th>\n",
210
+ " <td>2023-08-28T00:00:00.000Z</td>\n",
211
+ " <td>TSLA</td>\n",
212
+ " <td>0.024046</td>\n",
213
+ " </tr>\n",
214
+ " <tr>\n",
215
+ " <th>...</th>\n",
216
+ " <td>...</td>\n",
217
+ " <td>...</td>\n",
218
+ " <td>...</td>\n",
219
+ " </tr>\n",
220
+ " <tr>\n",
221
+ " <th>378</th>\n",
222
+ " <td>2023-02-10T00:00:00.000Z</td>\n",
223
+ " <td>TSLA</td>\n",
224
+ " <td>0.155833</td>\n",
225
+ " </tr>\n",
226
+ " <tr>\n",
227
+ " <th>379</th>\n",
228
+ " <td>2023-05-08T00:00:00.000Z</td>\n",
229
+ " <td>TSLA</td>\n",
230
+ " <td>0.141296</td>\n",
231
+ " </tr>\n",
232
+ " <tr>\n",
233
+ " <th>380</th>\n",
234
+ " <td>2022-09-08T00:00:00.000Z</td>\n",
235
+ " <td>TSLA</td>\n",
236
+ " <td>0.087306</td>\n",
237
+ " </tr>\n",
238
+ " <tr>\n",
239
+ " <th>381</th>\n",
240
+ " <td>2023-07-06T00:00:00.000Z</td>\n",
241
+ " <td>TSLA</td>\n",
242
+ " <td>0.119444</td>\n",
243
+ " </tr>\n",
244
+ " <tr>\n",
245
+ " <th>382</th>\n",
246
+ " <td>2023-10-27T00:00:00.000Z</td>\n",
247
+ " <td>TSLA</td>\n",
248
+ " <td>0.164868</td>\n",
249
+ " </tr>\n",
250
+ " </tbody>\n",
251
+ "</table>\n",
252
+ "<p>383 rows Γ— 3 columns</p>\n",
253
+ "</div>"
254
+ ],
255
+ "text/plain": [
256
+ " date ticker sentiment\n",
257
+ "0 2022-12-14T00:00:00.000Z TSLA 0.102207\n",
258
+ "1 2023-02-21T00:00:00.000Z TSLA 0.155833\n",
259
+ "2 2023-08-17T00:00:00.000Z TSLA 0.024046\n",
260
+ "3 2022-09-16T00:00:00.000Z TSLA 0.087306\n",
261
+ "4 2023-08-28T00:00:00.000Z TSLA 0.024046\n",
262
+ ".. ... ... ...\n",
263
+ "378 2023-02-10T00:00:00.000Z TSLA 0.155833\n",
264
+ "379 2023-05-08T00:00:00.000Z TSLA 0.141296\n",
265
+ "380 2022-09-08T00:00:00.000Z TSLA 0.087306\n",
266
+ "381 2023-07-06T00:00:00.000Z TSLA 0.119444\n",
267
+ "382 2023-10-27T00:00:00.000Z TSLA 0.164868\n",
268
+ "\n",
269
+ "[383 rows x 3 columns]"
270
+ ]
271
+ },
272
+ "execution_count": 8,
273
+ "metadata": {},
274
+ "output_type": "execute_result"
275
+ }
276
+ ],
277
+ "source": [
278
+ "X_train"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 9,
284
+ "metadata": {},
285
+ "outputs": [
286
+ {
287
+ "data": {
288
+ "text/html": [
289
+ "<div>\n",
290
+ "<style scoped>\n",
291
+ " .dataframe tbody tr th:only-of-type {\n",
292
+ " vertical-align: middle;\n",
293
+ " }\n",
294
+ "\n",
295
+ " .dataframe tbody tr th {\n",
296
+ " vertical-align: top;\n",
297
+ " }\n",
298
+ "\n",
299
+ " .dataframe thead th {\n",
300
+ " text-align: right;\n",
301
+ " }\n",
302
+ "</style>\n",
303
+ "<table border=\"1\" class=\"dataframe\">\n",
304
+ " <thead>\n",
305
+ " <tr style=\"text-align: right;\">\n",
306
+ " <th></th>\n",
307
+ " <th>date</th>\n",
308
+ " <th>ticker</th>\n",
309
+ " <th>sentiment</th>\n",
310
+ " </tr>\n",
311
+ " </thead>\n",
312
+ " <tbody>\n",
313
+ " <tr>\n",
314
+ " <th>0</th>\n",
315
+ " <td>2024-04-16T00:00:00.000Z</td>\n",
316
+ " <td>TSLA</td>\n",
317
+ " <td>0.018769</td>\n",
318
+ " </tr>\n",
319
+ " <tr>\n",
320
+ " <th>1</th>\n",
321
+ " <td>2024-02-22T00:00:00.000Z</td>\n",
322
+ " <td>TSLA</td>\n",
323
+ " <td>0.212963</td>\n",
324
+ " </tr>\n",
325
+ " <tr>\n",
326
+ " <th>2</th>\n",
327
+ " <td>2024-02-13T00:00:00.000Z</td>\n",
328
+ " <td>TSLA</td>\n",
329
+ " <td>0.099363</td>\n",
330
+ " </tr>\n",
331
+ " <tr>\n",
332
+ " <th>3</th>\n",
333
+ " <td>2024-01-17T00:00:00.000Z</td>\n",
334
+ " <td>TSLA</td>\n",
335
+ " <td>0.099363</td>\n",
336
+ " </tr>\n",
337
+ " <tr>\n",
338
+ " <th>4</th>\n",
339
+ " <td>2024-02-16T00:00:00.000Z</td>\n",
340
+ " <td>TSLA</td>\n",
341
+ " <td>0.099363</td>\n",
342
+ " </tr>\n",
343
+ " </tbody>\n",
344
+ "</table>\n",
345
+ "</div>"
346
+ ],
347
+ "text/plain": [
348
+ " date ticker sentiment\n",
349
+ "0 2024-04-16T00:00:00.000Z TSLA 0.018769\n",
350
+ "1 2024-02-22T00:00:00.000Z TSLA 0.212963\n",
351
+ "2 2024-02-13T00:00:00.000Z TSLA 0.099363\n",
352
+ "3 2024-01-17T00:00:00.000Z TSLA 0.099363\n",
353
+ "4 2024-02-16T00:00:00.000Z TSLA 0.099363"
354
+ ]
355
+ },
356
+ "execution_count": 9,
357
+ "metadata": {},
358
+ "output_type": "execute_result"
359
+ }
360
+ ],
361
+ "source": [
362
+ "X_test.head()"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": 10,
368
+ "metadata": {},
369
+ "outputs": [
370
+ {
371
+ "data": {
372
+ "text/html": [
373
+ "<div>\n",
374
+ "<style scoped>\n",
375
+ " .dataframe tbody tr th:only-of-type {\n",
376
+ " vertical-align: middle;\n",
377
+ " }\n",
378
+ "\n",
379
+ " .dataframe tbody tr th {\n",
380
+ " vertical-align: top;\n",
381
+ " }\n",
382
+ "\n",
383
+ " .dataframe thead th {\n",
384
+ " text-align: right;\n",
385
+ " }\n",
386
+ "</style>\n",
387
+ "<table border=\"1\" class=\"dataframe\">\n",
388
+ " <thead>\n",
389
+ " <tr style=\"text-align: right;\">\n",
390
+ " <th></th>\n",
391
+ " <th>date</th>\n",
392
+ " <th>ticker</th>\n",
393
+ " <th>sentiment</th>\n",
394
+ " </tr>\n",
395
+ " </thead>\n",
396
+ " <tbody>\n",
397
+ " <tr>\n",
398
+ " <th>80</th>\n",
399
+ " <td>2024-05-02T00:00:00.000Z</td>\n",
400
+ " <td>TSLA</td>\n",
401
+ " <td>0.001443</td>\n",
402
+ " </tr>\n",
403
+ " <tr>\n",
404
+ " <th>81</th>\n",
405
+ " <td>2024-04-02T00:00:00.000Z</td>\n",
406
+ " <td>TSLA</td>\n",
407
+ " <td>0.080911</td>\n",
408
+ " </tr>\n",
409
+ " <tr>\n",
410
+ " <th>82</th>\n",
411
+ " <td>2024-03-22T00:00:00.000Z</td>\n",
412
+ " <td>TSLA</td>\n",
413
+ " <td>0.080911</td>\n",
414
+ " </tr>\n",
415
+ " <tr>\n",
416
+ " <th>83</th>\n",
417
+ " <td>2024-01-02T00:00:00.000Z</td>\n",
418
+ " <td>TSLA</td>\n",
419
+ " <td>-0.122579</td>\n",
420
+ " </tr>\n",
421
+ " <tr>\n",
422
+ " <th>84</th>\n",
423
+ " <td>2024-02-26T00:00:00.000Z</td>\n",
424
+ " <td>TSLA</td>\n",
425
+ " <td>0.152764</td>\n",
426
+ " </tr>\n",
427
+ " </tbody>\n",
428
+ "</table>\n",
429
+ "</div>"
430
+ ],
431
+ "text/plain": [
432
+ " date ticker sentiment\n",
433
+ "80 2024-05-02T00:00:00.000Z TSLA 0.001443\n",
434
+ "81 2024-04-02T00:00:00.000Z TSLA 0.080911\n",
435
+ "82 2024-03-22T00:00:00.000Z TSLA 0.080911\n",
436
+ "83 2024-01-02T00:00:00.000Z TSLA -0.122579\n",
437
+ "84 2024-02-26T00:00:00.000Z TSLA 0.152764"
438
+ ]
439
+ },
440
+ "execution_count": 10,
441
+ "metadata": {},
442
+ "output_type": "execute_result"
443
+ }
444
+ ],
445
+ "source": [
446
+ "X_test.tail()"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": 11,
452
+ "metadata": {},
453
+ "outputs": [],
454
+ "source": [
455
+ "X_train['date'] = pd.to_datetime(X_train['date']).dt.date\n",
456
+ "X_test['date'] = pd.to_datetime(X_test['date']).dt.date\n",
457
+ "X_train['date'] = pd.to_datetime(X_train['date'])\n",
458
+ "X_test['date'] = pd.to_datetime(X_test['date'])"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 12,
464
+ "metadata": {},
465
+ "outputs": [
466
+ {
467
+ "data": {
468
+ "text/html": [
469
+ "<div>\n",
470
+ "<style scoped>\n",
471
+ " .dataframe tbody tr th:only-of-type {\n",
472
+ " vertical-align: middle;\n",
473
+ " }\n",
474
+ "\n",
475
+ " .dataframe tbody tr th {\n",
476
+ " vertical-align: top;\n",
477
+ " }\n",
478
+ "\n",
479
+ " .dataframe thead th {\n",
480
+ " text-align: right;\n",
481
+ " }\n",
482
+ "</style>\n",
483
+ "<table border=\"1\" class=\"dataframe\">\n",
484
+ " <thead>\n",
485
+ " <tr style=\"text-align: right;\">\n",
486
+ " <th></th>\n",
487
+ " <th>date</th>\n",
488
+ " <th>ticker</th>\n",
489
+ " <th>sentiment</th>\n",
490
+ " </tr>\n",
491
+ " </thead>\n",
492
+ " <tbody>\n",
493
+ " <tr>\n",
494
+ " <th>0</th>\n",
495
+ " <td>2022-12-14</td>\n",
496
+ " <td>TSLA</td>\n",
497
+ " <td>0.102207</td>\n",
498
+ " </tr>\n",
499
+ " <tr>\n",
500
+ " <th>1</th>\n",
501
+ " <td>2023-02-21</td>\n",
502
+ " <td>TSLA</td>\n",
503
+ " <td>0.155833</td>\n",
504
+ " </tr>\n",
505
+ " <tr>\n",
506
+ " <th>2</th>\n",
507
+ " <td>2023-08-17</td>\n",
508
+ " <td>TSLA</td>\n",
509
+ " <td>0.024046</td>\n",
510
+ " </tr>\n",
511
+ " <tr>\n",
512
+ " <th>3</th>\n",
513
+ " <td>2022-09-16</td>\n",
514
+ " <td>TSLA</td>\n",
515
+ " <td>0.087306</td>\n",
516
+ " </tr>\n",
517
+ " <tr>\n",
518
+ " <th>4</th>\n",
519
+ " <td>2023-08-28</td>\n",
520
+ " <td>TSLA</td>\n",
521
+ " <td>0.024046</td>\n",
522
+ " </tr>\n",
523
+ " </tbody>\n",
524
+ "</table>\n",
525
+ "</div>"
526
+ ],
527
+ "text/plain": [
528
+ " date ticker sentiment\n",
529
+ "0 2022-12-14 TSLA 0.102207\n",
530
+ "1 2023-02-21 TSLA 0.155833\n",
531
+ "2 2023-08-17 TSLA 0.024046\n",
532
+ "3 2022-09-16 TSLA 0.087306\n",
533
+ "4 2023-08-28 TSLA 0.024046"
534
+ ]
535
+ },
536
+ "execution_count": 12,
537
+ "metadata": {},
538
+ "output_type": "execute_result"
539
+ }
540
+ ],
541
+ "source": [
542
+ "X_train.head()"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": 13,
548
+ "metadata": {},
549
+ "outputs": [],
550
+ "source": [
551
+ "# Extract the 'ticker' column\n",
552
+ "tickers = X_train[['ticker']]\n",
553
+ "\n",
554
+ "# Initialize OneHotEncoder\n",
555
+ "encoder = OneHotEncoder()\n",
556
+ "\n",
557
+ "# Fit and transform the 'ticker' column\n",
558
+ "ticker_encoded = encoder.fit_transform(tickers)\n",
559
+ "\n",
560
+ "# Convert the encoded column into a DataFrame\n",
561
+ "ticker_encoded_df = pd.DataFrame(ticker_encoded.toarray(), columns=encoder.get_feature_names_out(['ticker']))\n",
562
+ "\n",
563
+ "# Concatenate the encoded DataFrame with the original DataFrame\n",
564
+ "X_train = pd.concat([X_train, ticker_encoded_df], axis=1)\n",
565
+ "\n",
566
+ "# Drop the original 'ticker' column\n",
567
+ "X_train.drop('ticker', axis=1, inplace=True)"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": 14,
573
+ "metadata": {},
574
+ "outputs": [
575
+ {
576
+ "data": {
577
+ "text/html": [
578
+ "<div>\n",
579
+ "<style scoped>\n",
580
+ " .dataframe tbody tr th:only-of-type {\n",
581
+ " vertical-align: middle;\n",
582
+ " }\n",
583
+ "\n",
584
+ " .dataframe tbody tr th {\n",
585
+ " vertical-align: top;\n",
586
+ " }\n",
587
+ "\n",
588
+ " .dataframe thead th {\n",
589
+ " text-align: right;\n",
590
+ " }\n",
591
+ "</style>\n",
592
+ "<table border=\"1\" class=\"dataframe\">\n",
593
+ " <thead>\n",
594
+ " <tr style=\"text-align: right;\">\n",
595
+ " <th></th>\n",
596
+ " <th>date</th>\n",
597
+ " <th>sentiment</th>\n",
598
+ " <th>ticker_TSLA</th>\n",
599
+ " </tr>\n",
600
+ " </thead>\n",
601
+ " <tbody>\n",
602
+ " <tr>\n",
603
+ " <th>0</th>\n",
604
+ " <td>2022-12-14</td>\n",
605
+ " <td>0.102207</td>\n",
606
+ " <td>1.0</td>\n",
607
+ " </tr>\n",
608
+ " <tr>\n",
609
+ " <th>1</th>\n",
610
+ " <td>2023-02-21</td>\n",
611
+ " <td>0.155833</td>\n",
612
+ " <td>1.0</td>\n",
613
+ " </tr>\n",
614
+ " <tr>\n",
615
+ " <th>2</th>\n",
616
+ " <td>2023-08-17</td>\n",
617
+ " <td>0.024046</td>\n",
618
+ " <td>1.0</td>\n",
619
+ " </tr>\n",
620
+ " <tr>\n",
621
+ " <th>3</th>\n",
622
+ " <td>2022-09-16</td>\n",
623
+ " <td>0.087306</td>\n",
624
+ " <td>1.0</td>\n",
625
+ " </tr>\n",
626
+ " <tr>\n",
627
+ " <th>4</th>\n",
628
+ " <td>2023-08-28</td>\n",
629
+ " <td>0.024046</td>\n",
630
+ " <td>1.0</td>\n",
631
+ " </tr>\n",
632
+ " </tbody>\n",
633
+ "</table>\n",
634
+ "</div>"
635
+ ],
636
+ "text/plain": [
637
+ " date sentiment ticker_TSLA\n",
638
+ "0 2022-12-14 0.102207 1.0\n",
639
+ "1 2023-02-21 0.155833 1.0\n",
640
+ "2 2023-08-17 0.024046 1.0\n",
641
+ "3 2022-09-16 0.087306 1.0\n",
642
+ "4 2023-08-28 0.024046 1.0"
643
+ ]
644
+ },
645
+ "execution_count": 14,
646
+ "metadata": {},
647
+ "output_type": "execute_result"
648
+ }
649
+ ],
650
+ "source": [
651
+ "X_train.head()"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": 15,
657
+ "metadata": {},
658
+ "outputs": [],
659
+ "source": [
660
+ "tickers = X_test[['ticker']]\n",
661
+ "\n",
662
+ "# Initialize OneHotEncoder\n",
663
+ "encoder = OneHotEncoder()\n",
664
+ "\n",
665
+ "# Fit and transform the 'ticker' column\n",
666
+ "ticker_encoded_test = encoder.fit_transform(tickers)\n",
667
+ "\n",
668
+ "# Convert the encoded column into a DataFrame\n",
669
+ "ticker_encoded_df_test = pd.DataFrame(ticker_encoded_test.toarray(), columns=encoder.get_feature_names_out(['ticker']))\n",
670
+ "\n",
671
+ "# Concatenate the encoded DataFrame with the original DataFrame\n",
672
+ "X_test = pd.concat([X_test, ticker_encoded_df_test], axis=1)\n",
673
+ "\n",
674
+ "# Drop the original 'ticker' column\n",
675
+ "X_test.drop('ticker', axis=1, inplace=True)"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": 16,
681
+ "metadata": {},
682
+ "outputs": [],
683
+ "source": [
684
+ "scaler = MinMaxScaler()\n",
685
+ "\n",
686
+ "# Fit and transform the 'open' column\n",
687
+ "y_train['open_scaled'] = scaler.fit_transform(y_train[['open']])\n",
688
+ "y_train.drop('open', axis=1, inplace=True)"
689
+ ]
690
+ },
691
+ {
692
+ "cell_type": "code",
693
+ "execution_count": 17,
694
+ "metadata": {},
695
+ "outputs": [],
696
+ "source": [
697
+ "y_test['open_scaled'] = scaler.fit_transform(y_test[['open']])\n",
698
+ "y_test.drop('open', axis=1, inplace=True)"
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "execution_count": 18,
704
+ "metadata": {},
705
+ "outputs": [],
706
+ "source": [
707
+ "from tensorflow.keras.models import Sequential\n",
708
+ "from tensorflow.keras.layers import Input, LSTM, Dense, Dropout\n",
709
+ "from sklearn.preprocessing import StandardScaler # Import StandardScaler from scikit-learn\n",
710
+ "\n",
711
+ "def create_model(input_shape,\n",
712
+ " LSTM_filters=64,\n",
713
+ " dropout=0.1,\n",
714
+ " recurrent_dropout=0.1,\n",
715
+ " dense_dropout=0.5,\n",
716
+ " activation='relu',\n",
717
+ " depth=1):\n",
718
+ "\n",
719
+ " model = Sequential()\n",
720
+ "\n",
721
+ " # Input layer\n",
722
+ " model.add(Input(shape=input_shape))\n",
723
+ "\n",
724
+ " if depth > 1:\n",
725
+ " for i in range(1, depth):\n",
726
+ " # Recurrent layer\n",
727
+ " model.add(LSTM(LSTM_filters, return_sequences=True, dropout=dropout, recurrent_dropout=recurrent_dropout))\n",
728
+ "\n",
729
+ " # Recurrent layer\n",
730
+ " model.add(LSTM(LSTM_filters, return_sequences=False, dropout=dropout, recurrent_dropout=recurrent_dropout))\n",
731
+ "\n",
732
+ " # Fully connected layer\n",
733
+ " if activation == 'relu':\n",
734
+ " model.add(Dense(LSTM_filters, activation='relu'))\n",
735
+ " elif activation == 'leaky_relu':\n",
736
+ " model.add(Dense(LSTM_filters))\n",
737
+ " model.add(tf.keras.layers.LeakyReLU(alpha=0.1))\n",
738
+ "\n",
739
+ " # Dropout for regularization\n",
740
+ " model.add(Dropout(dense_dropout))\n",
741
+ "\n",
742
+ " # Output layer for predicting one day forward\n",
743
+ " model.add(Dense(1, activation='linear'))\n",
744
+ "\n",
745
+ " # Compile the model\n",
746
+ " model.compile(optimizer='adam', loss='mse')\n",
747
+ "\n",
748
+ " return model"
749
+ ]
750
+ },
751
+ {
752
+ "cell_type": "code",
753
+ "execution_count": 19,
754
+ "metadata": {},
755
+ "outputs": [
756
+ {
757
+ "name": "stdout",
758
+ "output_type": "stream",
759
+ "text": [
760
+ "2024-05-06 23:23:33,215 WARNING: DeprecationWarning: np.find_common_type is deprecated. Please use `np.result_type` or `np.promote_types`.\n",
761
+ "See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information. (Deprecated NumPy 1.25)\n",
762
+ "\n"
763
+ ]
764
+ }
765
+ ],
766
+ "source": [
767
+ "import numpy as np\n",
768
+ "\n",
769
+ "# Assuming X_train['date'] column exists and is in datetime format\n",
770
+ "X_train['year'] = X_train['date'].dt.year\n",
771
+ "X_train['month'] = X_train['date'].dt.month\n",
772
+ "X_train['day'] = X_train['date'].dt.day\n",
773
+ "\n",
774
+ "# Drop the original date column\n",
775
+ "X_train.drop(columns=['date'], inplace=True)\n",
776
+ "\n",
777
+ "# Convert dataframe to numpy array\n",
778
+ "X_train_array = X_train.to_numpy()\n",
779
+ "\n",
780
+ "# Reshape the array to have a shape suitable for LSTM\n",
781
+ "# Assuming each row represents a sample and each column represents a feature\n",
782
+ "# Reshape to [samples, timesteps, features]\n",
783
+ "X_train_array = np.expand_dims(X_train_array, axis=1)\n"
784
+ ]
785
+ },
786
+ {
787
+ "cell_type": "code",
788
+ "execution_count": 20,
789
+ "metadata": {},
790
+ "outputs": [],
791
+ "source": [
792
+ "import numpy as np\n",
793
+ "\n",
794
+ "# Convert DataFrame to numpy array\n",
795
+ "X_train_array = X_train.values\n",
796
+ "\n",
797
+ "# Reshape X_train_array to add a time step dimension\n",
798
+ "X_train_reshaped = X_train_array.reshape(X_train_array.shape[0], 1, X_train_array.shape[1])\n",
799
+ "\n",
800
+ "# Assuming X_train_reshaped shape is now (374, 1, 5)\n",
801
+ "input_shape = X_train_reshaped.shape[1:]\n",
802
+ "\n",
803
+ "# Create the model\n",
804
+ "model = create_model(input_shape=input_shape)"
805
+ ]
806
+ },
807
+ {
808
+ "cell_type": "code",
809
+ "execution_count": 21,
810
+ "metadata": {},
811
+ "outputs": [
812
+ {
813
+ "name": "stdout",
814
+ "output_type": "stream",
815
+ "text": [
816
+ "\u001b[1m12/12\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 3ms/step - loss: 0.5165\n"
817
+ ]
818
+ },
819
+ {
820
+ "data": {
821
+ "text/plain": [
822
+ "<keras.src.callbacks.history.History at 0x194ec2a4550>"
823
+ ]
824
+ },
825
+ "execution_count": 21,
826
+ "metadata": {},
827
+ "output_type": "execute_result"
828
+ }
829
+ ],
830
+ "source": [
831
+ "model.fit(X_train_reshaped, y_train)"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": 22,
837
+ "metadata": {},
838
+ "outputs": [
839
+ {
840
+ "name": "stdout",
841
+ "output_type": "stream",
842
+ "text": [
843
+ "2024-05-06 23:23:37,549 WARNING: DeprecationWarning: np.find_common_type is deprecated. Please use `np.result_type` or `np.promote_types`.\n",
844
+ "See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information. (Deprecated NumPy 1.25)\n",
845
+ "\n"
846
+ ]
847
+ }
848
+ ],
849
+ "source": [
850
+ "# Assuming X_test['date'] column exists and is in datetime format\n",
851
+ "X_test['year'] = X_test['date'].dt.year\n",
852
+ "X_test['month'] = X_test['date'].dt.month\n",
853
+ "X_test['day'] = X_test['date'].dt.day\n",
854
+ "\n",
855
+ "# Drop the original date column\n",
856
+ "X_test.drop(columns=['date'], inplace=True)\n",
857
+ "\n",
858
+ "# Convert dataframe to numpy array\n",
859
+ "X_test_array = X_test.to_numpy()\n",
860
+ "\n",
861
+ "# Reshape the array to have a shape suitable for LSTM\n",
862
+ "# Assuming each row represents a sample and each column represents a feature\n",
863
+ "# Reshape to [samples, timesteps, features]\n",
864
+ "X_test_array = np.expand_dims(X_test_array, axis=1)\n"
865
+ ]
866
+ },
867
+ {
868
+ "cell_type": "code",
869
+ "execution_count": 30,
870
+ "metadata": {},
871
+ "outputs": [
872
+ {
873
+ "name": "stdout",
874
+ "output_type": "stream",
875
+ "text": [
876
+ "\u001b[1m3/3\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 237ms/step\n"
877
+ ]
878
+ }
879
+ ],
880
+ "source": [
881
+ "y_pred = model.predict(X_test_array)"
882
+ ]
883
+ },
884
+ {
885
+ "cell_type": "code",
886
+ "execution_count": 31,
887
+ "metadata": {},
888
+ "outputs": [
889
+ {
890
+ "name": "stdout",
891
+ "output_type": "stream",
892
+ "text": [
893
+ "Connected. Call `.close()` to terminate connection gracefully.\n"
894
+ ]
895
+ }
896
+ ],
897
+ "source": [
898
+ "mr = project.get_model_registry()"
899
+ ]
900
+ },
901
+ {
902
+ "cell_type": "code",
903
+ "execution_count": 37,
904
+ "metadata": {},
905
+ "outputs": [
906
+ {
907
+ "data": {
908
+ "text/plain": [
909
+ "['LSTM_model.keras']"
910
+ ]
911
+ },
912
+ "execution_count": 37,
913
+ "metadata": {},
914
+ "output_type": "execute_result"
915
+ }
916
+ ],
917
+ "source": [
918
+ "import joblib\n",
919
+ "joblib.dump(model, 'LSTM_model.keras')"
920
+ ]
921
+ },
922
+ {
923
+ "cell_type": "code",
924
+ "execution_count": 32,
925
+ "metadata": {},
926
+ "outputs": [
927
+ {
928
+ "data": {
929
+ "text/plain": [
930
+ "{'RMSE': 0.40675989895763576}"
931
+ ]
932
+ },
933
+ "execution_count": 32,
934
+ "metadata": {},
935
+ "output_type": "execute_result"
936
+ }
937
+ ],
938
+ "source": [
939
+ "from sklearn.metrics import mean_squared_error\n",
940
+ "import numpy as np\n",
941
+ "\n",
942
+ "# Compute RMSE\n",
943
+ "rmse = np.sqrt(mean_squared_error(y_test, y_pred))\n",
944
+ "rmse_metrics = {\"RMSE\": rmse}\n",
945
+ "rmse_metrics\n"
946
+ ]
947
+ },
948
+ {
949
+ "cell_type": "code",
950
+ "execution_count": 33,
951
+ "metadata": {},
952
+ "outputs": [],
953
+ "source": [
954
+ "input_schema = Schema(X_train)\n",
955
+ "output_schema = Schema(y_train)\n",
956
+ "model_schema = ModelSchema(input_schema, output_schema)"
957
+ ]
958
+ },
959
+ {
960
+ "cell_type": "code",
961
+ "execution_count": 38,
962
+ "metadata": {},
963
+ "outputs": [
964
+ {
965
+ "data": {
966
+ "application/vnd.jupyter.widget-view+json": {
967
+ "model_id": "1dd08fa9a7c144638a9f5c4600df04fa",
968
+ "version_major": 2,
969
+ "version_minor": 0
970
+ },
971
+ "text/plain": [
972
+ " 0%| | 0/6 [00:00<?, ?it/s]"
973
+ ]
974
+ },
975
+ "metadata": {},
976
+ "output_type": "display_data"
977
+ },
978
+ {
979
+ "data": {
980
+ "application/vnd.jupyter.widget-view+json": {
981
+ "model_id": "a02a72f26d7b433599f80f8b7d3ad72c",
982
+ "version_major": 2,
983
+ "version_minor": 0
984
+ },
985
+ "text/plain": [
986
+ "Uploading: 0.000%| | 0/291253 elapsed<00:00 remaining<?"
987
+ ]
988
+ },
989
+ "metadata": {},
990
+ "output_type": "display_data"
991
+ },
992
+ {
993
+ "data": {
994
+ "application/vnd.jupyter.widget-view+json": {
995
+ "model_id": "ac84dbdca58f4648b1eb54452812b563",
996
+ "version_major": 2,
997
+ "version_minor": 0
998
+ },
999
+ "text/plain": [
1000
+ "Uploading: 0.000%| | 0/561 elapsed<00:00 remaining<?"
1001
+ ]
1002
+ },
1003
+ "metadata": {},
1004
+ "output_type": "display_data"
1005
+ },
1006
+ {
1007
+ "name": "stdout",
1008
+ "output_type": "stream",
1009
+ "text": [
1010
+ "Model created, explore it at https://c.app.hopsworks.ai:443/p/693399/models/stock_pred_model/4\n"
1011
+ ]
1012
+ },
1013
+ {
1014
+ "data": {
1015
+ "text/plain": [
1016
+ "Model(name: 'stock_pred_model', version: 4)"
1017
+ ]
1018
+ },
1019
+ "execution_count": 38,
1020
+ "metadata": {},
1021
+ "output_type": "execute_result"
1022
+ }
1023
+ ],
1024
+ "source": [
1025
+ "stock_pred_model = mr.tensorflow.create_model(\n",
1026
+ " name=\"stock_pred_model\",\n",
1027
+ " metrics= rmse_metrics,\n",
1028
+ " model_schema=model_schema,\n",
1029
+ " description=\"Stock Market TSLA Predictor from News Sentiment\",\n",
1030
+ " )\n",
1031
+ "\n",
1032
+ "stock_pred_model.save('LSTM_model.keras')"
1033
+ ]
1034
+ }
1035
+ ],
1036
+ "metadata": {
1037
+ "kernelspec": {
1038
+ "display_name": "base",
1039
+ "language": "python",
1040
+ "name": "python3"
1041
+ },
1042
+ "language_info": {
1043
+ "codemirror_mode": {
1044
+ "name": "ipython",
1045
+ "version": 3
1046
+ },
1047
+ "file_extension": ".py",
1048
+ "mimetype": "text/x-python",
1049
+ "name": "python",
1050
+ "nbconvert_exporter": "python",
1051
+ "pygments_lexer": "ipython3",
1052
+ "version": "3.11.9"
1053
+ },
1054
+ "orig_nbformat": 4
1055
+ },
1056
+ "nbformat": 4,
1057
+ "nbformat_minor": 2
1058
+ }