aj2614@nyu.edu commited on
Commit
3d7c786
β€’
1 Parent(s): 54556be

who knows if this is a good idea

Browse files
Files changed (30) hide show
  1. language-models-project/Dockerfile β†’ Dockerfile +0 -0
  2. language-models-project/Pipfile β†’ Pipfile +0 -0
  3. language-models-project/Pipfile.lock β†’ Pipfile.lock +0 -0
  4. language-models-project/README.md β†’ README.md +0 -0
  5. language-models-project/Screenshot 2023-03-26 235302.png β†’ Screenshot 2023-03-26 235302.png +0 -0
  6. assignment-2/README.md +0 -0
  7. assignment-2/assignment_2/GML.py +0 -305
  8. assignment-2/assignment_2/Gaussian Maximum Likelihood.ipynb +0 -1676
  9. assignment-2/data/01_raw/nyc_bb_bicyclist_counts.csv +0 -215
  10. assignment-2/data/nyc_bb_bicyclist_counts.csv +0 -215
  11. assignment-2/poetry.lock +0 -0
  12. assignment-2/pyproject.toml +0 -23
  13. language-models-project/docker-compose.yml β†’ docker-compose.yml +0 -0
  14. language-models-project/language_models_project/__init__.py +0 -0
  15. language-models-project/tests/__init__.py +0 -0
  16. {assignment-2/assignment_2 β†’ language_models_project}/__init__.py +0 -0
  17. language_models_project/app.py +47 -0
  18. {language-models-project/language_models_project β†’ language_models_project}/main.py +0 -0
  19. midterm/Pipfile +0 -19
  20. midterm/Pipfile.lock +0 -0
  21. midterm/README.md +0 -0
  22. midterm/data/01_raw/CBC_data.csv +0 -0
  23. midterm/midterm/__init__.py +0 -0
  24. midterm/midterm/take_at_home_(1).ipynb +0 -0
  25. midterm/pyproject.toml +0 -22
  26. midterm/tests/__init__.py +0 -0
  27. language-models-project/poetry.lock β†’ poetry.lock +0 -0
  28. language-models-project/pyproject.toml β†’ pyproject.toml +0 -0
  29. {assignment-2/tests β†’ tests}/__init__.py +0 -0
  30. {language-models-project/tests β†’ tests}/test_classifier.py +0 -0
language-models-project/Dockerfile β†’ Dockerfile RENAMED
File without changes
language-models-project/Pipfile β†’ Pipfile RENAMED
File without changes
language-models-project/Pipfile.lock β†’ Pipfile.lock RENAMED
File without changes
language-models-project/README.md β†’ README.md RENAMED
File without changes
language-models-project/Screenshot 2023-03-26 235302.png β†’ Screenshot 2023-03-26 235302.png RENAMED
File without changes
assignment-2/README.md DELETED
File without changes
assignment-2/assignment_2/GML.py DELETED
@@ -1,305 +0,0 @@
1
- ## imports
2
- import numpy as np
3
- import pandas as pd
4
- from scipy.optimize import minimize
5
- from scipy.stats import norm
6
- import math
7
-
8
-
9
- ## Problem 1
10
- data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]
11
-
12
- data_mean = np.mean(data)
13
- data_variance = np.var(data)
14
-
15
-
16
- mu = 0.5
17
- sigma = 0.5
18
- w = np.array([mu, sigma])
19
-
20
- w_star = np.array([data_mean, data_variance])
21
- mu_star = data_mean
22
- sigma_star = np.sqrt(data_variance)
23
- offset = 10 * np.random.random(2)
24
-
25
- w1p = w_star + 0.5 * offset
26
- w1n = w_star - 0.5 * offset
27
- w2p = w_star + 0.25 * offset
28
- w2n = w_star - 0.25 * offset
29
-
30
- # Negative Log Likelihood is defined as follows:
31
- # $-\ln(\frac{1}{\sqrt{2\pi\sigma^2}}\exp(-\frac{1}{2}\frac{(x-\mu)}{\sigma}^2))$.
32
- # Ignoring the contribution of the constant, we find that $\frac{\delta}{\delta
33
- # \mu} \mathcal{N} = \frac{\mu-x}{\sigma^2}$ and $\frac{\delta}{\delta \sigma}
34
- # \mathcal{N} = \frac{\sigma^2 + (\mu-x)^2 - \sigma^2}{\sigma^3}$. We apply these as our step functions for our SGD.
35
-
36
- loss = lambda mu, sigma, x: np.sum(
37
- [-np.log(norm.pdf(xi, loc=mu, scale=sigma)) for xi in x]
38
- )
39
-
40
- loss_2_electric_boogaloo = lambda mu, sigma, x: -len(x) / 2 * np.log(
41
- 2 * np.pi * sigma**2
42
- ) - 1 / (2 * sigma**2) * np.sum((x - mu) ** 2)
43
-
44
-
45
- dmu = lambda mu, sigma, x: -np.sum([mu - xi for xi in x]) / (sigma**2)
46
- dsigma = lambda mu, sigma, x: -len(x) / sigma + np.sum([(mu - xi) ** 2 for xi in x]) / (
47
- sigma**3
48
- )
49
-
50
- log = []
51
-
52
-
53
- def SGD_problem1(mu, sigma, x, learning_rate=0.01, n_epochs=1000):
54
- global log
55
- log = []
56
- for epoch in range(n_epochs):
57
- mu += learning_rate * dmu(mu, sigma, x)
58
- sigma += learning_rate * dsigma(mu, sigma, x)
59
-
60
- # print(f"Epoch {epoch}, Loss: {loss(mu, sigma, x)}, New mu: {mu}, New sigma: {sigma}")
61
- log.append(
62
- {
63
- "Epoch": epoch,
64
- "Loss": loss(mu, sigma, x),
65
- "Loss 2 Alternative": loss_2_alternative(mu, sigma, x),
66
- "New mu": mu,
67
- "New sigma": sigma,
68
- }
69
- )
70
- return np.array([mu, sigma])
71
-
72
-
73
- def debug_SGD_1(wnn, data):
74
- print("SGD Problem 1")
75
- print("wnn", SGD_problem1(*wnn, data))
76
- dflog = pd.DataFrame(log)
77
- dflog["mu_star"] = mu_star
78
- dflog["mu_std"] = sigma_star
79
- print(f"mu diff at start {dflog.iloc[0]['New mu'] - dflog.iloc[0]['mu_star']}")
80
- print(f"mu diff at end {dflog.iloc[-1]['New mu'] - dflog.iloc[-1]['mu_star']}")
81
- if np.abs(dflog.iloc[-1]["New mu"] - dflog.iloc[-1]["mu_star"]) < np.abs(
82
- dflog.iloc[0]["New mu"] - dflog.iloc[0]["mu_star"]
83
- ):
84
- print("mu is improving")
85
- else:
86
- print("mu is not improving")
87
-
88
- print(f"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['mu_std']}")
89
- print(f"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['mu_std']}")
90
- if np.abs(dflog.iloc[-1]["New sigma"] - dflog.iloc[-1]["mu_std"]) < np.abs(
91
- dflog.iloc[0]["New sigma"] - dflog.iloc[0]["mu_std"]
92
- ):
93
- print("sigma is improving")
94
- else:
95
- print("sigma is not improving")
96
-
97
- return dflog
98
-
99
-
100
- # _ = debug_SGD_1(w1p, data)
101
- # _ = debug_SGD_1(w1n, data)
102
- # _ = debug_SGD_1(w2p, data)
103
- # _ = debug_SGD_1(w2n, data)
104
-
105
-
106
- # TODO EXPLAIN WHY += WORKS HERE.
107
-
108
-
109
- ## Problem 2
110
- x = np.array([8, 16, 22, 33, 50, 51])
111
- y = np.array([5, 20, 14, 32, 42, 58])
112
-
113
- # $-\frac{n}{\sigma}+\frac{1}{\sigma^3}\sum_{i=1}^n(y_i - (mx+c))^2$
114
- dsigma = lambda sigma, c, m, x: -len(x) / sigma + np.sum(
115
- [(xi - (m * x + c)) ** 2 for xi in x]
116
- ) / (sigma**3)
117
- # $-\frac{1}{\sigma^2}\sum_{i=1}^n(y_i - (mx+c))$
118
- dc = lambda sigma, c, m, x: -np.sum([xi - (m * x + c) for xi in x]) / (sigma**2)
119
- # $-\frac{1}{\sigma^2}\sum_{i=1}^n(x_i(y_i - (mx+c)))$
120
- dm = lambda sigma, c, m, x: -np.sum([x * (xi - (m * x + c)) for xi in x]) / (sigma**2)
121
-
122
-
123
- log2 = []
124
-
125
-
126
- def SGD_problem2(
127
- sigma: float,
128
- c: float,
129
- m: float,
130
- x: np.array,
131
- y: np.array,
132
- learning_rate=0.01,
133
- n_epochs=1000,
134
- ):
135
- global log2
136
- log2 = []
137
- for epoch in range(n_epochs):
138
- sigma += learning_rate * dsigma(sigma, c, m, x)
139
- c += learning_rate * dc(sigma, c, m, x)
140
- m += learning_rate * dm(sigma, c, m, x)
141
-
142
- log2.append(
143
- {
144
- "Epoch": epoch,
145
- "New sigma": sigma,
146
- "New c": c,
147
- "New m": m,
148
- "dc": dc(sigma, c, m, x),
149
- "dm": dm(sigma, c, m, x),
150
- "dsigma": dsigma(sigma, c, m, x),
151
- "Loss": loss((m * x + c), sigma, y),
152
- }
153
- )
154
- print(f"Epoch {epoch}, Loss: {loss((m * x + c), sigma, y)}")
155
- return np.array([sigma, c, m])
156
-
157
-
158
- # def debug_SGD_2(wnn, data):
159
- # print("SGD Problem 2")
160
- # print("wnn", SGD_problem2(*wnn, data))
161
- # dflog = pd.DataFrame(log)
162
- # dflog["m_star"] = m_star
163
- # dflog["c_star"] = c_star
164
- # dflog["sigma_star"] = sigma_star
165
- # print(f"m diff at start {dflog.iloc[0]['New m'] - dflog.iloc[0]['m_star']}")
166
- # print(f"m diff at end {dflog.iloc[-1]['New m'] - dflog.iloc[-1]['m_star']}")
167
- # if np.abs(dflog.iloc[-1]["New m"] - dflog.iloc[-1]["m_star"]) < np.abs(
168
- # dflog.iloc[0]["New m"] - dflog.iloc[0]["m_star"]
169
- # ):
170
- # print("m is improving")
171
- # else:
172
- # print("m is not improving")
173
- # print(f"c diff at start {dflog.iloc[0]['New c'] - dflog.iloc[0]['c_star']}")
174
- # print(f"c diff at end {dflog.iloc[-1]['New c'] - dflog.iloc[-1]['c_star']}")
175
- # if np.abs(dflog.iloc[-1]["New c"] - dflog.iloc[-1]["c_star"]) < np.abs(
176
- # dflog.iloc[0]["New c"] - dflog.iloc[0]["c_star"]
177
- # ):
178
- # print("c is improving")
179
- # else:
180
- # print("c is not improving")
181
- # print(f"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['sigma_star']}")
182
- # print(f"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['sigma_star']}")
183
- # if np.abs(dflog.iloc[-1]["New sigma"] - dflog.iloc[-1]["sigma_star"]) < np.abs(
184
- # dflog.iloc[0]["New sigma"] - dflog.iloc[0]["sigma_star"]
185
- # ):
186
- # print("sigma is improving")
187
- # else:
188
- # print("sigma is not improving")
189
- # return dflog
190
-
191
- result = SGD_problem2(0.5, 0.5, 0.5, x, y)
192
- print(f"final parameters: m={result[2]}, c={result[1]}, sigma={result[0]}")
193
-
194
-
195
- ## pset2
196
- # Knowing that the poisson pdf is $P(k) = \frac{\lambda^k e^{-\lambda}}{k!}$, we can find the negative log likelihood of the data as $-\log(\Pi_{i=1}^n P(k_i)) = -\sum_{i=1}^n \log(\frac{\lambda^k_i e^{-\lambda}}{k_i!}) = \sum_{i=1}^n -\ln(\lambda) k_i + \ln(k_i!) + \lambda$. Which simplified, gives $n\lambda + \sum_{i=1}^n \ln(k_i!) - \sum_{i=1}^n k_i \ln(\lambda)$. Differentiating with respect to $\lambda$ gives $n - \sum_{i=1}^n \frac{k_i}{\lambda}$. Which is our desired $\frac{\partial L}{\partial \lambda}$!
197
-
198
-
199
- import pandas as pd
200
-
201
- df = pd.read_csv("../data/01_raw/nyc_bb_bicyclist_counts.csv")
202
-
203
- dlambda = lambda l, k: len(k) - np.sum([ki / l for ki in k])
204
-
205
-
206
- def SGD_problem3(
207
- l: float,
208
- k: np.array,
209
- learning_rate=0.01,
210
- n_epochs=1000,
211
- ):
212
- global log3
213
- log3 = []
214
- for epoch in range(n_epochs):
215
- l -= learning_rate * dlambda(l, k)
216
- # $n\lambda + \sum_{i=1}^n \ln(k_i!) - \sum_{i=1}^n k_i \ln(\lambda)$
217
- # the rest of the loss function is commented out because it's a
218
- # constant and was causing overflows. It is unnecessary, and a useless
219
- # pain.
220
- loss = len(k) * l - np.sum(
221
- [ki * np.log(l) for ki in k]
222
- ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])
223
-
224
- log3.append(
225
- {
226
- "Epoch": epoch,
227
- "New lambda": l,
228
- "dlambda": dlambda(l, k),
229
- "Loss": loss,
230
- }
231
- )
232
- print(f"Epoch {epoch}", f"Loss: {loss}")
233
- return np.array([l])
234
-
235
-
236
- l_star = df["BB_COUNT"].mean()
237
-
238
-
239
- def debug_SGD_3(data, l=1000):
240
- print("SGD Problem 3")
241
- print(f"l: {SGD_problem3(l, data)}")
242
- dflog = pd.DataFrame(log3)
243
- dflog["l_star"] = l_star
244
- print(f"l diff at start {dflog.iloc[0]['New lambda'] - dflog.iloc[0]['l_star']}")
245
- print(f"l diff at end {dflog.iloc[-1]['New lambda'] - dflog.iloc[-1]['l_star']}")
246
- if np.abs(dflog.iloc[-1]["New lambda"] - dflog.iloc[-1]["l_star"]) < np.abs(
247
- dflog.iloc[0]["New lambda"] - dflog.iloc[0]["l_star"]
248
- ):
249
- print("l is improving")
250
- else:
251
- print("l is not improving")
252
- return dflog
253
-
254
-
255
- debug_SGD_3(data=df["BB_COUNT"].values, l=l_star + 1000)
256
- debug_SGD_3(data=df["BB_COUNT"].values, l=l_star - 1000)
257
-
258
-
259
- ## pset 4
260
-
261
- # dw = lambda w, x: len(x) * np.exp(np.dot(x, w)) * x - np.sum()
262
-
263
- primitive = lambda xi, wi: (x.shape[0] * np.exp(wi * xi) * xi) - (xi**2)
264
- p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])
265
-
266
-
267
- def SGD_problem4(
268
- w: np.array,
269
- x: np.array,
270
- learning_rate=0.01,
271
- n_epochs=1000,
272
- ):
273
- global log4
274
- log4 = []
275
- for epoch in range(n_epochs):
276
- w -= learning_rate * p_dw(w, x)
277
- # custom
278
- # loss = x.shape[0] * np.exp(np.dot(x, w))
279
- loss_fn = lambda k, l: len(k) * l - np.sum(
280
- [ki * np.log(l) for ki in k]
281
- ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])
282
- loss = loss_fn(x, np.exp(np.dot(x, w)))
283
- log4.append(
284
- {
285
- "Epoch": epoch,
286
- "New w": w,
287
- "dw": dw(w, x),
288
- "Loss": loss,
289
- }
290
- )
291
- print(f"Epoch {epoch}", f"Loss: {loss}")
292
- return w
293
-
294
-
295
- def debug_SGD_3(data, w=np.array([1, 1])):
296
- print("SGD Problem 4")
297
- print(f"w: {SGD_problem4(w, data)}")
298
- dflog = pd.DataFrame(log4)
299
- return dflog
300
-
301
-
302
- _ = debug_SGD_3(
303
- data=df[["HIGH_T", "LOW_T", "PRECIP"]].to_numpy(),
304
- w=np.array([1.0, 1.0, 1.0]),
305
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assignment-2/assignment_2/Gaussian Maximum Likelihood.ipynb DELETED
@@ -1,1676 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "fce8933f-4594-4bb5-bffe-86fcb9ddd684",
6
- "metadata": {},
7
- "source": [
8
- "# MLE of a Gaussian $p_{model}(x|w)$"
9
- ]
10
- },
11
- {
12
- "cell_type": "code",
13
- "execution_count": 4,
14
- "id": "f6cd23f0-e755-48af-be5e-aaee83dda1e7",
15
- "metadata": {
16
- "tags": []
17
- },
18
- "outputs": [],
19
- "source": [
20
- "import numpy as np\n",
21
- "\n",
22
- "data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n",
23
- "\n",
24
- "\n",
25
- "## imports\n",
26
- "import numpy as np\n",
27
- "import pandas as pd\n",
28
- "from scipy.optimize import minimize\n",
29
- "from scipy.stats import norm\n",
30
- "import math\n",
31
- "\n",
32
- "\n",
33
- "## Problem 1\n",
34
- "data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n",
35
- "\n",
36
- "data_mean = np.mean(data)\n",
37
- "data_variance = np.var(data)\n",
38
- "\n",
39
- "\n",
40
- "mu = 0.5\n",
41
- "sigma = 0.5\n",
42
- "w = np.array([mu, sigma])\n",
43
- "\n",
44
- "w_star = np.array([data_mean, data_variance])\n",
45
- "mu_star = data_mean\n",
46
- "sigma_star = np.sqrt(data_variance)\n",
47
- "offset = 10 * np.random.random(2)\n",
48
- "\n",
49
- "w1p = w_star + 0.5 * offset\n",
50
- "w1n = w_star - 0.5 * offset\n",
51
- "w2p = w_star + 0.25 * offset\n",
52
- "w2n = w_star - 0.25 * offset"
53
- ]
54
- },
55
- {
56
- "cell_type": "markdown",
57
- "id": "f3d8587b-3862-4e98-bbcc-99d57bb313c1",
58
- "metadata": {},
59
- "source": [
60
- "Negative Log Likelihood is defined as follows: $-\\ln(\\frac{1}{\\sqrt{2\\pi\\sigma^2}}\\exp(-\\frac{1}{2}\\frac{(x-\\mu)}{\\sigma}^2))$. Ignoring the contribution of the constant, we find that $\\frac{\\delta}{\\delta \\mu} \\mathcal{N} = \\frac{\\mu-x}{\\sigma^2}$ and $\\frac{\\delta}{\\delta \\sigma} \\mathcal{N} = \\frac{\\sigma^2 + (\\mu-x)^2 - \\sigma^2}{\\sigma^3}$. We apply these as our step functions for our SGD. "
61
- ]
62
- },
63
- {
64
- "cell_type": "code",
65
- "execution_count": 5,
66
- "id": "27bf27ad-031e-4b65-a44d-53c5c1a09d91",
67
- "metadata": {
68
- "tags": []
69
- },
70
- "outputs": [],
71
- "source": [
72
- "loss = lambda mu, sigma, x: np.sum(\n",
73
- " [-np.log(norm.pdf(xi, loc=mu, scale=sigma)) for xi in x]\n",
74
- ")\n",
75
- "\n",
76
- "loss_2_alternative = lambda mu, sigma, x: -len(x) / 2 * np.log(\n",
77
- " 2 * np.pi * sigma**2\n",
78
- ") - 1 / (2 * sigma**2) * np.sum((x - mu) ** 2)\n",
79
- "\n",
80
- "\n",
81
- "dmu = lambda mu, sigma, x: -np.sum([mu - xi for xi in x]) / (sigma**2)\n",
82
- "dsigma = lambda mu, sigma, x: -len(x) / sigma + np.sum([(mu - xi) ** 2 for xi in x]) / (sigma**3)\n",
83
- "\n",
84
- "log = []\n",
85
- "def SGD_problem1(mu, sigma, x, learning_rate=0.01, n_epochs=1000):\n",
86
- " global log\n",
87
- " log = []\n",
88
- " for epoch in range(n_epochs):\n",
89
- " mu += learning_rate * dmu(mu, sigma, x)\n",
90
- " sigma += learning_rate * dsigma(mu, sigma, x)\n",
91
- "\n",
92
- " # print(f\"Epoch {epoch}, Loss: {loss(mu, sigma, x)}, New mu: {mu}, New sigma: {sigma}\")\n",
93
- " log.append(\n",
94
- " {\n",
95
- " \"Epoch\": epoch,\n",
96
- " \"Loss\": loss(mu, sigma, x),\n",
97
- " \"Loss 2 Alternative\": loss_2_alternative(mu, sigma, x),\n",
98
- " \"New mu\": mu,\n",
99
- " \"New sigma\": sigma,\n",
100
- " }\n",
101
- " )\n",
102
- " return np.array([mu, sigma])\n",
103
- "\n",
104
- "\n",
105
- "def debug_SGD_1(wnn, data):\n",
106
- " print(\"SGD Problem 1\")\n",
107
- " print(\"wnn\", SGD_problem1(*wnn, data))\n",
108
- " dflog = pd.DataFrame(log)\n",
109
- " dflog[\"mu_star\"] = mu_star\n",
110
- " dflog[\"mu_std\"] = sigma_star\n",
111
- " print(f\"mu diff at start {dflog.iloc[0]['New mu'] - dflog.iloc[0]['mu_star']}\")\n",
112
- " print(f\"mu diff at end {dflog.iloc[-1]['New mu'] - dflog.iloc[-1]['mu_star']}\")\n",
113
- " if np.abs(dflog.iloc[-1][\"New mu\"] - dflog.iloc[-1][\"mu_star\"]) < np.abs(\n",
114
- " dflog.iloc[0][\"New mu\"] - dflog.iloc[0][\"mu_star\"]\n",
115
- " ):\n",
116
- " print(\"mu is improving\")\n",
117
- " else:\n",
118
- " print(\"mu is not improving\")\n",
119
- "\n",
120
- " print(f\"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['mu_std']}\")\n",
121
- " print(f\"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['mu_std']}\")\n",
122
- " if np.abs(dflog.iloc[-1][\"New sigma\"] - dflog.iloc[-1][\"mu_std\"]) < np.abs(\n",
123
- " dflog.iloc[0][\"New sigma\"] - dflog.iloc[0][\"mu_std\"]\n",
124
- " ):\n",
125
- " print(\"sigma is improving\")\n",
126
- " else:\n",
127
- " print(\"sigma is not improving\")\n",
128
- "\n",
129
- " return dflog"
130
- ]
131
- },
132
- {
133
- "cell_type": "code",
134
- "execution_count": 6,
135
- "id": "27dd3bc6-b96e-4f8b-9118-01ad344dfd6a",
136
- "metadata": {
137
- "tags": []
138
- },
139
- "outputs": [
140
- {
141
- "name": "stdout",
142
- "output_type": "stream",
143
- "text": [
144
- "SGD Problem 1\n",
145
- "wnn [6.2142858 2.42541812]\n",
146
- "mu diff at start 0.27610721776969527\n",
147
- "mu diff at end 8.978893806244059e-08\n",
148
- "mu is improving\n",
149
- "sigma diff at start 8.134860851821205\n",
150
- "sigma diff at end 1.7124079931818414e-12\n",
151
- "sigma is improving\n",
152
- "SGD Problem 1\n",
153
- "wnn [6.21428571 2.42541812]\n",
154
- "mu diff at start -0.24923650064862635\n",
155
- "mu diff at end -6.602718372050731e-12\n",
156
- "mu is improving\n",
157
- "sigma diff at start -0.859536014291925\n",
158
- "sigma diff at end -3.552713678800501e-15\n",
159
- "sigma is improving\n",
160
- "SGD Problem 1\n",
161
- "wnn [6.21428572 2.42541812]\n",
162
- "mu diff at start 0.13794086144778994\n",
163
- "mu diff at end 1.0008935902305893e-09\n",
164
- "mu is improving\n",
165
- "sigma diff at start 5.786783512688555\n",
166
- "sigma diff at end 4.440892098500626e-15\n",
167
- "sigma is improving\n",
168
- "SGD Problem 1\n",
169
- "wnn [6.21428571 2.42541812]\n",
170
- "mu diff at start -0.13668036978891251\n",
171
- "mu diff at end -8.528289185960602e-12\n",
172
- "mu is improving\n",
173
- "sigma diff at start 1.091241177336173\n",
174
- "sigma diff at end 4.440892098500626e-15\n",
175
- "sigma is improving\n"
176
- ]
177
- }
178
- ],
179
- "source": [
180
- "_ = debug_SGD_1(w1p, data)\n",
181
- "_ = debug_SGD_1(w1n, data)\n",
182
- "_ = debug_SGD_1(w2p, data)\n",
183
- "_ = debug_SGD_1(w2n, data)"
184
- ]
185
- },
186
- {
187
- "cell_type": "markdown",
188
- "id": "30096401-0bd5-4cf6-b093-a688476e16f1",
189
- "metadata": {
190
- "tags": []
191
- },
192
- "source": [
193
- "# MLE of Conditional Gaussian"
194
- ]
195
- },
196
- {
197
- "cell_type": "markdown",
198
- "id": "101a3c5e-1e02-41e6-9eab-aba65c39627a",
199
- "metadata": {},
200
- "source": [
201
- "dsigma = $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$ \n",
202
- "dc = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$ \n",
203
- "dm = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$ "
204
- ]
205
- },
206
- {
207
- "cell_type": "code",
208
- "execution_count": 8,
209
- "id": "21969012-f81b-43d4-975d-13411e975f8f",
210
- "metadata": {
211
- "collapsed": true,
212
- "jupyter": {
213
- "outputs_hidden": true
214
- },
215
- "tags": []
216
- },
217
- "outputs": [
218
- {
219
- "name": "stdout",
220
- "output_type": "stream",
221
- "text": [
222
- "Epoch 0, Loss: 297.82677563555086\n",
223
- "Epoch 1, Loss: 297.8267749215061\n",
224
- "Epoch 2, Loss: 297.82677420752475\n",
225
- "Epoch 3, Loss: 297.82677349360694\n",
226
- "Epoch 4, Loss: 297.8267727797526\n",
227
- "Epoch 5, Loss: 297.8267720659618\n",
228
- "Epoch 6, Loss: 297.82677135223446\n",
229
- "Epoch 7, Loss: 297.82677063857074\n",
230
- "Epoch 8, Loss: 297.8267699249706\n",
231
- "Epoch 9, Loss: 297.826769211434\n",
232
- "Epoch 10, Loss: 297.8267684979611\n",
233
- "Epoch 11, Loss: 297.82676778455175\n",
234
- "Epoch 12, Loss: 297.82676707120623\n",
235
- "Epoch 13, Loss: 297.82676635792427\n",
236
- "Epoch 14, Loss: 297.8267656447061\n",
237
- "Epoch 15, Loss: 297.8267649315517\n",
238
- "Epoch 16, Loss: 297.82676421846105\n",
239
- "Epoch 17, Loss: 297.82676350543414\n",
240
- "Epoch 18, Loss: 297.82676279247113\n",
241
- "Epoch 19, Loss: 297.82676207957184\n",
242
- "Epoch 20, Loss: 297.82676136673643\n",
243
- "Epoch 21, Loss: 297.826760653965\n",
244
- "Epoch 22, Loss: 297.82675994125736\n",
245
- "Epoch 23, Loss: 297.8267592286137\n",
246
- "Epoch 24, Loss: 297.8267585160339\n",
247
- "Epoch 25, Loss: 297.8267578035182\n",
248
- "Epoch 26, Loss: 297.82675709106655\n",
249
- "Epoch 27, Loss: 297.8267563786788\n",
250
- "Epoch 28, Loss: 297.82675566635504\n",
251
- "Epoch 29, Loss: 297.82675495409546\n",
252
- "Epoch 30, Loss: 297.82675424189983\n",
253
- "Epoch 31, Loss: 297.82675352976844\n",
254
- "Epoch 32, Loss: 297.8267528177011\n",
255
- "Epoch 33, Loss: 297.82675210569795\n",
256
- "Epoch 34, Loss: 297.826751393759\n",
257
- "Epoch 35, Loss: 297.8267506818843\n",
258
- "Epoch 36, Loss: 297.8267499700737\n",
259
- "Epoch 37, Loss: 297.8267492583274\n",
260
- "Epoch 38, Loss: 297.82674854664543\n",
261
- "Epoch 39, Loss: 297.8267478350277\n",
262
- "Epoch 40, Loss: 297.8267471234743\n",
263
- "Epoch 41, Loss: 297.82674641198514\n",
264
- "Epoch 42, Loss: 297.8267457005605\n",
265
- "Epoch 43, Loss: 297.82674498920017\n",
266
- "Epoch 44, Loss: 297.82674427790425\n",
267
- "Epoch 45, Loss: 297.82674356667275\n",
268
- "Epoch 46, Loss: 297.82674285550564\n",
269
- "Epoch 47, Loss: 297.8267421444032\n",
270
- "Epoch 48, Loss: 297.82674143336516\n",
271
- "Epoch 49, Loss: 297.8267407223916\n",
272
- "Epoch 50, Loss: 297.8267400114827\n",
273
- "Epoch 51, Loss: 297.82673930063817\n",
274
- "Epoch 52, Loss: 297.8267385898584\n",
275
- "Epoch 53, Loss: 297.8267378791432\n",
276
- "Epoch 54, Loss: 297.8267371684925\n",
277
- "Epoch 55, Loss: 297.8267364579067\n",
278
- "Epoch 56, Loss: 297.8267357473855\n",
279
- "Epoch 57, Loss: 297.82673503692894\n",
280
- "Epoch 58, Loss: 297.82673432653723\n",
281
- "Epoch 59, Loss: 297.8267336162103\n",
282
- "Epoch 60, Loss: 297.826732905948\n",
283
- "Epoch 61, Loss: 297.82673219575054\n",
284
- "Epoch 62, Loss: 297.826731485618\n",
285
- "Epoch 63, Loss: 297.82673077555023\n",
286
- "Epoch 64, Loss: 297.82673006554734\n",
287
- "Epoch 65, Loss: 297.8267293556093\n",
288
- "Epoch 66, Loss: 297.82672864573624\n",
289
- "Epoch 67, Loss: 297.82672793592815\n",
290
- "Epoch 68, Loss: 297.82672722618497\n",
291
- "Epoch 69, Loss: 297.8267265165068\n",
292
- "Epoch 70, Loss: 297.8267258068936\n",
293
- "Epoch 71, Loss: 297.8267250973455\n",
294
- "Epoch 72, Loss: 297.8267243878624\n",
295
- "Epoch 73, Loss: 297.8267236784444\n",
296
- "Epoch 74, Loss: 297.82672296909146\n",
297
- "Epoch 75, Loss: 297.8267222598038\n",
298
- "Epoch 76, Loss: 297.8267215505811\n",
299
- "Epoch 77, Loss: 297.8267208414237\n",
300
- "Epoch 78, Loss: 297.82672013233156\n",
301
- "Epoch 79, Loss: 297.8267194233045\n",
302
- "Epoch 80, Loss: 297.8267187143427\n",
303
- "Epoch 81, Loss: 297.8267180054462\n",
304
- "Epoch 82, Loss: 297.82671729661496\n",
305
- "Epoch 83, Loss: 297.82671658784903\n",
306
- "Epoch 84, Loss: 297.8267158791485\n",
307
- "Epoch 85, Loss: 297.82671517051335\n",
308
- "Epoch 86, Loss: 297.8267144619435\n",
309
- "Epoch 87, Loss: 297.8267137534391\n",
310
- "Epoch 88, Loss: 297.82671304500013\n",
311
- "Epoch 89, Loss: 297.82671233662654\n",
312
- "Epoch 90, Loss: 297.82671162831855\n",
313
- "Epoch 91, Loss: 297.8267109200759\n",
314
- "Epoch 92, Loss: 297.82671021189896\n",
315
- "Epoch 93, Loss: 297.8267095037876\n",
316
- "Epoch 94, Loss: 297.8267087957417\n",
317
- "Epoch 95, Loss: 297.82670808776135\n",
318
- "Epoch 96, Loss: 297.82670737984665\n",
319
- "Epoch 97, Loss: 297.8267066719976\n",
320
- "Epoch 98, Loss: 297.82670596421417\n",
321
- "Epoch 99, Loss: 297.8267052564966\n",
322
- "Epoch 100, Loss: 297.82670454884465\n",
323
- "Epoch 101, Loss: 297.82670384125834\n",
324
- "Epoch 102, Loss: 297.8267031337379\n",
325
- "Epoch 103, Loss: 297.8267024262833\n",
326
- "Epoch 104, Loss: 297.82670171889436\n",
327
- "Epoch 105, Loss: 297.8267010115713\n",
328
- "Epoch 106, Loss: 297.82670030431416\n",
329
- "Epoch 107, Loss: 297.82669959712285\n",
330
- "Epoch 108, Loss: 297.82669888999743\n",
331
- "Epoch 109, Loss: 297.8266981829379\n",
332
- "Epoch 110, Loss: 297.8266974759444\n",
333
- "Epoch 111, Loss: 297.8266967690169\n",
334
- "Epoch 112, Loss: 297.8266960621552\n",
335
- "Epoch 113, Loss: 297.8266953553597\n",
336
- "Epoch 114, Loss: 297.82669464863017\n",
337
- "Epoch 115, Loss: 297.82669394196677\n",
338
- "Epoch 116, Loss: 297.8266932353694\n",
339
- "Epoch 117, Loss: 297.82669252883824\n",
340
- "Epoch 118, Loss: 297.8266918223731\n",
341
- "Epoch 119, Loss: 297.8266911159742\n",
342
- "Epoch 120, Loss: 297.82669040964146\n",
343
- "Epoch 121, Loss: 297.8266897033749\n",
344
- "Epoch 122, Loss: 297.8266889971746\n",
345
- "Epoch 123, Loss: 297.82668829104057\n",
346
- "Epoch 124, Loss: 297.8266875849728\n",
347
- "Epoch 125, Loss: 297.8266868789714\n",
348
- "Epoch 126, Loss: 297.8266861730362\n",
349
- "Epoch 127, Loss: 297.8266854671674\n",
350
- "Epoch 128, Loss: 297.826684761365\n",
351
- "Epoch 129, Loss: 297.82668405562896\n",
352
- "Epoch 130, Loss: 297.8266833499594\n",
353
- "Epoch 131, Loss: 297.8266826443563\n",
354
- "Epoch 132, Loss: 297.8266819388196\n",
355
- "Epoch 133, Loss: 297.82668123334935\n",
356
- "Epoch 134, Loss: 297.8266805279458\n",
357
- "Epoch 135, Loss: 297.8266798226087\n",
358
- "Epoch 136, Loss: 297.82667911733813\n",
359
- "Epoch 137, Loss: 297.8266784121342\n",
360
- "Epoch 138, Loss: 297.82667770699686\n",
361
- "Epoch 139, Loss: 297.82667700192616\n",
362
- "Epoch 140, Loss: 297.8266762969221\n",
363
- "Epoch 141, Loss: 297.8266755919847\n",
364
- "Epoch 142, Loss: 297.82667488711405\n",
365
- "Epoch 143, Loss: 297.8266741823101\n",
366
- "Epoch 144, Loss: 297.82667347757297\n",
367
- "Epoch 145, Loss: 297.82667277290255\n",
368
- "Epoch 146, Loss: 297.82667206829893\n",
369
- "Epoch 147, Loss: 297.8266713637622\n",
370
- "Epoch 148, Loss: 297.8266706592923\n",
371
- "Epoch 149, Loss: 297.8266699548893\n",
372
- "Epoch 150, Loss: 297.8266692505533\n",
373
- "Epoch 151, Loss: 297.82666854628405\n",
374
- "Epoch 152, Loss: 297.82666784208175\n",
375
- "Epoch 153, Loss: 297.8266671379465\n",
376
- "Epoch 154, Loss: 297.8266664338782\n",
377
- "Epoch 155, Loss: 297.82666572987705\n",
378
- "Epoch 156, Loss: 297.8266650259427\n",
379
- "Epoch 157, Loss: 297.82666432207566\n",
380
- "Epoch 158, Loss: 297.82666361827546\n",
381
- "Epoch 159, Loss: 297.8266629145426\n",
382
- "Epoch 160, Loss: 297.8266622108768\n",
383
- "Epoch 161, Loss: 297.8266615072782\n",
384
- "Epoch 162, Loss: 297.8266608037467\n",
385
- "Epoch 163, Loss: 297.8266601002824\n",
386
- "Epoch 164, Loss: 297.8266593968855\n",
387
- "Epoch 165, Loss: 297.8266586935557\n",
388
- "Epoch 166, Loss: 297.82665799029326\n",
389
- "Epoch 167, Loss: 297.82665728709804\n",
390
- "Epoch 168, Loss: 297.82665658397036\n",
391
- "Epoch 169, Loss: 297.8266558809098\n",
392
- "Epoch 170, Loss: 297.82665517791673\n",
393
- "Epoch 171, Loss: 297.8266544749911\n",
394
- "Epoch 172, Loss: 297.82665377213283\n",
395
- "Epoch 173, Loss: 297.826653069342\n",
396
- "Epoch 174, Loss: 297.8266523666187\n",
397
- "Epoch 175, Loss: 297.82665166396293\n",
398
- "Epoch 176, Loss: 297.82665096137464\n",
399
- "Epoch 177, Loss: 297.8266502588538\n",
400
- "Epoch 178, Loss: 297.82664955640064\n",
401
- "Epoch 179, Loss: 297.82664885401516\n",
402
- "Epoch 180, Loss: 297.82664815169716\n",
403
- "Epoch 181, Loss: 297.8266474494469\n",
404
- "Epoch 182, Loss: 297.8266467472642\n",
405
- "Epoch 183, Loss: 297.8266460451493\n",
406
- "Epoch 184, Loss: 297.82664534310203\n",
407
- "Epoch 185, Loss: 297.82664464112264\n",
408
- "Epoch 186, Loss: 297.82664393921095\n",
409
- "Epoch 187, Loss: 297.82664323736697\n",
410
- "Epoch 188, Loss: 297.8266425355909\n",
411
- "Epoch 189, Loss: 297.8266418338827\n",
412
- "Epoch 190, Loss: 297.82664113224223\n",
413
- "Epoch 191, Loss: 297.8266404306698\n",
414
- "Epoch 192, Loss: 297.82663972916515\n",
415
- "Epoch 193, Loss: 297.82663902772856\n",
416
- "Epoch 194, Loss: 297.82663832635984\n",
417
- "Epoch 195, Loss: 297.8266376250591\n",
418
- "Epoch 196, Loss: 297.8266369238264\n",
419
- "Epoch 197, Loss: 297.82663622266176\n",
420
- "Epoch 198, Loss: 297.8266355215652\n",
421
- "Epoch 199, Loss: 297.8266348205367\n",
422
- "Epoch 200, Loss: 297.8266341195762\n",
423
- "Epoch 201, Loss: 297.82663341868397\n",
424
- "Epoch 202, Loss: 297.82663271785987\n",
425
- "Epoch 203, Loss: 297.82663201710386\n",
426
- "Epoch 204, Loss: 297.8266313164161\n",
427
- "Epoch 205, Loss: 297.82663061579655\n",
428
- "Epoch 206, Loss: 297.8266299152454\n",
429
- "Epoch 207, Loss: 297.8266292147624\n",
430
- "Epoch 208, Loss: 297.8266285143477\n",
431
- "Epoch 209, Loss: 297.82662781400137\n",
432
- "Epoch 210, Loss: 297.82662711372336\n",
433
- "Epoch 211, Loss: 297.8266264135138\n",
434
- "Epoch 212, Loss: 297.8266257133725\n",
435
- "Epoch 213, Loss: 297.8266250132998\n",
436
- "Epoch 214, Loss: 297.82662431329544\n",
437
- "Epoch 215, Loss: 297.8266236133595\n",
438
- "Epoch 216, Loss: 297.8266229134922\n",
439
- "Epoch 217, Loss: 297.8266222136934\n",
440
- "Epoch 218, Loss: 297.8266215139631\n",
441
- "Epoch 219, Loss: 297.8266208143014\n",
442
- "Epoch 220, Loss: 297.8266201147082\n",
443
- "Epoch 221, Loss: 297.8266194151838\n",
444
- "Epoch 222, Loss: 297.82661871572793\n",
445
- "Epoch 223, Loss: 297.82661801634066\n",
446
- "Epoch 224, Loss: 297.82661731702217\n",
447
- "Epoch 225, Loss: 297.82661661777234\n",
448
- "Epoch 226, Loss: 297.8266159185914\n",
449
- "Epoch 227, Loss: 297.8266152194791\n",
450
- "Epoch 228, Loss: 297.82661452043567\n",
451
- "Epoch 229, Loss: 297.82661382146097\n",
452
- "Epoch 230, Loss: 297.8266131225552\n",
453
- "Epoch 231, Loss: 297.82661242371825\n",
454
- "Epoch 232, Loss: 297.82661172495017\n",
455
- "Epoch 233, Loss: 297.82661102625104\n",
456
- "Epoch 234, Loss: 297.8266103276209\n",
457
- "Epoch 235, Loss: 297.8266096290597\n",
458
- "Epoch 236, Loss: 297.82660893056743\n",
459
- "Epoch 237, Loss: 297.8266082321442\n",
460
- "Epoch 238, Loss: 297.82660753379\n",
461
- "Epoch 239, Loss: 297.8266068355049\n",
462
- "Epoch 240, Loss: 297.82660613728893\n",
463
- "Epoch 241, Loss: 297.82660543914204\n",
464
- "Epoch 242, Loss: 297.8266047410642\n",
465
- "Epoch 243, Loss: 297.82660404305557\n",
466
- "Epoch 244, Loss: 297.82660334511615\n",
467
- "Epoch 245, Loss: 297.826602647246\n",
468
- "Epoch 246, Loss: 297.826601949445\n",
469
- "Epoch 247, Loss: 297.8266012517133\n",
470
- "Epoch 248, Loss: 297.82660055405086\n",
471
- "Epoch 249, Loss: 297.82659985645773\n",
472
- "Epoch 250, Loss: 297.82659915893396\n",
473
- "Epoch 251, Loss: 297.8265984614796\n",
474
- "Epoch 252, Loss: 297.8265977640945\n",
475
- "Epoch 253, Loss: 297.8265970667789\n",
476
- "Epoch 254, Loss: 297.8265963695328\n",
477
- "Epoch 255, Loss: 297.82659567235606\n",
478
- "Epoch 256, Loss: 297.8265949752488\n",
479
- "Epoch 257, Loss: 297.8265942782112\n",
480
- "Epoch 258, Loss: 297.82659358124295\n",
481
- "Epoch 259, Loss: 297.82659288434434\n",
482
- "Epoch 260, Loss: 297.82659218751525\n",
483
- "Epoch 261, Loss: 297.82659149075585\n",
484
- "Epoch 262, Loss: 297.826590794066\n",
485
- "Epoch 263, Loss: 297.8265900974459\n",
486
- "Epoch 264, Loss: 297.8265894008955\n",
487
- "Epoch 265, Loss: 297.82658870441463\n",
488
- "Epoch 266, Loss: 297.8265880080037\n",
489
- "Epoch 267, Loss: 297.8265873116626\n",
490
- "Epoch 268, Loss: 297.82658661539097\n",
491
- "Epoch 269, Loss: 297.8265859191893\n",
492
- "Epoch 270, Loss: 297.8265852230575\n",
493
- "Epoch 271, Loss: 297.8265845269956\n",
494
- "Epoch 272, Loss: 297.82658383100346\n",
495
- "Epoch 273, Loss: 297.82658313508136\n",
496
- "Epoch 274, Loss: 297.8265824392291\n",
497
- "Epoch 275, Loss: 297.8265817434468\n",
498
- "Epoch 276, Loss: 297.82658104773463\n",
499
- "Epoch 277, Loss: 297.82658035209226\n",
500
- "Epoch 278, Loss: 297.82657965652004\n",
501
- "Epoch 279, Loss: 297.8265789610178\n",
502
- "Epoch 280, Loss: 297.82657826558574\n",
503
- "Epoch 281, Loss: 297.8265775702237\n",
504
- "Epoch 282, Loss: 297.82657687493185\n",
505
- "Epoch 283, Loss: 297.8265761797103\n",
506
- "Epoch 284, Loss: 297.82657548455876\n",
507
- "Epoch 285, Loss: 297.82657478947743\n",
508
- "Epoch 286, Loss: 297.82657409446637\n",
509
- "Epoch 287, Loss: 297.8265733995255\n",
510
- "Epoch 288, Loss: 297.826572704655\n",
511
- "Epoch 289, Loss: 297.8265720098549\n",
512
- "Epoch 290, Loss: 297.82657131512497\n",
513
- "Epoch 291, Loss: 297.8265706204655\n",
514
- "Epoch 292, Loss: 297.82656992587636\n",
515
- "Epoch 293, Loss: 297.82656923135755\n",
516
- "Epoch 294, Loss: 297.82656853690935\n",
517
- "Epoch 295, Loss: 297.8265678425316\n",
518
- "Epoch 296, Loss: 297.82656714822417\n",
519
- "Epoch 297, Loss: 297.82656645398737\n",
520
- "Epoch 298, Loss: 297.8265657598211\n",
521
- "Epoch 299, Loss: 297.82656506572545\n",
522
- "Epoch 300, Loss: 297.8265643717003\n",
523
- "Epoch 301, Loss: 297.8265636777458\n",
524
- "Epoch 302, Loss: 297.8265629838619\n",
525
- "Epoch 303, Loss: 297.8265622900487\n",
526
- "Epoch 304, Loss: 297.8265615963061\n",
527
- "Epoch 305, Loss: 297.82656090263424\n",
528
- "Epoch 306, Loss: 297.8265602090333\n",
529
- "Epoch 307, Loss: 297.82655951550294\n",
530
- "Epoch 308, Loss: 297.82655882204335\n",
531
- "Epoch 309, Loss: 297.8265581286547\n",
532
- "Epoch 310, Loss: 297.82655743533684\n",
533
- "Epoch 311, Loss: 297.8265567420898\n",
534
- "Epoch 312, Loss: 297.8265560489138\n",
535
- "Epoch 313, Loss: 297.8265553558085\n",
536
- "Epoch 314, Loss: 297.82655466277436\n",
537
- "Epoch 315, Loss: 297.8265539698111\n",
538
- "Epoch 316, Loss: 297.8265532769188\n",
539
- "Epoch 317, Loss: 297.8265525840975\n",
540
- "Epoch 318, Loss: 297.8265518913472\n",
541
- "Epoch 319, Loss: 297.8265511986681\n",
542
- "Epoch 320, Loss: 297.82655050606\n",
543
- "Epoch 321, Loss: 297.8265498135231\n",
544
- "Epoch 322, Loss: 297.8265491210572\n",
545
- "Epoch 323, Loss: 297.82654842866265\n",
546
- "Epoch 324, Loss: 297.8265477363392\n",
547
- "Epoch 325, Loss: 297.826547044087\n",
548
- "Epoch 326, Loss: 297.826546351906\n",
549
- "Epoch 327, Loss: 297.82654565979635\n",
550
- "Epoch 328, Loss: 297.8265449677579\n",
551
- "Epoch 329, Loss: 297.8265442757908\n",
552
- "Epoch 330, Loss: 297.8265435838951\n",
553
- "Epoch 331, Loss: 297.82654289207085\n",
554
- "Epoch 332, Loss: 297.8265422003178\n",
555
- "Epoch 333, Loss: 297.82654150863635\n",
556
- "Epoch 334, Loss: 297.8265408170263\n",
557
- "Epoch 335, Loss: 297.8265401254876\n",
558
- "Epoch 336, Loss: 297.8265394340205\n",
559
- "Epoch 337, Loss: 297.826538742625\n",
560
- "Epoch 338, Loss: 297.8265380513009\n",
561
- "Epoch 339, Loss: 297.8265373600486\n",
562
- "Epoch 340, Loss: 297.8265366688676\n",
563
- "Epoch 341, Loss: 297.82653597775845\n",
564
- "Epoch 342, Loss: 297.82653528672085\n",
565
- "Epoch 343, Loss: 297.826534595755\n",
566
- "Epoch 344, Loss: 297.82653390486087\n",
567
- "Epoch 345, Loss: 297.8265332140384\n",
568
- "Epoch 346, Loss: 297.8265325232878\n",
569
- "Epoch 347, Loss: 297.8265318326088\n",
570
- "Epoch 348, Loss: 297.8265311420018\n",
571
- "Epoch 349, Loss: 297.8265304514666\n",
572
- "Epoch 350, Loss: 297.82652976100314\n",
573
- "Epoch 351, Loss: 297.82652907061174\n",
574
- "Epoch 352, Loss: 297.82652838029213\n",
575
- "Epoch 353, Loss: 297.8265276900445\n",
576
- "Epoch 354, Loss: 297.82652699986875\n",
577
- "Epoch 355, Loss: 297.82652630976514\n",
578
- "Epoch 356, Loss: 297.82652561973345\n",
579
- "Epoch 357, Loss: 297.82652492977377\n",
580
- "Epoch 358, Loss: 297.82652423988617\n",
581
- "Epoch 359, Loss: 297.82652355007065\n",
582
- "Epoch 360, Loss: 297.8265228603273\n",
583
- "Epoch 361, Loss: 297.8265221706561\n",
584
- "Epoch 362, Loss: 297.82652148105706\n",
585
- "Epoch 363, Loss: 297.8265207915302\n",
586
- "Epoch 364, Loss: 297.8265201020755\n",
587
- "Epoch 365, Loss: 297.8265194126932\n",
588
- "Epoch 366, Loss: 297.82651872338306\n",
589
- "Epoch 367, Loss: 297.8265180341453\n",
590
- "Epoch 368, Loss: 297.8265173449798\n",
591
- "Epoch 369, Loss: 297.8265166558866\n",
592
- "Epoch 370, Loss: 297.8265159668659\n",
593
- "Epoch 371, Loss: 297.82651527791745\n",
594
- "Epoch 372, Loss: 297.8265145890415\n",
595
- "Epoch 373, Loss: 297.82651390023807\n",
596
- "Epoch 374, Loss: 297.82651321150695\n",
597
- "Epoch 375, Loss: 297.82651252284853\n",
598
- "Epoch 376, Loss: 297.8265118342626\n",
599
- "Epoch 377, Loss: 297.8265111457491\n",
600
- "Epoch 378, Loss: 297.82651045730825\n",
601
- "Epoch 379, Loss: 297.8265097689401\n",
602
- "Epoch 380, Loss: 297.8265090806445\n",
603
- "Epoch 381, Loss: 297.8265083924216\n",
604
- "Epoch 382, Loss: 297.82650770427136\n",
605
- "Epoch 383, Loss: 297.82650701619394\n",
606
- "Epoch 384, Loss: 297.82650632818905\n",
607
- "Epoch 385, Loss: 297.8265056402571\n",
608
- "Epoch 386, Loss: 297.8265049523977\n",
609
- "Epoch 387, Loss: 297.82650426461134\n",
610
- "Epoch 388, Loss: 297.82650357689784\n",
611
- "Epoch 389, Loss: 297.82650288925714\n",
612
- "Epoch 390, Loss: 297.82650220168927\n",
613
- "Epoch 391, Loss: 297.82650151419443\n",
614
- "Epoch 392, Loss: 297.8265008267725\n",
615
- "Epoch 393, Loss: 297.8265001394235\n",
616
- "Epoch 394, Loss: 297.8264994521476\n",
617
- "Epoch 395, Loss: 297.8264987649447\n",
618
- "Epoch 396, Loss: 297.82649807781473\n",
619
- "Epoch 397, Loss: 297.82649739075794\n",
620
- "Epoch 398, Loss: 297.82649670377424\n",
621
- "Epoch 399, Loss: 297.82649601686364\n",
622
- "Epoch 400, Loss: 297.8264953300262\n",
623
- "Epoch 401, Loss: 297.826494643262\n",
624
- "Epoch 402, Loss: 297.82649395657097\n",
625
- "Epoch 403, Loss: 297.8264932699532\n",
626
- "Epoch 404, Loss: 297.8264925834086\n",
627
- "Epoch 405, Loss: 297.8264918969374\n",
628
- "Epoch 406, Loss: 297.8264912105394\n",
629
- "Epoch 407, Loss: 297.8264905242148\n",
630
- "Epoch 408, Loss: 297.82648983796355\n",
631
- "Epoch 409, Loss: 297.8264891517858\n",
632
- "Epoch 410, Loss: 297.82648846568134\n",
633
- "Epoch 411, Loss: 297.8264877796504\n",
634
- "Epoch 412, Loss: 297.8264870936928\n",
635
- "Epoch 413, Loss: 297.82648640780883\n",
636
- "Epoch 414, Loss: 297.8264857219984\n",
637
- "Epoch 415, Loss: 297.8264850362614\n",
638
- "Epoch 416, Loss: 297.82648435059804\n",
639
- "Epoch 417, Loss: 297.8264836650082\n",
640
- "Epoch 418, Loss: 297.8264829794921\n",
641
- "Epoch 419, Loss: 297.82648229404964\n",
642
- "Epoch 420, Loss: 297.8264816086808\n",
643
- "Epoch 421, Loss: 297.8264809233857\n",
644
- "Epoch 422, Loss: 297.82648023816444\n",
645
- "Epoch 423, Loss: 297.8264795530168\n",
646
- "Epoch 424, Loss: 297.82647886794297\n",
647
- "Epoch 425, Loss: 297.82647818294294\n",
648
- "Epoch 426, Loss: 297.82647749801686\n",
649
- "Epoch 427, Loss: 297.8264768131645\n",
650
- "Epoch 428, Loss: 297.8264761283861\n",
651
- "Epoch 429, Loss: 297.82647544368155\n",
652
- "Epoch 430, Loss: 297.82647475905094\n",
653
- "Epoch 431, Loss: 297.82647407449446\n",
654
- "Epoch 432, Loss: 297.8264733900118\n",
655
- "Epoch 433, Loss: 297.82647270560335\n",
656
- "Epoch 434, Loss: 297.8264720212687\n",
657
- "Epoch 435, Loss: 297.82647133700834\n",
658
- "Epoch 436, Loss: 297.8264706528221\n",
659
- "Epoch 437, Loss: 297.82646996870983\n",
660
- "Epoch 438, Loss: 297.8264692846718\n",
661
- "Epoch 439, Loss: 297.8264686007079\n",
662
- "Epoch 440, Loss: 297.8264679168184\n",
663
- "Epoch 441, Loss: 297.8264672330029\n",
664
- "Epoch 442, Loss: 297.8264665492618\n",
665
- "Epoch 443, Loss: 297.82646586559486\n",
666
- "Epoch 444, Loss: 297.8264651820023\n",
667
- "Epoch 445, Loss: 297.82646449848403\n",
668
- "Epoch 446, Loss: 297.8264638150402\n",
669
- "Epoch 447, Loss: 297.8264631316708\n",
670
- "Epoch 448, Loss: 297.8264624483758\n",
671
- "Epoch 449, Loss: 297.8264617651551\n",
672
- "Epoch 450, Loss: 297.8264610820091\n",
673
- "Epoch 451, Loss: 297.82646039893746\n",
674
- "Epoch 452, Loss: 297.82645971594036\n",
675
- "Epoch 453, Loss: 297.8264590330179\n",
676
- "Epoch 454, Loss: 297.8264583501699\n",
677
- "Epoch 455, Loss: 297.82645766739654\n",
678
- "Epoch 456, Loss: 297.82645698469787\n",
679
- "Epoch 457, Loss: 297.8264563020737\n",
680
- "Epoch 458, Loss: 297.82645561952444\n",
681
- "Epoch 459, Loss: 297.8264549370498\n",
682
- "Epoch 460, Loss: 297.82645425464983\n",
683
- "Epoch 461, Loss: 297.8264535723248\n",
684
- "Epoch 462, Loss: 297.82645289007445\n",
685
- "Epoch 463, Loss: 297.8264522078989\n",
686
- "Epoch 464, Loss: 297.8264515257983\n",
687
- "Epoch 465, Loss: 297.8264508437724\n",
688
- "Epoch 466, Loss: 297.8264501618215\n",
689
- "Epoch 467, Loss: 297.82644947994555\n",
690
- "Epoch 468, Loss: 297.82644879814444\n",
691
- "Epoch 469, Loss: 297.8264481164186\n",
692
- "Epoch 470, Loss: 297.82644743476743\n",
693
- "Epoch 471, Loss: 297.82644675319153\n",
694
- "Epoch 472, Loss: 297.82644607169055\n",
695
- "Epoch 473, Loss: 297.8264453902647\n",
696
- "Epoch 474, Loss: 297.8264447089139\n",
697
- "Epoch 475, Loss: 297.82644402763833\n",
698
- "Epoch 476, Loss: 297.82644334643805\n",
699
- "Epoch 477, Loss: 297.82644266531275\n",
700
- "Epoch 478, Loss: 297.8264419842628\n",
701
- "Epoch 479, Loss: 297.82644130328805\n",
702
- "Epoch 480, Loss: 297.8264406223886\n",
703
- "Epoch 481, Loss: 297.8264399415644\n",
704
- "Epoch 482, Loss: 297.8264392608157\n",
705
- "Epoch 483, Loss: 297.8264385801421\n",
706
- "Epoch 484, Loss: 297.82643789954403\n",
707
- "Epoch 485, Loss: 297.82643721902133\n",
708
- "Epoch 486, Loss: 297.8264365385741\n",
709
- "Epoch 487, Loss: 297.8264358582022\n",
710
- "Epoch 488, Loss: 297.82643517790603\n",
711
- "Epoch 489, Loss: 297.8264344976852\n",
712
- "Epoch 490, Loss: 297.82643381753996\n",
713
- "Epoch 491, Loss: 297.8264331374703\n",
714
- "Epoch 492, Loss: 297.8264324574763\n",
715
- "Epoch 493, Loss: 297.82643177755784\n",
716
- "Epoch 494, Loss: 297.82643109771504\n",
717
- "Epoch 495, Loss: 297.82643041794796\n",
718
- "Epoch 496, Loss: 297.8264297382566\n",
719
- "Epoch 497, Loss: 297.82642905864094\n",
720
- "Epoch 498, Loss: 297.82642837910106\n",
721
- "Epoch 499, Loss: 297.82642769963695\n",
722
- "Epoch 500, Loss: 297.8264270202487\n",
723
- "Epoch 501, Loss: 297.8264263409362\n",
724
- "Epoch 502, Loss: 297.82642566169966\n",
725
- "Epoch 503, Loss: 297.826424982539\n",
726
- "Epoch 504, Loss: 297.8264243034543\n",
727
- "Epoch 505, Loss: 297.82642362444545\n",
728
- "Epoch 506, Loss: 297.8264229455126\n",
729
- "Epoch 507, Loss: 297.8264222666558\n",
730
- "Epoch 508, Loss: 297.82642158787496\n",
731
- "Epoch 509, Loss: 297.82642090917034\n",
732
- "Epoch 510, Loss: 297.8264202305416\n",
733
- "Epoch 511, Loss: 297.82641955198915\n",
734
- "Epoch 512, Loss: 297.82641887351275\n",
735
- "Epoch 513, Loss: 297.8264181951125\n",
736
- "Epoch 514, Loss: 297.8264175167885\n",
737
- "Epoch 515, Loss: 297.82641683854075\n",
738
- "Epoch 516, Loss: 297.82641616036915\n",
739
- "Epoch 517, Loss: 297.82641548227383\n",
740
- "Epoch 518, Loss: 297.82641480425497\n",
741
- "Epoch 519, Loss: 297.8264141263123\n",
742
- "Epoch 520, Loss: 297.82641344844603\n",
743
- "Epoch 521, Loss: 297.8264127706562\n",
744
- "Epoch 522, Loss: 297.82641209294275\n",
745
- "Epoch 523, Loss: 297.82641141530576\n",
746
- "Epoch 524, Loss: 297.8264107377451\n",
747
- "Epoch 525, Loss: 297.826410060261\n",
748
- "Epoch 526, Loss: 297.82640938285346\n",
749
- "Epoch 527, Loss: 297.8264087055225\n",
750
- "Epoch 528, Loss: 297.826408028268\n",
751
- "Epoch 529, Loss: 297.82640735109027\n",
752
- "Epoch 530, Loss: 297.8264066739891\n",
753
- "Epoch 531, Loss: 297.8264059969646\n",
754
- "Epoch 532, Loss: 297.8264053200167\n",
755
- "Epoch 533, Loss: 297.8264046431456\n",
756
- "Epoch 534, Loss: 297.82640396635117\n",
757
- "Epoch 535, Loss: 297.8264032896337\n",
758
- "Epoch 536, Loss: 297.82640261299275\n",
759
- "Epoch 537, Loss: 297.8264019364288\n",
760
- "Epoch 538, Loss: 297.82640125994163\n",
761
- "Epoch 539, Loss: 297.8264005835315\n",
762
- "Epoch 540, Loss: 297.8263999071981\n",
763
- "Epoch 541, Loss: 297.82639923094166\n",
764
- "Epoch 542, Loss: 297.8263985547622\n",
765
- "Epoch 543, Loss: 297.82639787865975\n",
766
- "Epoch 544, Loss: 297.82639720263427\n",
767
- "Epoch 545, Loss: 297.8263965266858\n",
768
- "Epoch 546, Loss: 297.82639585081455\n",
769
- "Epoch 547, Loss: 297.82639517502025\n",
770
- "Epoch 548, Loss: 297.82639449930315\n",
771
- "Epoch 549, Loss: 297.8263938236632\n",
772
- "Epoch 550, Loss: 297.8263931481004\n",
773
- "Epoch 551, Loss: 297.8263924726149\n",
774
- "Epoch 552, Loss: 297.8263917972065\n",
775
- "Epoch 553, Loss: 297.8263911218754\n",
776
- "Epoch 554, Loss: 297.82639044662164\n",
777
- "Epoch 555, Loss: 297.8263897714451\n",
778
- "Epoch 556, Loss: 297.8263890963461\n",
779
- "Epoch 557, Loss: 297.82638842132434\n",
780
- "Epoch 558, Loss: 297.82638774638\n",
781
- "Epoch 559, Loss: 297.82638707151307\n",
782
- "Epoch 560, Loss: 297.82638639672365\n",
783
- "Epoch 561, Loss: 297.8263857220117\n",
784
- "Epoch 562, Loss: 297.8263850473772\n",
785
- "Epoch 563, Loss: 297.8263843728203\n",
786
- "Epoch 564, Loss: 297.826383698341\n",
787
- "Epoch 565, Loss: 297.82638302393923\n",
788
- "Epoch 566, Loss: 297.82638234961513\n",
789
- "Epoch 567, Loss: 297.8263816753686\n",
790
- "Epoch 568, Loss: 297.82638100119976\n",
791
- "Epoch 569, Loss: 297.82638032710867\n",
792
- "Epoch 570, Loss: 297.82637965309533\n",
793
- "Epoch 571, Loss: 297.82637897915964\n",
794
- "Epoch 572, Loss: 297.8263783053019\n",
795
- "Epoch 573, Loss: 297.8263776315219\n",
796
- "Epoch 574, Loss: 297.82637695781983\n",
797
- "Epoch 575, Loss: 297.8263762841955\n",
798
- "Epoch 576, Loss: 297.8263756106491\n",
799
- "Epoch 577, Loss: 297.8263749371807\n",
800
- "Epoch 578, Loss: 297.8263742637902\n",
801
- "Epoch 579, Loss: 297.82637359047766\n",
802
- "Epoch 580, Loss: 297.8263729172433\n",
803
- "Epoch 581, Loss: 297.82637224408677\n",
804
- "Epoch 582, Loss: 297.8263715710084\n",
805
- "Epoch 583, Loss: 297.826370898008\n",
806
- "Epoch 584, Loss: 297.8263702250859\n",
807
- "Epoch 585, Loss: 297.82636955224183\n",
808
- "Epoch 586, Loss: 297.82636887947604\n",
809
- "Epoch 587, Loss: 297.8263682067884\n",
810
- "Epoch 588, Loss: 297.826367534179\n",
811
- "Epoch 589, Loss: 297.82636686164784\n",
812
- "Epoch 590, Loss: 297.8263661891951\n",
813
- "Epoch 591, Loss: 297.82636551682054\n",
814
- "Epoch 592, Loss: 297.82636484452433\n",
815
- "Epoch 593, Loss: 297.8263641723065\n",
816
- "Epoch 594, Loss: 297.8263635001672\n",
817
- "Epoch 595, Loss: 297.82636282810614\n",
818
- "Epoch 596, Loss: 297.8263621561237\n",
819
- "Epoch 597, Loss: 297.8263614842196\n",
820
- "Epoch 598, Loss: 297.82636081239417\n",
821
- "Epoch 599, Loss: 297.8263601406472\n",
822
- "Epoch 600, Loss: 297.8263594689787\n",
823
- "Epoch 601, Loss: 297.826358797389\n",
824
- "Epoch 602, Loss: 297.82635812587785\n",
825
- "Epoch 603, Loss: 297.82635745444526\n",
826
- "Epoch 604, Loss: 297.82635678309146\n",
827
- "Epoch 605, Loss: 297.8263561118164\n",
828
- "Epoch 606, Loss: 297.82635544062\n",
829
- "Epoch 607, Loss: 297.8263547695023\n",
830
- "Epoch 608, Loss: 297.82635409846347\n",
831
- "Epoch 609, Loss: 297.8263534275034\n",
832
- "Epoch 610, Loss: 297.8263527566223\n",
833
- "Epoch 611, Loss: 297.8263520858201\n",
834
- "Epoch 612, Loss: 297.82635141509684\n",
835
- "Epoch 613, Loss: 297.8263507444523\n",
836
- "Epoch 614, Loss: 297.8263500738869\n",
837
- "Epoch 615, Loss: 297.8263494034004\n",
838
- "Epoch 616, Loss: 297.8263487329929\n",
839
- "Epoch 617, Loss: 297.8263480626645\n",
840
- "Epoch 618, Loss: 297.8263473924152\n",
841
- "Epoch 619, Loss: 297.82634672224503\n",
842
- "Epoch 620, Loss: 297.8263460521538\n",
843
- "Epoch 621, Loss: 297.82634538214194\n",
844
- "Epoch 622, Loss: 297.8263447122093\n",
845
- "Epoch 623, Loss: 297.82634404235574\n",
846
- "Epoch 624, Loss: 297.8263433725814\n",
847
- "Epoch 625, Loss: 297.8263427028865\n",
848
- "Epoch 626, Loss: 297.82634203327075\n",
849
- "Epoch 627, Loss: 297.8263413637344\n",
850
- "Epoch 628, Loss: 297.82634069427735\n",
851
- "Epoch 629, Loss: 297.82634002489976\n",
852
- "Epoch 630, Loss: 297.82633935560165\n",
853
- "Epoch 631, Loss: 297.8263386863829\n",
854
- "Epoch 632, Loss: 297.8263380172436\n",
855
- "Epoch 633, Loss: 297.82633734818376\n",
856
- "Epoch 634, Loss: 297.8263366792035\n",
857
- "Epoch 635, Loss: 297.8263360103027\n",
858
- "Epoch 636, Loss: 297.8263353414817\n",
859
- "Epoch 637, Loss: 297.82633467274024\n",
860
- "Epoch 638, Loss: 297.8263340040784\n",
861
- "Epoch 639, Loss: 297.8263333354962\n",
862
- "Epoch 640, Loss: 297.8263326669936\n",
863
- "Epoch 641, Loss: 297.826331998571\n",
864
- "Epoch 642, Loss: 297.82633133022796\n",
865
- "Epoch 643, Loss: 297.82633066196473\n",
866
- "Epoch 644, Loss: 297.82632999378137\n",
867
- "Epoch 645, Loss: 297.8263293256778\n",
868
- "Epoch 646, Loss: 297.8263286576541\n",
869
- "Epoch 647, Loss: 297.8263279897103\n",
870
- "Epoch 648, Loss: 297.82632732184646\n",
871
- "Epoch 649, Loss: 297.8263266540626\n",
872
- "Epoch 650, Loss: 297.8263259863587\n",
873
- "Epoch 651, Loss: 297.8263253187347\n",
874
- "Epoch 652, Loss: 297.82632465119093\n",
875
- "Epoch 653, Loss: 297.82632398372704\n",
876
- "Epoch 654, Loss: 297.8263233163434\n",
877
- "Epoch 655, Loss: 297.8263226490398\n",
878
- "Epoch 656, Loss: 297.82632198181636\n",
879
- "Epoch 657, Loss: 297.82632131467324\n",
880
- "Epoch 658, Loss: 297.82632064761026\n",
881
- "Epoch 659, Loss: 297.82631998062743\n",
882
- "Epoch 660, Loss: 297.8263193137249\n",
883
- "Epoch 661, Loss: 297.8263186469027\n",
884
- "Epoch 662, Loss: 297.82631798016087\n",
885
- "Epoch 663, Loss: 297.8263173134994\n",
886
- "Epoch 664, Loss: 297.82631664691814\n",
887
- "Epoch 665, Loss: 297.82631598041746\n",
888
- "Epoch 666, Loss: 297.8263153139972\n",
889
- "Epoch 667, Loss: 297.82631464765734\n",
890
- "Epoch 668, Loss: 297.8263139813981\n",
891
- "Epoch 669, Loss: 297.82631331521924\n",
892
- "Epoch 670, Loss: 297.82631264912106\n",
893
- "Epoch 671, Loss: 297.82631198310344\n",
894
- "Epoch 672, Loss: 297.8263113171664\n",
895
- "Epoch 673, Loss: 297.82631065131005\n",
896
- "Epoch 674, Loss: 297.8263099855343\n",
897
- "Epoch 675, Loss: 297.82630931983937\n",
898
- "Epoch 676, Loss: 297.82630865422504\n",
899
- "Epoch 677, Loss: 297.8263079886915\n",
900
- "Epoch 678, Loss: 297.8263073232387\n",
901
- "Epoch 679, Loss: 297.8263066578669\n",
902
- "Epoch 680, Loss: 297.82630599257584\n",
903
- "Epoch 681, Loss: 297.8263053273656\n",
904
- "Epoch 682, Loss: 297.82630466223634\n",
905
- "Epoch 683, Loss: 297.8263039971879\n",
906
- "Epoch 684, Loss: 297.82630333222056\n",
907
- "Epoch 685, Loss: 297.8263026673342\n",
908
- "Epoch 686, Loss: 297.8263020025288\n",
909
- "Epoch 687, Loss: 297.82630133780435\n",
910
- "Epoch 688, Loss: 297.82630067316114\n",
911
- "Epoch 689, Loss: 297.826300008599\n",
912
- "Epoch 690, Loss: 297.8262993441179\n",
913
- "Epoch 691, Loss: 297.8262986797181\n",
914
- "Epoch 692, Loss: 297.8262980153994\n",
915
- "Epoch 693, Loss: 297.82629735116194\n",
916
- "Epoch 694, Loss: 297.82629668700577\n",
917
- "Epoch 695, Loss: 297.8262960229308\n",
918
- "Epoch 696, Loss: 297.8262953589371\n",
919
- "Epoch 697, Loss: 297.82629469502484\n",
920
- "Epoch 698, Loss: 297.826294031194\n",
921
- "Epoch 699, Loss: 297.8262933674444\n",
922
- "Epoch 700, Loss: 297.8262927037763\n",
923
- "Epoch 701, Loss: 297.82629204018974\n",
924
- "Epoch 702, Loss: 297.8262913766845\n",
925
- "Epoch 703, Loss: 297.82629071326073\n",
926
- "Epoch 704, Loss: 297.82629004991867\n",
927
- "Epoch 705, Loss: 297.82628938665823\n",
928
- "Epoch 706, Loss: 297.8262887234792\n",
929
- "Epoch 707, Loss: 297.8262880603819\n",
930
- "Epoch 708, Loss: 297.82628739736623\n",
931
- "Epoch 709, Loss: 297.8262867344323\n",
932
- "Epoch 710, Loss: 297.82628607158\n",
933
- "Epoch 711, Loss: 297.8262854088095\n",
934
- "Epoch 712, Loss: 297.8262847461207\n",
935
- "Epoch 713, Loss: 297.82628408351377\n",
936
- "Epoch 714, Loss: 297.8262834209886\n",
937
- "Epoch 715, Loss: 297.8262827585453\n",
938
- "Epoch 716, Loss: 297.8262820961839\n",
939
- "Epoch 717, Loss: 297.8262814339045\n",
940
- "Epoch 718, Loss: 297.826280771707\n",
941
- "Epoch 719, Loss: 297.8262801095915\n",
942
- "Epoch 720, Loss: 297.82627944755797\n",
943
- "Epoch 721, Loss: 297.8262787856065\n",
944
- "Epoch 722, Loss: 297.82627812373715\n",
945
- "Epoch 723, Loss: 297.8262774619497\n",
946
- "Epoch 724, Loss: 297.82627680024444\n",
947
- "Epoch 725, Loss: 297.8262761386215\n",
948
- "Epoch 726, Loss: 297.8262754770806\n",
949
- "Epoch 727, Loss: 297.826274815622\n",
950
- "Epoch 728, Loss: 297.82627415424554\n",
951
- "Epoch 729, Loss: 297.82627349295143\n",
952
- "Epoch 730, Loss: 297.8262728317395\n",
953
- "Epoch 731, Loss: 297.82627217061\n",
954
- "Epoch 732, Loss: 297.82627150956284\n",
955
- "Epoch 733, Loss: 297.8262708485981\n",
956
- "Epoch 734, Loss: 297.82627018771575\n",
957
- "Epoch 735, Loss: 297.8262695269158\n",
958
- "Epoch 736, Loss: 297.8262688661983\n",
959
- "Epoch 737, Loss: 297.82626820556334\n",
960
- "Epoch 738, Loss: 297.826267545011\n",
961
- "Epoch 739, Loss: 297.8262668845411\n",
962
- "Epoch 740, Loss: 297.82626622415387\n",
963
- "Epoch 741, Loss: 297.82626556384935\n",
964
- "Epoch 742, Loss: 297.82626490362736\n",
965
- "Epoch 743, Loss: 297.826264243488\n",
966
- "Epoch 744, Loss: 297.82626358343146\n",
967
- "Epoch 745, Loss: 297.82626292345765\n",
968
- "Epoch 746, Loss: 297.82626226356655\n",
969
- "Epoch 747, Loss: 297.8262616037582\n",
970
- "Epoch 748, Loss: 297.8262609440329\n",
971
- "Epoch 749, Loss: 297.8262602843903\n",
972
- "Epoch 750, Loss: 297.82625962483047\n",
973
- "Epoch 751, Loss: 297.82625896535376\n",
974
- "Epoch 752, Loss: 297.82625830595987\n",
975
- "Epoch 753, Loss: 297.8262576466491\n",
976
- "Epoch 754, Loss: 297.82625698742123\n",
977
- "Epoch 755, Loss: 297.82625632827643\n",
978
- "Epoch 756, Loss: 297.8262556692147\n",
979
- "Epoch 757, Loss: 297.82625501023597\n",
980
- "Epoch 758, Loss: 297.8262543513405\n",
981
- "Epoch 759, Loss: 297.8262536925281\n",
982
- "Epoch 760, Loss: 297.82625303379893\n",
983
- "Epoch 761, Loss: 297.8262523751529\n",
984
- "Epoch 762, Loss: 297.82625171659015\n",
985
- "Epoch 763, Loss: 297.82625105811076\n",
986
- "Epoch 764, Loss: 297.8262503997146\n",
987
- "Epoch 765, Loss: 297.82624974140174\n",
988
- "Epoch 766, Loss: 297.82624908317234\n",
989
- "Epoch 767, Loss: 297.8262484250262\n",
990
- "Epoch 768, Loss: 297.82624776696355\n",
991
- "Epoch 769, Loss: 297.8262471089844\n",
992
- "Epoch 770, Loss: 297.82624645108865\n",
993
- "Epoch 771, Loss: 297.8262457932765\n",
994
- "Epoch 772, Loss: 297.82624513554777\n",
995
- "Epoch 773, Loss: 297.8262444779027\n",
996
- "Epoch 774, Loss: 297.8262438203412\n",
997
- "Epoch 775, Loss: 297.8262431628633\n",
998
- "Epoch 776, Loss: 297.8262425054691\n",
999
- "Epoch 777, Loss: 297.82624184815865\n",
1000
- "Epoch 778, Loss: 297.82624119093174\n",
1001
- "Epoch 779, Loss: 297.82624053378873\n",
1002
- "Epoch 780, Loss: 297.82623987672946\n",
1003
- "Epoch 781, Loss: 297.826239219754\n",
1004
- "Epoch 782, Loss: 297.8262385628624\n",
1005
- "Epoch 783, Loss: 297.82623790605464\n",
1006
- "Epoch 784, Loss: 297.82623724933075\n",
1007
- "Epoch 785, Loss: 297.82623659269086\n",
1008
- "Epoch 786, Loss: 297.8262359361348\n",
1009
- "Epoch 787, Loss: 297.8262352796629\n",
1010
- "Epoch 788, Loss: 297.8262346232749\n",
1011
- "Epoch 789, Loss: 297.8262339669709\n",
1012
- "Epoch 790, Loss: 297.8262333107511\n",
1013
- "Epoch 791, Loss: 297.82623265461535\n",
1014
- "Epoch 792, Loss: 297.8262319985638\n",
1015
- "Epoch 793, Loss: 297.8262313425964\n",
1016
- "Epoch 794, Loss: 297.8262306867131\n",
1017
- "Epoch 795, Loss: 297.8262300309141\n",
1018
- "Epoch 796, Loss: 297.82622937519943\n",
1019
- "Epoch 797, Loss: 297.826228719569\n",
1020
- "Epoch 798, Loss: 297.82622806402287\n",
1021
- "Epoch 799, Loss: 297.82622740856107\n",
1022
- "Epoch 800, Loss: 297.8262267531837\n",
1023
- "Epoch 801, Loss: 297.82622609789064\n",
1024
- "Epoch 802, Loss: 297.826225442682\n",
1025
- "Epoch 803, Loss: 297.826224787558\n",
1026
- "Epoch 804, Loss: 297.8262241325183\n",
1027
- "Epoch 805, Loss: 297.8262234775633\n",
1028
- "Epoch 806, Loss: 297.8262228226928\n",
1029
- "Epoch 807, Loss: 297.82622216790685\n",
1030
- "Epoch 808, Loss: 297.8262215132056\n",
1031
- "Epoch 809, Loss: 297.8262208585888\n",
1032
- "Epoch 810, Loss: 297.8262202040569\n",
1033
- "Epoch 811, Loss: 297.8262195496096\n",
1034
- "Epoch 812, Loss: 297.82621889524705\n",
1035
- "Epoch 813, Loss: 297.82621824096935\n",
1036
- "Epoch 814, Loss: 297.8262175867764\n",
1037
- "Epoch 815, Loss: 297.8262169326682\n",
1038
- "Epoch 816, Loss: 297.82621627864495\n",
1039
- "Epoch 817, Loss: 297.82621562470666\n",
1040
- "Epoch 818, Loss: 297.8262149708532\n",
1041
- "Epoch 819, Loss: 297.82621431708463\n",
1042
- "Epoch 820, Loss: 297.8262136634011\n",
1043
- "Epoch 821, Loss: 297.82621300980264\n",
1044
- "Epoch 822, Loss: 297.82621235628915\n",
1045
- "Epoch 823, Loss: 297.82621170286075\n",
1046
- "Epoch 824, Loss: 297.8262110495175\n",
1047
- "Epoch 825, Loss: 297.8262103962594\n",
1048
- "Epoch 826, Loss: 297.8262097430863\n",
1049
- "Epoch 827, Loss: 297.8262090899985\n",
1050
- "Epoch 828, Loss: 297.82620843699596\n",
1051
- "Epoch 829, Loss: 297.8262077840786\n",
1052
- "Epoch 830, Loss: 297.8262071312466\n",
1053
- "Epoch 831, Loss: 297.82620647849984\n",
1054
- "Epoch 832, Loss: 297.8262058258384\n",
1055
- "Epoch 833, Loss: 297.82620517326245\n",
1056
- "Epoch 834, Loss: 297.8262045207719\n",
1057
- "Epoch 835, Loss: 297.82620386836675\n",
1058
- "Epoch 836, Loss: 297.82620321604696\n",
1059
- "Epoch 837, Loss: 297.8262025638128\n",
1060
- "Epoch 838, Loss: 297.82620191166416\n",
1061
- "Epoch 839, Loss: 297.8262012596011\n",
1062
- "Epoch 840, Loss: 297.82620060762355\n",
1063
- "Epoch 841, Loss: 297.8261999557316\n",
1064
- "Epoch 842, Loss: 297.8261993039254\n",
1065
- "Epoch 843, Loss: 297.8261986522048\n",
1066
- "Epoch 844, Loss: 297.82619800056995\n",
1067
- "Epoch 845, Loss: 297.8261973490208\n",
1068
- "Epoch 846, Loss: 297.82619669755746\n",
1069
- "Epoch 847, Loss: 297.8261960461799\n",
1070
- "Epoch 848, Loss: 297.82619539488826\n",
1071
- "Epoch 849, Loss: 297.82619474368244\n",
1072
- "Epoch 850, Loss: 297.82619409256245\n",
1073
- "Epoch 851, Loss: 297.8261934415284\n",
1074
- "Epoch 852, Loss: 297.82619279058036\n",
1075
- "Epoch 853, Loss: 297.82619213971833\n",
1076
- "Epoch 854, Loss: 297.8261914889422\n",
1077
- "Epoch 855, Loss: 297.82619083825216\n",
1078
- "Epoch 856, Loss: 297.82619018764836\n",
1079
- "Epoch 857, Loss: 297.8261895371304\n",
1080
- "Epoch 858, Loss: 297.82618888669873\n",
1081
- "Epoch 859, Loss: 297.8261882363531\n",
1082
- "Epoch 860, Loss: 297.8261875860939\n",
1083
- "Epoch 861, Loss: 297.8261869359208\n",
1084
- "Epoch 862, Loss: 297.826186285834\n",
1085
- "Epoch 863, Loss: 297.82618563583344\n",
1086
- "Epoch 864, Loss: 297.8261849859192\n",
1087
- "Epoch 865, Loss: 297.8261843360914\n",
1088
- "Epoch 866, Loss: 297.82618368634996\n",
1089
- "Epoch 867, Loss: 297.8261830366949\n",
1090
- "Epoch 868, Loss: 297.82618238712627\n",
1091
- "Epoch 869, Loss: 297.82618173764416\n",
1092
- "Epoch 870, Loss: 297.8261810882486\n",
1093
- "Epoch 871, Loss: 297.82618043893956\n",
1094
- "Epoch 872, Loss: 297.826179789717\n",
1095
- "Epoch 873, Loss: 297.8261791405811\n",
1096
- "Epoch 874, Loss: 297.82617849153183\n",
1097
- "Epoch 875, Loss: 297.82617784256917\n",
1098
- "Epoch 876, Loss: 297.8261771936933\n",
1099
- "Epoch 877, Loss: 297.8261765449042\n",
1100
- "Epoch 878, Loss: 297.8261758962016\n",
1101
- "Epoch 879, Loss: 297.826175247586\n",
1102
- "Epoch 880, Loss: 297.82617459905725\n",
1103
- "Epoch 881, Loss: 297.82617395061516\n",
1104
- "Epoch 882, Loss: 297.8261733022601\n",
1105
- "Epoch 883, Loss: 297.82617265399193\n",
1106
- "Epoch 884, Loss: 297.82617200581075\n",
1107
- "Epoch 885, Loss: 297.8261713577164\n",
1108
- "Epoch 886, Loss: 297.8261707097091\n",
1109
- "Epoch 887, Loss: 297.82617006178884\n",
1110
- "Epoch 888, Loss: 297.8261694139557\n",
1111
- "Epoch 889, Loss: 297.82616876620966\n",
1112
- "Epoch 890, Loss: 297.82616811855064\n",
1113
- "Epoch 891, Loss: 297.8261674709788\n",
1114
- "Epoch 892, Loss: 297.82616682349425\n",
1115
- "Epoch 893, Loss: 297.8261661760969\n",
1116
- "Epoch 894, Loss: 297.82616552878676\n",
1117
- "Epoch 895, Loss: 297.82616488156384\n",
1118
- "Epoch 896, Loss: 297.8261642344284\n",
1119
- "Epoch 897, Loss: 297.8261635873801\n",
1120
- "Epoch 898, Loss: 297.82616294041935\n",
1121
- "Epoch 899, Loss: 297.8261622935459\n",
1122
- "Epoch 900, Loss: 297.82616164676\n",
1123
- "Epoch 901, Loss: 297.8261610000614\n",
1124
- "Epoch 902, Loss: 297.8261603534504\n",
1125
- "Epoch 903, Loss: 297.82615970692694\n",
1126
- "Epoch 904, Loss: 297.826159060491\n",
1127
- "Epoch 905, Loss: 297.82615841414264\n",
1128
- "Epoch 906, Loss: 297.82615776788197\n",
1129
- "Epoch 907, Loss: 297.826157121709\n",
1130
- "Epoch 908, Loss: 297.82615647562363\n",
1131
- "Epoch 909, Loss: 297.8261558296259\n",
1132
- "Epoch 910, Loss: 297.8261551837161\n",
1133
- "Epoch 911, Loss: 297.826154537894\n",
1134
- "Epoch 912, Loss: 297.82615389215977\n",
1135
- "Epoch 913, Loss: 297.82615324651323\n",
1136
- "Epoch 914, Loss: 297.8261526009547\n",
1137
- "Epoch 915, Loss: 297.826151955484\n",
1138
- "Epoch 916, Loss: 297.8261513101013\n",
1139
- "Epoch 917, Loss: 297.8261506648065\n",
1140
- "Epoch 918, Loss: 297.8261500195997\n",
1141
- "Epoch 919, Loss: 297.826149374481\n",
1142
- "Epoch 920, Loss: 297.82614872945044\n",
1143
- "Epoch 921, Loss: 297.8261480845078\n",
1144
- "Epoch 922, Loss: 297.8261474396533\n",
1145
- "Epoch 923, Loss: 297.8261467948871\n",
1146
- "Epoch 924, Loss: 297.826146150209\n",
1147
- "Epoch 925, Loss: 297.82614550561914\n",
1148
- "Epoch 926, Loss: 297.8261448611175\n",
1149
- "Epoch 927, Loss: 297.82614421670417\n",
1150
- "Epoch 928, Loss: 297.8261435723791\n",
1151
- "Epoch 929, Loss: 297.8261429281424\n",
1152
- "Epoch 930, Loss: 297.8261422839941\n",
1153
- "Epoch 931, Loss: 297.82614163993424\n",
1154
- "Epoch 932, Loss: 297.8261409959628\n",
1155
- "Epoch 933, Loss: 297.82614035207973\n",
1156
- "Epoch 934, Loss: 297.82613970828527\n",
1157
- "Epoch 935, Loss: 297.8261390645793\n",
1158
- "Epoch 936, Loss: 297.826138420962\n",
1159
- "Epoch 937, Loss: 297.8261377774331\n",
1160
- "Epoch 938, Loss: 297.82613713399303\n",
1161
- "Epoch 939, Loss: 297.82613649064143\n",
1162
- "Epoch 940, Loss: 297.8261358473787\n",
1163
- "Epoch 941, Loss: 297.8261352042046\n",
1164
- "Epoch 942, Loss: 297.8261345611193\n",
1165
- "Epoch 943, Loss: 297.8261339181227\n",
1166
- "Epoch 944, Loss: 297.826133275215\n",
1167
- "Epoch 945, Loss: 297.82613263239614\n",
1168
- "Epoch 946, Loss: 297.8261319896662\n",
1169
- "Epoch 947, Loss: 297.82613134702507\n",
1170
- "Epoch 948, Loss: 297.826130704473\n",
1171
- "Epoch 949, Loss: 297.8261300620098\n",
1172
- "Epoch 950, Loss: 297.8261294196356\n",
1173
- "Epoch 951, Loss: 297.8261287773505\n",
1174
- "Epoch 952, Loss: 297.8261281351545\n",
1175
- "Epoch 953, Loss: 297.8261274930475\n",
1176
- "Epoch 954, Loss: 297.82612685102976\n",
1177
- "Epoch 955, Loss: 297.8261262091011\n",
1178
- "Epoch 956, Loss: 297.82612556726167\n",
1179
- "Epoch 957, Loss: 297.8261249255115\n",
1180
- "Epoch 958, Loss: 297.8261242838506\n",
1181
- "Epoch 959, Loss: 297.8261236422789\n",
1182
- "Epoch 960, Loss: 297.8261230007966\n",
1183
- "Epoch 961, Loss: 297.8261223594037\n",
1184
- "Epoch 962, Loss: 297.82612171810007\n",
1185
- "Epoch 963, Loss: 297.82612107688584\n",
1186
- "Epoch 964, Loss: 297.8261204357612\n",
1187
- "Epoch 965, Loss: 297.82611979472597\n",
1188
- "Epoch 966, Loss: 297.8261191537804\n",
1189
- "Epoch 967, Loss: 297.82611851292415\n",
1190
- "Epoch 968, Loss: 297.8261178721576\n",
1191
- "Epoch 969, Loss: 297.8261172314807\n",
1192
- "Epoch 970, Loss: 297.8261165908933\n",
1193
- "Epoch 971, Loss: 297.82611595039566\n",
1194
- "Epoch 972, Loss: 297.8261153099878\n",
1195
- "Epoch 973, Loss: 297.82611466966955\n",
1196
- "Epoch 974, Loss: 297.8261140294412\n",
1197
- "Epoch 975, Loss: 297.8261133893026\n",
1198
- "Epoch 976, Loss: 297.82611274925375\n",
1199
- "Epoch 977, Loss: 297.82611210929485\n",
1200
- "Epoch 978, Loss: 297.82611146942594\n",
1201
- "Epoch 979, Loss: 297.8261108296469\n",
1202
- "Epoch 980, Loss: 297.8261101899578\n",
1203
- "Epoch 981, Loss: 297.8261095503587\n",
1204
- "Epoch 982, Loss: 297.8261089108496\n",
1205
- "Epoch 983, Loss: 297.82610827143054\n",
1206
- "Epoch 984, Loss: 297.8261076321016\n",
1207
- "Epoch 985, Loss: 297.8261069928628\n",
1208
- "Epoch 986, Loss: 297.8261063537141\n",
1209
- "Epoch 987, Loss: 297.8261057146557\n",
1210
- "Epoch 988, Loss: 297.8261050756874\n",
1211
- "Epoch 989, Loss: 297.82610443680943\n",
1212
- "Epoch 990, Loss: 297.82610379802173\n",
1213
- "Epoch 991, Loss: 297.82610315932436\n",
1214
- "Epoch 992, Loss: 297.8261025207173\n",
1215
- "Epoch 993, Loss: 297.8261018822007\n",
1216
- "Epoch 994, Loss: 297.8261012437744\n",
1217
- "Epoch 995, Loss: 297.8261006054387\n",
1218
- "Epoch 996, Loss: 297.82609996719333\n",
1219
- "Epoch 997, Loss: 297.8260993290385\n",
1220
- "Epoch 998, Loss: 297.8260986909742\n",
1221
- "Epoch 999, Loss: 297.8260980530006\n",
1222
- "final parameters: m=0.45136980910052144, c=0.49775672565271384, sigma=1562.2616856027405\n"
1223
- ]
1224
- }
1225
- ],
1226
- "source": [
1227
- "## Problem 2\n",
1228
- "x = np.array([8, 16, 22, 33, 50, 51])\n",
1229
- "y = np.array([5, 20, 14, 32, 42, 58])\n",
1230
- "\n",
1231
- "# $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$\n",
1232
- "dsigma = lambda sigma, c, m, x: -len(x) / sigma + np.sum(\n",
1233
- " [(xi - (m * x + c)) ** 2 for xi in x]\n",
1234
- ") / (sigma**3)\n",
1235
- "# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$\n",
1236
- "dc = lambda sigma, c, m, x: -np.sum([xi - (m * x + c) for xi in x]) / (sigma**2)\n",
1237
- "# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$\n",
1238
- "dm = lambda sigma, c, m, x: -np.sum([x * (xi - (m * x + c)) for xi in x]) / (sigma**2)\n",
1239
- "\n",
1240
- "\n",
1241
- "log2 = []\n",
1242
- "\n",
1243
- "\n",
1244
- "def SGD_problem2(\n",
1245
- " sigma: float,\n",
1246
- " c: float,\n",
1247
- " m: float,\n",
1248
- " x: np.array,\n",
1249
- " y: np.array,\n",
1250
- " learning_rate=0.01,\n",
1251
- " n_epochs=1000,\n",
1252
- "):\n",
1253
- " global log2\n",
1254
- " log2 = []\n",
1255
- " for epoch in range(n_epochs):\n",
1256
- " sigma += learning_rate * dsigma(sigma, c, m, x)\n",
1257
- " c += learning_rate * dc(sigma, c, m, x)\n",
1258
- " m += learning_rate * dm(sigma, c, m, x)\n",
1259
- "\n",
1260
- " log2.append(\n",
1261
- " {\n",
1262
- " \"Epoch\": epoch,\n",
1263
- " \"New sigma\": sigma,\n",
1264
- " \"New c\": c,\n",
1265
- " \"New m\": m,\n",
1266
- " \"dc\": dc(sigma, c, m, x),\n",
1267
- " \"dm\": dm(sigma, c, m, x),\n",
1268
- " \"dsigma\": dsigma(sigma, c, m, x),\n",
1269
- " \"Loss\": loss((m * x + c), sigma, y),\n",
1270
- " }\n",
1271
- " )\n",
1272
- " print(f\"Epoch {epoch}, Loss: {loss((m * x + c), sigma, y)}\")\n",
1273
- " return np.array([sigma, c, m])\n",
1274
- "\n",
1275
- "\n",
1276
- "result = SGD_problem2(0.5, 0.5, 0.5, x, y)\n",
1277
- "print(f\"final parameters: m={result[2]}, c={result[1]}, sigma={result[0]}\")"
1278
- ]
1279
- },
1280
- {
1281
- "cell_type": "markdown",
1282
- "id": "0562b012-f4ca-47de-bc76-e0eb2bf1e509",
1283
- "metadata": {},
1284
- "source": [
1285
- "loss appears to be decreasing. Uncollapse cell for output"
1286
- ]
1287
- },
1288
- {
1289
- "cell_type": "markdown",
1290
- "id": "bed9f3ce-c15c-4f30-8906-26f3e51acf30",
1291
- "metadata": {},
1292
- "source": [
1293
- "# Bike Rides and the Poisson Model"
1294
- ]
1295
- },
1296
- {
1297
- "cell_type": "markdown",
1298
- "id": "975e2ef5-f5d5-45a3-b635-8faef035906f",
1299
- "metadata": {},
1300
- "source": [
1301
- "Knowing that the poisson pdf is $P(k) = \\frac{\\lambda^k e^{-\\lambda}}{k!}$, we can find the negative log likelihood of the data as $-\\log(\\Pi_{i=1}^n P(k_i)) = -\\sum_{i=1}^n \\log(\\frac{\\lambda^k_i e^{-\\lambda}}{k_i!}) = \\sum_{i=1}^n -\\ln(\\lambda) k_i + \\ln(k_i!) + \\lambda$. Which simplified, gives $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$. Differentiating with respect to $\\lambda$ gives $n - \\sum_{i=1}^n \\frac{k_i}{\\lambda}$. Which is our desired $\\frac{\\partial L}{\\partial \\lambda}$!"
1302
- ]
1303
- },
1304
- {
1305
- "cell_type": "code",
1306
- "execution_count": 10,
1307
- "id": "3877723c-179e-4759-bed5-9eb70110ded2",
1308
- "metadata": {
1309
- "tags": []
1310
- },
1311
- "outputs": [
1312
- {
1313
- "name": "stdout",
1314
- "output_type": "stream",
1315
- "text": [
1316
- "SGD Problem 3\n",
1317
- "l: [3215.17703224]\n",
1318
- "l diff at start 999.4184849065878\n",
1319
- "l diff at end 535.134976163929\n",
1320
- "l is improving\n",
1321
- "SGD Problem 3\n",
1322
- "l: [2326.70336987]\n",
1323
- "l diff at start -998.7262223631474\n",
1324
- "l diff at end -353.33868620734074\n",
1325
- "l is improving\n"
1326
- ]
1327
- },
1328
- {
1329
- "data": {
1330
- "text/html": [
1331
- "<div>\n",
1332
- "<style scoped>\n",
1333
- " .dataframe tbody tr th:only-of-type {\n",
1334
- " vertical-align: middle;\n",
1335
- " }\n",
1336
- "\n",
1337
- " .dataframe tbody tr th {\n",
1338
- " vertical-align: top;\n",
1339
- " }\n",
1340
- "\n",
1341
- " .dataframe thead th {\n",
1342
- " text-align: right;\n",
1343
- " }\n",
1344
- "</style>\n",
1345
- "<table border=\"1\" class=\"dataframe\">\n",
1346
- " <thead>\n",
1347
- " <tr style=\"text-align: right;\">\n",
1348
- " <th></th>\n",
1349
- " <th>Epoch</th>\n",
1350
- " <th>New lambda</th>\n",
1351
- " <th>dlambda</th>\n",
1352
- " <th>Loss</th>\n",
1353
- " <th>l_star</th>\n",
1354
- " </tr>\n",
1355
- " </thead>\n",
1356
- " <tbody>\n",
1357
- " <tr>\n",
1358
- " <th>0</th>\n",
1359
- " <td>0</td>\n",
1360
- " <td>1681.315834</td>\n",
1361
- " <td>-127.119133</td>\n",
1362
- " <td>-3.899989e+06</td>\n",
1363
- " <td>2680.042056</td>\n",
1364
- " </tr>\n",
1365
- " <tr>\n",
1366
- " <th>1</th>\n",
1367
- " <td>1</td>\n",
1368
- " <td>1682.587025</td>\n",
1369
- " <td>-126.861418</td>\n",
1370
- " <td>-3.900150e+06</td>\n",
1371
- " <td>2680.042056</td>\n",
1372
- " </tr>\n",
1373
- " <tr>\n",
1374
- " <th>2</th>\n",
1375
- " <td>2</td>\n",
1376
- " <td>1683.855639</td>\n",
1377
- " <td>-126.604614</td>\n",
1378
- " <td>-3.900311e+06</td>\n",
1379
- " <td>2680.042056</td>\n",
1380
- " </tr>\n",
1381
- " <tr>\n",
1382
- " <th>3</th>\n",
1383
- " <td>3</td>\n",
1384
- " <td>1685.121685</td>\n",
1385
- " <td>-126.348715</td>\n",
1386
- " <td>-3.900471e+06</td>\n",
1387
- " <td>2680.042056</td>\n",
1388
- " </tr>\n",
1389
- " <tr>\n",
1390
- " <th>4</th>\n",
1391
- " <td>4</td>\n",
1392
- " <td>1686.385173</td>\n",
1393
- " <td>-126.093716</td>\n",
1394
- " <td>-3.900631e+06</td>\n",
1395
- " <td>2680.042056</td>\n",
1396
- " </tr>\n",
1397
- " <tr>\n",
1398
- " <th>...</th>\n",
1399
- " <td>...</td>\n",
1400
- " <td>...</td>\n",
1401
- " <td>...</td>\n",
1402
- " <td>...</td>\n",
1403
- " <td>...</td>\n",
1404
- " </tr>\n",
1405
- " <tr>\n",
1406
- " <th>995</th>\n",
1407
- " <td>995</td>\n",
1408
- " <td>2325.399976</td>\n",
1409
- " <td>-32.636710</td>\n",
1410
- " <td>-3.948159e+06</td>\n",
1411
- " <td>2680.042056</td>\n",
1412
- " </tr>\n",
1413
- " <tr>\n",
1414
- " <th>996</th>\n",
1415
- " <td>996</td>\n",
1416
- " <td>2325.726343</td>\n",
1417
- " <td>-32.602100</td>\n",
1418
- " <td>-3.948170e+06</td>\n",
1419
- " <td>2680.042056</td>\n",
1420
- " </tr>\n",
1421
- " <tr>\n",
1422
- " <th>997</th>\n",
1423
- " <td>997</td>\n",
1424
- " <td>2326.052364</td>\n",
1425
- " <td>-32.567536</td>\n",
1426
- " <td>-3.948180e+06</td>\n",
1427
- " <td>2680.042056</td>\n",
1428
- " </tr>\n",
1429
- " <tr>\n",
1430
- " <th>998</th>\n",
1431
- " <td>998</td>\n",
1432
- " <td>2326.378040</td>\n",
1433
- " <td>-32.533018</td>\n",
1434
- " <td>-3.948191e+06</td>\n",
1435
- " <td>2680.042056</td>\n",
1436
- " </tr>\n",
1437
- " <tr>\n",
1438
- " <th>999</th>\n",
1439
- " <td>999</td>\n",
1440
- " <td>2326.703370</td>\n",
1441
- " <td>-32.498547</td>\n",
1442
- " <td>-3.948201e+06</td>\n",
1443
- " <td>2680.042056</td>\n",
1444
- " </tr>\n",
1445
- " </tbody>\n",
1446
- "</table>\n",
1447
- "<p>1000 rows Γ— 5 columns</p>\n",
1448
- "</div>"
1449
- ],
1450
- "text/plain": [
1451
- " Epoch New lambda dlambda Loss l_star\n",
1452
- "0 0 1681.315834 -127.119133 -3.899989e+06 2680.042056\n",
1453
- "1 1 1682.587025 -126.861418 -3.900150e+06 2680.042056\n",
1454
- "2 2 1683.855639 -126.604614 -3.900311e+06 2680.042056\n",
1455
- "3 3 1685.121685 -126.348715 -3.900471e+06 2680.042056\n",
1456
- "4 4 1686.385173 -126.093716 -3.900631e+06 2680.042056\n",
1457
- ".. ... ... ... ... ...\n",
1458
- "995 995 2325.399976 -32.636710 -3.948159e+06 2680.042056\n",
1459
- "996 996 2325.726343 -32.602100 -3.948170e+06 2680.042056\n",
1460
- "997 997 2326.052364 -32.567536 -3.948180e+06 2680.042056\n",
1461
- "998 998 2326.378040 -32.533018 -3.948191e+06 2680.042056\n",
1462
- "999 999 2326.703370 -32.498547 -3.948201e+06 2680.042056\n",
1463
- "\n",
1464
- "[1000 rows x 5 columns]"
1465
- ]
1466
- },
1467
- "execution_count": 10,
1468
- "metadata": {},
1469
- "output_type": "execute_result"
1470
- }
1471
- ],
1472
- "source": [
1473
- "import pandas as pd\n",
1474
- "\n",
1475
- "df = pd.read_csv(\"../data/01_raw/nyc_bb_bicyclist_counts.csv\")\n",
1476
- "\n",
1477
- "dlambda = lambda l, k: len(k) - np.sum([ki / l for ki in k])\n",
1478
- "\n",
1479
- "\n",
1480
- "def SGD_problem3(\n",
1481
- " l: float,\n",
1482
- " k: np.array,\n",
1483
- " learning_rate=0.01,\n",
1484
- " n_epochs=1000,\n",
1485
- "):\n",
1486
- " global log3\n",
1487
- " log3 = []\n",
1488
- " for epoch in range(n_epochs):\n",
1489
- " l -= learning_rate * dlambda(l, k)\n",
1490
- " # $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$\n",
1491
- " # the rest of the loss function is commented out because it's a\n",
1492
- " # constant and was causing overflows. It is unnecessary, and a useless\n",
1493
- " # pain.\n",
1494
- " loss = len(k) * l - np.sum(\n",
1495
- " [ki * np.log(l) for ki in k]\n",
1496
- " ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n",
1497
- "\n",
1498
- " log3.append(\n",
1499
- " {\n",
1500
- " \"Epoch\": epoch,\n",
1501
- " \"New lambda\": l,\n",
1502
- " \"dlambda\": dlambda(l, k),\n",
1503
- " \"Loss\": loss,\n",
1504
- " }\n",
1505
- " )\n",
1506
- " # print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n",
1507
- " return np.array([l])\n",
1508
- "\n",
1509
- "\n",
1510
- "l_star = df[\"BB_COUNT\"].mean()\n",
1511
- "\n",
1512
- "\n",
1513
- "def debug_SGD_3(data, l=1000):\n",
1514
- " print(\"SGD Problem 3\")\n",
1515
- " print(f\"l: {SGD_problem3(l, data)}\")\n",
1516
- " dflog = pd.DataFrame(log3)\n",
1517
- " dflog[\"l_star\"] = l_star\n",
1518
- " print(f\"l diff at start {dflog.iloc[0]['New lambda'] - dflog.iloc[0]['l_star']}\")\n",
1519
- " print(f\"l diff at end {dflog.iloc[-1]['New lambda'] - dflog.iloc[-1]['l_star']}\")\n",
1520
- " if np.abs(dflog.iloc[-1][\"New lambda\"] - dflog.iloc[-1][\"l_star\"]) < np.abs(\n",
1521
- " dflog.iloc[0][\"New lambda\"] - dflog.iloc[0][\"l_star\"]\n",
1522
- " ):\n",
1523
- " print(\"l is improving\")\n",
1524
- " else:\n",
1525
- " print(\"l is not improving\")\n",
1526
- " return dflog\n",
1527
- "\n",
1528
- "\n",
1529
- "debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star + 1000)\n",
1530
- "debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star - 1000)"
1531
- ]
1532
- },
1533
- {
1534
- "cell_type": "markdown",
1535
- "id": "c05192f9-78ae-4bdb-9df5-cac91006d79f",
1536
- "metadata": {},
1537
- "source": [
1538
- "l approaches the l_star and decreases the loss function."
1539
- ]
1540
- },
1541
- {
1542
- "cell_type": "markdown",
1543
- "id": "4955b868-7f67-4760-bf86-39f6edd55871",
1544
- "metadata": {},
1545
- "source": [
1546
- "## Maximum Likelihood II"
1547
- ]
1548
- },
1549
- {
1550
- "cell_type": "markdown",
1551
- "id": "cd7e6e62-3f64-43e5-bf2c-3cb514411446",
1552
- "metadata": {},
1553
- "source": [
1554
- "The partial of the poisson was found to be $nE^{w.x}*x - \\sum_{i=1}^{n}k x.x$"
1555
- ]
1556
- },
1557
- {
1558
- "cell_type": "code",
1559
- "execution_count": 18,
1560
- "id": "7c8b167d-c397-4155-93f3-d826c279fbb2",
1561
- "metadata": {
1562
- "tags": []
1563
- },
1564
- "outputs": [
1565
- {
1566
- "name": "stdout",
1567
- "output_type": "stream",
1568
- "text": [
1569
- "SGD Problem 4\n"
1570
- ]
1571
- },
1572
- {
1573
- "name": "stderr",
1574
- "output_type": "stream",
1575
- "text": [
1576
- "/tmp/ipykernel_615396/2481416868.py:22: RuntimeWarning: divide by zero encountered in log\n",
1577
- " [ki * np.log(l) for ki in k]\n"
1578
- ]
1579
- },
1580
- {
1581
- "ename": "ValueError",
1582
- "evalue": "operands could not be broadcast together with shapes (3,) (214,) ",
1583
- "output_type": "error",
1584
- "traceback": [
1585
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1586
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
1587
- "Cell \u001b[0;32mIn[18], line 44\u001b[0m\n\u001b[1;32m 40\u001b[0m dflog \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(log4)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m dflog\n\u001b[0;32m---> 44\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mdebug_SGD_3\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 45\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mHIGH_T\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLOW_T\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mPRECIP\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_numpy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[43m)\u001b[49m\n",
1588
- "Cell \u001b[0;32mIn[18], line 39\u001b[0m, in \u001b[0;36mdebug_SGD_3\u001b[0;34m(data, w)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdebug_SGD_3\u001b[39m(data, w\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39marray([\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m])):\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSGD Problem 4\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 39\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[43mSGD_problem4\u001b[49m\u001b[43m(\u001b[49m\u001b[43mw\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;250;43m \u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 40\u001b[0m dflog \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(log4)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m dflog\n",
1589
- "Cell \u001b[0;32mIn[18], line 24\u001b[0m, in \u001b[0;36mSGD_problem4\u001b[0;34m(w, x, learning_rate, n_epochs)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# custom\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# loss = x.shape[0] * np.exp(np.dot(x, w))\u001b[39;00m\n\u001b[1;32m 21\u001b[0m loss_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m k, l: \u001b[38;5;28mlen\u001b[39m(k) \u001b[38;5;241m*\u001b[39m l \u001b[38;5;241m-\u001b[39m np\u001b[38;5;241m.\u001b[39msum(\n\u001b[1;32m 22\u001b[0m [ki \u001b[38;5;241m*\u001b[39m np\u001b[38;5;241m.\u001b[39mlog(l) \u001b[38;5;28;01mfor\u001b[39;00m ki \u001b[38;5;129;01min\u001b[39;00m k]\n\u001b[1;32m 23\u001b[0m ) \u001b[38;5;66;03m# + np.sum([np.log(np.math.factorial(ki)) for ki in k])\u001b[39;00m\n\u001b[0;32m---> 24\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mloss_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdot\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 25\u001b[0m log4\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 26\u001b[0m {\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch\u001b[39m\u001b[38;5;124m\"\u001b[39m: epoch,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m }\n\u001b[1;32m 32\u001b[0m )\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloss\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
1590
- "Cell \u001b[0;32mIn[18], line 22\u001b[0m, in \u001b[0;36mSGD_problem4.<locals>.<lambda>\u001b[0;34m(k, l)\u001b[0m\n\u001b[1;32m 18\u001b[0m w \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m learning_rate \u001b[38;5;241m*\u001b[39m dw(w, x)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# custom\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# loss = x.shape[0] * np.exp(np.dot(x, w))\u001b[39;00m\n\u001b[1;32m 21\u001b[0m loss_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m k, l: \u001b[38;5;28mlen\u001b[39m(k) \u001b[38;5;241m*\u001b[39m l \u001b[38;5;241m-\u001b[39m np\u001b[38;5;241m.\u001b[39msum(\n\u001b[0;32m---> 22\u001b[0m [ki \u001b[38;5;241m*\u001b[39m np\u001b[38;5;241m.\u001b[39mlog(l) \u001b[38;5;28;01mfor\u001b[39;00m ki \u001b[38;5;129;01min\u001b[39;00m k]\n\u001b[1;32m 23\u001b[0m ) \u001b[38;5;66;03m# + np.sum([np.log(np.math.factorial(ki)) for ki in k])\u001b[39;00m\n\u001b[1;32m 24\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_fn(x, np\u001b[38;5;241m.\u001b[39mexp(np\u001b[38;5;241m.\u001b[39mdot(x, w)))\n\u001b[1;32m 25\u001b[0m log4\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 26\u001b[0m {\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch\u001b[39m\u001b[38;5;124m\"\u001b[39m: epoch,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m }\n\u001b[1;32m 32\u001b[0m )\n",
1591
- "Cell \u001b[0;32mIn[18], line 22\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 18\u001b[0m w \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m learning_rate \u001b[38;5;241m*\u001b[39m dw(w, x)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# custom\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# loss = x.shape[0] * np.exp(np.dot(x, w))\u001b[39;00m\n\u001b[1;32m 21\u001b[0m loss_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m k, l: \u001b[38;5;28mlen\u001b[39m(k) \u001b[38;5;241m*\u001b[39m l \u001b[38;5;241m-\u001b[39m np\u001b[38;5;241m.\u001b[39msum(\n\u001b[0;32m---> 22\u001b[0m [\u001b[43mki\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43ml\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m ki \u001b[38;5;129;01min\u001b[39;00m k]\n\u001b[1;32m 23\u001b[0m ) \u001b[38;5;66;03m# + np.sum([np.log(np.math.factorial(ki)) for ki in k])\u001b[39;00m\n\u001b[1;32m 24\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss_fn(x, np\u001b[38;5;241m.\u001b[39mexp(np\u001b[38;5;241m.\u001b[39mdot(x, w)))\n\u001b[1;32m 25\u001b[0m log4\u001b[38;5;241m.\u001b[39mappend(\n\u001b[1;32m 26\u001b[0m {\n\u001b[1;32m 27\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch\u001b[39m\u001b[38;5;124m\"\u001b[39m: epoch,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m }\n\u001b[1;32m 32\u001b[0m )\n",
1592
- "\u001b[0;31mValueError\u001b[0m: operands could not be broadcast together with shapes (3,) (214,) "
1593
- ]
1594
- }
1595
- ],
1596
- "source": [
1597
- "## pset 4\n",
1598
- "\n",
1599
- "dw = lambda w, x: np.sum([len(x) * np.exp(np.dot(xi, w)) * x - np.sum(np.dot(x.T,x)) for xi in x])\n",
1600
- "\n",
1601
- "#primitive = lambda xi, wi: (x.shape[0] * np.exp(wi * xi) * xi) - (xi**2)\n",
1602
- "#p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])\n",
1603
- "\n",
1604
- "\n",
1605
- "def SGD_problem4(\n",
1606
- " w: np.array,\n",
1607
- " x: np.array,\n",
1608
- " learning_rate=0.01,\n",
1609
- " n_epochs=1000,\n",
1610
- "):\n",
1611
- " global log4\n",
1612
- " log4 = []\n",
1613
- " for epoch in range(n_epochs):\n",
1614
- " w -= learning_rate * dw(w, x)\n",
1615
- " # custom\n",
1616
- " # loss = x.shape[0] * np.exp(np.dot(x, w))\n",
1617
- " loss_fn = lambda k, l: len(k) * l - np.sum(\n",
1618
- " [ki * np.log(l) for ki in k]\n",
1619
- " ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n",
1620
- " loss = loss_fn(x, np.exp(np.dot(x, w)))\n",
1621
- " log4.append(\n",
1622
- " {\n",
1623
- " \"Epoch\": epoch,\n",
1624
- " \"New w\": w,\n",
1625
- " \"dw\": dw(w, x),\n",
1626
- " \"Loss\": loss,\n",
1627
- " }\n",
1628
- " )\n",
1629
- " print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n",
1630
- " return w\n",
1631
- "\n",
1632
- "\n",
1633
- "def debug_SGD_3(data, w=np.array([1, 1])):\n",
1634
- " print(\"SGD Problem 4\")\n",
1635
- " print(f\"w: {SGD_problem4(w, data)}\")\n",
1636
- " dflog = pd.DataFrame(log4)\n",
1637
- " return dflog\n",
1638
- "\n",
1639
- "\n",
1640
- "_ = debug_SGD_3(\n",
1641
- " data=df[[\"HIGH_T\", \"LOW_T\", \"PRECIP\"]].to_numpy(),\n",
1642
- " w=np.array([1.0, 1.0, 1.0]),\n",
1643
- ")"
1644
- ]
1645
- },
1646
- {
1647
- "cell_type": "code",
1648
- "execution_count": null,
1649
- "id": "7c00197d-873d-41b0-a458-dc8478b40f52",
1650
- "metadata": {},
1651
- "outputs": [],
1652
- "source": []
1653
- }
1654
- ],
1655
- "metadata": {
1656
- "kernelspec": {
1657
- "display_name": "Python 3 (ipykernel)",
1658
- "language": "python",
1659
- "name": "python3"
1660
- },
1661
- "language_info": {
1662
- "codemirror_mode": {
1663
- "name": "ipython",
1664
- "version": 3
1665
- },
1666
- "file_extension": ".py",
1667
- "mimetype": "text/x-python",
1668
- "name": "python",
1669
- "nbconvert_exporter": "python",
1670
- "pygments_lexer": "ipython3",
1671
- "version": "3.10.4"
1672
- }
1673
- },
1674
- "nbformat": 4,
1675
- "nbformat_minor": 5
1676
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assignment-2/data/01_raw/nyc_bb_bicyclist_counts.csv DELETED
@@ -1,215 +0,0 @@
1
- Date,HIGH_T,LOW_T,PRECIP,BB_COUNT
2
- 1-Apr-17,46.00,37.00,0.00,606
3
- 2-Apr-17,62.10,41.00,0.00,2021
4
- 3-Apr-17,63.00,50.00,0.03,2470
5
- 4-Apr-17,51.10,46.00,1.18,723
6
- 5-Apr-17,63.00,46.00,0.00,2807
7
- 6-Apr-17,48.90,41.00,0.73,461
8
- 7-Apr-17,48.00,43.00,0.01,1222
9
- 8-Apr-17,55.90,39.90,0.00,1674
10
- 9-Apr-17,66.00,45.00,0.00,2375
11
- 10-Apr-17,73.90,55.00,0.00,3324
12
- 11-Apr-17,80.10,62.10,0.00,3887
13
- 12-Apr-17,73.90,57.90,0.02,2565
14
- 13-Apr-17,64.00,48.90,0.00,3353
15
- 14-Apr-17,64.90,48.90,0.00,2942
16
- 15-Apr-17,64.90,52.00,0.00,2253
17
- 16-Apr-17,84.90,62.10,0.01,2877
18
- 17-Apr-17,73.90,64.00,0.01,3152
19
- 18-Apr-17,66.00,50.00,0.00,3415
20
- 19-Apr-17,52.00,45.00,0.01,1965
21
- 20-Apr-17,64.90,50.00,0.17,1567
22
- 21-Apr-17,53.10,48.00,0.29,1426
23
- 22-Apr-17,55.90,52.00,0.11,1318
24
- 23-Apr-17,64.90,46.90,0.00,2520
25
- 24-Apr-17,60.10,50.00,0.01,2544
26
- 25-Apr-17,54.00,50.00,0.91,611
27
- 26-Apr-17,59.00,54.00,0.34,1247
28
- 27-Apr-17,68.00,59.00,0.00,2959
29
- 28-Apr-17,82.90,57.90,0.00,3679
30
- 29-Apr-17,84.00,64.00,0.06,3315
31
- 30-Apr-17,64.00,54.00,0.00,2225
32
- 1-May-17,72.00,50.00,0.00,3084
33
- 2-May-17,73.90,66.90,0.00,3423
34
- 3-May-17,64.90,57.90,0.00,3342
35
- 4-May-17,63.00,50.00,0.00,3019
36
- 5-May-17,59.00,52.00,3.02,513
37
- 6-May-17,64.90,57.00,0.18,1892
38
- 7-May-17,54.00,48.90,0.01,3539
39
- 8-May-17,57.00,45.00,0.00,2886
40
- 9-May-17,61.00,48.00,0.00,2718
41
- 10-May-17,70.00,51.10,0.00,2810
42
- 11-May-17,61.00,51.80,0.00,2657
43
- 12-May-17,62.10,51.10,0.00,2640
44
- 13-May-17,51.10,45.00,1.31,151
45
- 14-May-17,64.90,46.00,0.02,1452
46
- 15-May-17,66.90,55.90,0.00,2685
47
- 16-May-17,78.10,57.90,0.00,3666
48
- 17-May-17,90.00,66.00,0.00,3535
49
- 18-May-17,91.90,75.00,0.00,3190
50
- 19-May-17,90.00,75.90,0.00,2952
51
- 20-May-17,64.00,55.90,0.01,2161
52
- 21-May-17,66.90,55.00,0.00,2612
53
- 22-May-17,61.00,54.00,0.59,768
54
- 23-May-17,68.00,57.90,0.00,3174
55
- 24-May-17,66.90,57.00,0.04,2969
56
- 25-May-17,57.90,55.90,0.58,488
57
- 26-May-17,73.00,55.90,0.10,2590
58
- 27-May-17,71.10,61.00,0.00,2609
59
- 28-May-17,71.10,59.00,0.00,2640
60
- 29-May-17,57.90,55.90,0.13,836
61
- 30-May-17,59.00,55.90,0.06,2301
62
- 31-May-17,75.00,57.90,0.03,2689
63
- 1-Jun-17,78.10,62.10,0.00,3468
64
- 2-Jun-17,73.90,60.10,0.01,3271
65
- 3-Jun-17,72.00,55.00,0.01,2589
66
- 4-Jun-17,68.00,60.10,0.09,1805
67
- 5-Jun-17,66.90,60.10,0.02,2171
68
- 6-Jun-17,55.90,53.10,0.06,1193
69
- 7-Jun-17,66.90,54.00,0.00,3211
70
- 8-Jun-17,68.00,59.00,0.00,3253
71
- 9-Jun-17,80.10,59.00,0.00,3401
72
- 10-Jun-17,84.00,68.00,0.00,3066
73
- 11-Jun-17,90.00,73.00,0.00,2465
74
- 12-Jun-17,91.90,77.00,0.00,2854
75
- 13-Jun-17,93.90,78.10,0.01,2882
76
- 14-Jun-17,84.00,69.10,0.29,2596
77
- 15-Jun-17,75.00,66.00,0.00,3510
78
- 16-Jun-17,68.00,66.00,0.00,2054
79
- 17-Jun-17,73.00,66.90,1.39,1399
80
- 18-Jun-17,84.00,72.00,0.01,2199
81
- 19-Jun-17,87.10,70.00,1.35,1648
82
- 20-Jun-17,82.00,72.00,0.03,3407
83
- 21-Jun-17,82.00,72.00,0.00,3304
84
- 22-Jun-17,82.00,70.00,0.00,3368
85
- 23-Jun-17,82.90,75.90,0.04,2283
86
- 24-Jun-17,82.90,71.10,1.29,2307
87
- 25-Jun-17,82.00,69.10,0.00,2625
88
- 26-Jun-17,78.10,66.00,0.00,3386
89
- 27-Jun-17,75.90,61.00,0.18,3182
90
- 28-Jun-17,78.10,62.10,0.00,3766
91
- 29-Jun-17,81.00,68.00,0.00,3356
92
- 30-Jun-17,88.00,73.90,0.01,2687
93
- 1-Jul-17,84.90,72.00,0.23,1848
94
- 2-Jul-17,87.10,73.00,0.00,2467
95
- 3-Jul-17,87.10,71.10,0.45,2714
96
- 4-Jul-17,82.90,70.00,0.00,2296
97
- 5-Jul-17,84.90,71.10,0.00,3170
98
- 6-Jul-17,75.00,71.10,0.01,3065
99
- 7-Jul-17,79.00,68.00,1.78,1513
100
- 8-Jul-17,82.90,70.00,0.00,2718
101
- 9-Jul-17,81.00,69.10,0.00,3048
102
- 10-Jul-17,82.90,71.10,0.00,3506
103
- 11-Jul-17,84.00,75.00,0.00,2929
104
- 12-Jul-17,87.10,77.00,0.00,2860
105
- 13-Jul-17,89.10,77.00,0.00,2563
106
- 14-Jul-17,69.10,64.90,0.35,907
107
- 15-Jul-17,82.90,68.00,0.00,2853
108
- 16-Jul-17,84.90,70.00,0.00,2917
109
- 17-Jul-17,84.90,73.90,0.00,3264
110
- 18-Jul-17,87.10,75.90,0.00,3507
111
- 19-Jul-17,91.00,77.00,0.00,3114
112
- 20-Jul-17,93.00,78.10,0.01,2840
113
- 21-Jul-17,91.00,77.00,0.00,2751
114
- 22-Jul-17,91.00,78.10,0.57,2301
115
- 23-Jul-17,78.10,73.00,0.06,2321
116
- 24-Jul-17,69.10,63.00,0.74,1576
117
- 25-Jul-17,71.10,64.00,0.00,3191
118
- 26-Jul-17,75.90,66.00,0.00,3821
119
- 27-Jul-17,77.00,66.90,0.01,3287
120
- 28-Jul-17,84.90,73.00,0.00,3123
121
- 29-Jul-17,75.90,68.00,0.00,2074
122
- 30-Jul-17,81.00,64.90,0.00,3331
123
- 31-Jul-17,88.00,66.90,0.00,3560
124
- 1-Aug-17,91.00,72.00,0.00,3492
125
- 2-Aug-17,86.00,69.10,0.09,2637
126
- 3-Aug-17,86.00,70.00,0.00,3346
127
- 4-Aug-17,82.90,70.00,0.15,2400
128
- 5-Aug-17,77.00,70.00,0.30,3409
129
- 6-Aug-17,75.90,64.00,0.00,3130
130
- 7-Aug-17,71.10,64.90,0.76,804
131
- 8-Aug-17,77.00,66.00,0.00,3598
132
- 9-Aug-17,82.90,66.00,0.00,3893
133
- 10-Aug-17,82.90,69.10,0.00,3423
134
- 11-Aug-17,81.00,70.00,0.01,3148
135
- 12-Aug-17,75.90,64.90,0.11,4146
136
- 13-Aug-17,82.00,71.10,0.00,3274
137
- 14-Aug-17,80.10,70.00,0.00,3291
138
- 15-Aug-17,73.00,69.10,0.45,2149
139
- 16-Aug-17,84.90,70.00,0.00,3685
140
- 17-Aug-17,82.00,71.10,0.00,3637
141
- 18-Aug-17,81.00,73.00,0.88,1064
142
- 19-Aug-17,84.90,73.00,0.00,4693
143
- 20-Aug-17,81.00,70.00,0.00,2822
144
- 21-Aug-17,84.90,73.00,0.00,3088
145
- 22-Aug-17,88.00,75.00,0.30,2983
146
- 23-Aug-17,80.10,71.10,0.01,2994
147
- 24-Aug-17,79.00,66.00,0.00,3688
148
- 25-Aug-17,78.10,64.00,0.00,3144
149
- 26-Aug-17,77.00,62.10,0.00,2710
150
- 27-Aug-17,77.00,63.00,0.00,2676
151
- 28-Aug-17,75.00,63.00,0.00,3332
152
- 29-Aug-17,68.00,62.10,0.10,1472
153
- 30-Aug-17,75.90,61.00,0.01,3468
154
- 31-Aug-17,81.00,64.00,0.00,3279
155
- 1-Sep-17,70.00,55.00,0.00,2945
156
- 2-Sep-17,66.90,54.00,0.53,1876
157
- 3-Sep-17,69.10,60.10,0.74,1004
158
- 4-Sep-17,79.00,62.10,0.00,2866
159
- 5-Sep-17,84.00,70.00,0.01,3244
160
- 6-Sep-17,70.00,62.10,0.42,1232
161
- 7-Sep-17,71.10,59.00,0.01,3249
162
- 8-Sep-17,70.00,59.00,0.00,3234
163
- 9-Sep-17,69.10,55.00,0.00,2609
164
- 10-Sep-17,72.00,57.00,0.00,4960
165
- 11-Sep-17,75.90,55.00,0.00,3657
166
- 12-Sep-17,78.10,61.00,0.00,3497
167
- 13-Sep-17,82.00,64.90,0.06,2994
168
- 14-Sep-17,81.00,70.00,0.02,3013
169
- 15-Sep-17,81.00,66.90,0.00,3344
170
- 16-Sep-17,82.00,70.00,0.00,2560
171
- 17-Sep-17,80.10,70.00,0.00,2676
172
- 18-Sep-17,73.00,69.10,0.00,2673
173
- 19-Sep-17,78.10,69.10,0.22,2012
174
- 20-Sep-17,78.10,71.10,0.00,3296
175
- 21-Sep-17,80.10,71.10,0.00,3317
176
- 22-Sep-17,82.00,66.00,0.00,3297
177
- 23-Sep-17,86.00,68.00,0.00,2810
178
- 24-Sep-17,90.00,69.10,0.00,2543
179
- 25-Sep-17,87.10,72.00,0.00,3276
180
- 26-Sep-17,82.00,69.10,0.00,3157
181
- 27-Sep-17,84.90,71.10,0.00,3216
182
- 28-Sep-17,78.10,66.00,0.00,3421
183
- 29-Sep-17,66.90,55.00,0.00,2988
184
- 30-Sep-17,64.00,55.90,0.00,1903
185
- 1-Oct-17,66.90,50.00,0.00,2297
186
- 2-Oct-17,72.00,52.00,0.00,3387
187
- 3-Oct-17,70.00,57.00,0.00,3386
188
- 4-Oct-17,75.00,55.90,0.00,3412
189
- 5-Oct-17,82.00,64.90,0.00,3312
190
- 6-Oct-17,81.00,69.10,0.00,2982
191
- 7-Oct-17,80.10,66.00,0.00,2750
192
- 8-Oct-17,77.00,72.00,0.22,1235
193
- 9-Oct-17,75.90,72.00,0.26,898
194
- 10-Oct-17,80.10,66.00,0.00,3922
195
- 11-Oct-17,75.00,64.90,0.06,2721
196
- 12-Oct-17,63.00,55.90,0.07,2411
197
- 13-Oct-17,64.90,52.00,0.00,2839
198
- 14-Oct-17,71.10,62.10,0.08,2021
199
- 15-Oct-17,72.00,66.00,0.01,2169
200
- 16-Oct-17,60.10,52.00,0.01,2751
201
- 17-Oct-17,57.90,43.00,0.00,2869
202
- 18-Oct-17,71.10,50.00,0.00,3264
203
- 19-Oct-17,70.00,55.90,0.00,3265
204
- 20-Oct-17,73.00,57.90,0.00,3169
205
- 21-Oct-17,78.10,57.00,0.00,2538
206
- 22-Oct-17,75.90,57.00,0.00,2744
207
- 23-Oct-17,73.90,64.00,0.00,3189
208
- 24-Oct-17,73.00,66.90,0.20,954
209
- 25-Oct-17,64.90,57.90,0.00,3367
210
- 26-Oct-17,57.00,53.10,0.00,2565
211
- 27-Oct-17,62.10,48.00,0.00,3150
212
- 28-Oct-17,68.00,55.90,0.00,2245
213
- 29-Oct-17,64.90,61.00,3.03,183
214
- 30-Oct-17,55.00,46.00,0.25,1428
215
- 31-Oct-17,54.00,44.00,0.00,2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assignment-2/data/nyc_bb_bicyclist_counts.csv DELETED
@@ -1,215 +0,0 @@
1
- Date,HIGH_T,LOW_T,PRECIP,BB_COUNT
2
- 1-Apr-17,46.00,37.00,0.00,606
3
- 2-Apr-17,62.10,41.00,0.00,2021
4
- 3-Apr-17,63.00,50.00,0.03,2470
5
- 4-Apr-17,51.10,46.00,1.18,723
6
- 5-Apr-17,63.00,46.00,0.00,2807
7
- 6-Apr-17,48.90,41.00,0.73,461
8
- 7-Apr-17,48.00,43.00,0.01,1222
9
- 8-Apr-17,55.90,39.90,0.00,1674
10
- 9-Apr-17,66.00,45.00,0.00,2375
11
- 10-Apr-17,73.90,55.00,0.00,3324
12
- 11-Apr-17,80.10,62.10,0.00,3887
13
- 12-Apr-17,73.90,57.90,0.02,2565
14
- 13-Apr-17,64.00,48.90,0.00,3353
15
- 14-Apr-17,64.90,48.90,0.00,2942
16
- 15-Apr-17,64.90,52.00,0.00,2253
17
- 16-Apr-17,84.90,62.10,0.01,2877
18
- 17-Apr-17,73.90,64.00,0.01,3152
19
- 18-Apr-17,66.00,50.00,0.00,3415
20
- 19-Apr-17,52.00,45.00,0.01,1965
21
- 20-Apr-17,64.90,50.00,0.17,1567
22
- 21-Apr-17,53.10,48.00,0.29,1426
23
- 22-Apr-17,55.90,52.00,0.11,1318
24
- 23-Apr-17,64.90,46.90,0.00,2520
25
- 24-Apr-17,60.10,50.00,0.01,2544
26
- 25-Apr-17,54.00,50.00,0.91,611
27
- 26-Apr-17,59.00,54.00,0.34,1247
28
- 27-Apr-17,68.00,59.00,0.00,2959
29
- 28-Apr-17,82.90,57.90,0.00,3679
30
- 29-Apr-17,84.00,64.00,0.06,3315
31
- 30-Apr-17,64.00,54.00,0.00,2225
32
- 1-May-17,72.00,50.00,0.00,3084
33
- 2-May-17,73.90,66.90,0.00,3423
34
- 3-May-17,64.90,57.90,0.00,3342
35
- 4-May-17,63.00,50.00,0.00,3019
36
- 5-May-17,59.00,52.00,3.02,513
37
- 6-May-17,64.90,57.00,0.18,1892
38
- 7-May-17,54.00,48.90,0.01,3539
39
- 8-May-17,57.00,45.00,0.00,2886
40
- 9-May-17,61.00,48.00,0.00,2718
41
- 10-May-17,70.00,51.10,0.00,2810
42
- 11-May-17,61.00,51.80,0.00,2657
43
- 12-May-17,62.10,51.10,0.00,2640
44
- 13-May-17,51.10,45.00,1.31,151
45
- 14-May-17,64.90,46.00,0.02,1452
46
- 15-May-17,66.90,55.90,0.00,2685
47
- 16-May-17,78.10,57.90,0.00,3666
48
- 17-May-17,90.00,66.00,0.00,3535
49
- 18-May-17,91.90,75.00,0.00,3190
50
- 19-May-17,90.00,75.90,0.00,2952
51
- 20-May-17,64.00,55.90,0.01,2161
52
- 21-May-17,66.90,55.00,0.00,2612
53
- 22-May-17,61.00,54.00,0.59,768
54
- 23-May-17,68.00,57.90,0.00,3174
55
- 24-May-17,66.90,57.00,0.04,2969
56
- 25-May-17,57.90,55.90,0.58,488
57
- 26-May-17,73.00,55.90,0.10,2590
58
- 27-May-17,71.10,61.00,0.00,2609
59
- 28-May-17,71.10,59.00,0.00,2640
60
- 29-May-17,57.90,55.90,0.13,836
61
- 30-May-17,59.00,55.90,0.06,2301
62
- 31-May-17,75.00,57.90,0.03,2689
63
- 1-Jun-17,78.10,62.10,0.00,3468
64
- 2-Jun-17,73.90,60.10,0.01,3271
65
- 3-Jun-17,72.00,55.00,0.01,2589
66
- 4-Jun-17,68.00,60.10,0.09,1805
67
- 5-Jun-17,66.90,60.10,0.02,2171
68
- 6-Jun-17,55.90,53.10,0.06,1193
69
- 7-Jun-17,66.90,54.00,0.00,3211
70
- 8-Jun-17,68.00,59.00,0.00,3253
71
- 9-Jun-17,80.10,59.00,0.00,3401
72
- 10-Jun-17,84.00,68.00,0.00,3066
73
- 11-Jun-17,90.00,73.00,0.00,2465
74
- 12-Jun-17,91.90,77.00,0.00,2854
75
- 13-Jun-17,93.90,78.10,0.01,2882
76
- 14-Jun-17,84.00,69.10,0.29,2596
77
- 15-Jun-17,75.00,66.00,0.00,3510
78
- 16-Jun-17,68.00,66.00,0.00,2054
79
- 17-Jun-17,73.00,66.90,1.39,1399
80
- 18-Jun-17,84.00,72.00,0.01,2199
81
- 19-Jun-17,87.10,70.00,1.35,1648
82
- 20-Jun-17,82.00,72.00,0.03,3407
83
- 21-Jun-17,82.00,72.00,0.00,3304
84
- 22-Jun-17,82.00,70.00,0.00,3368
85
- 23-Jun-17,82.90,75.90,0.04,2283
86
- 24-Jun-17,82.90,71.10,1.29,2307
87
- 25-Jun-17,82.00,69.10,0.00,2625
88
- 26-Jun-17,78.10,66.00,0.00,3386
89
- 27-Jun-17,75.90,61.00,0.18,3182
90
- 28-Jun-17,78.10,62.10,0.00,3766
91
- 29-Jun-17,81.00,68.00,0.00,3356
92
- 30-Jun-17,88.00,73.90,0.01,2687
93
- 1-Jul-17,84.90,72.00,0.23,1848
94
- 2-Jul-17,87.10,73.00,0.00,2467
95
- 3-Jul-17,87.10,71.10,0.45,2714
96
- 4-Jul-17,82.90,70.00,0.00,2296
97
- 5-Jul-17,84.90,71.10,0.00,3170
98
- 6-Jul-17,75.00,71.10,0.01,3065
99
- 7-Jul-17,79.00,68.00,1.78,1513
100
- 8-Jul-17,82.90,70.00,0.00,2718
101
- 9-Jul-17,81.00,69.10,0.00,3048
102
- 10-Jul-17,82.90,71.10,0.00,3506
103
- 11-Jul-17,84.00,75.00,0.00,2929
104
- 12-Jul-17,87.10,77.00,0.00,2860
105
- 13-Jul-17,89.10,77.00,0.00,2563
106
- 14-Jul-17,69.10,64.90,0.35,907
107
- 15-Jul-17,82.90,68.00,0.00,2853
108
- 16-Jul-17,84.90,70.00,0.00,2917
109
- 17-Jul-17,84.90,73.90,0.00,3264
110
- 18-Jul-17,87.10,75.90,0.00,3507
111
- 19-Jul-17,91.00,77.00,0.00,3114
112
- 20-Jul-17,93.00,78.10,0.01,2840
113
- 21-Jul-17,91.00,77.00,0.00,2751
114
- 22-Jul-17,91.00,78.10,0.57,2301
115
- 23-Jul-17,78.10,73.00,0.06,2321
116
- 24-Jul-17,69.10,63.00,0.74,1576
117
- 25-Jul-17,71.10,64.00,0.00,3191
118
- 26-Jul-17,75.90,66.00,0.00,3821
119
- 27-Jul-17,77.00,66.90,0.01,3287
120
- 28-Jul-17,84.90,73.00,0.00,3123
121
- 29-Jul-17,75.90,68.00,0.00,2074
122
- 30-Jul-17,81.00,64.90,0.00,3331
123
- 31-Jul-17,88.00,66.90,0.00,3560
124
- 1-Aug-17,91.00,72.00,0.00,3492
125
- 2-Aug-17,86.00,69.10,0.09,2637
126
- 3-Aug-17,86.00,70.00,0.00,3346
127
- 4-Aug-17,82.90,70.00,0.15,2400
128
- 5-Aug-17,77.00,70.00,0.30,3409
129
- 6-Aug-17,75.90,64.00,0.00,3130
130
- 7-Aug-17,71.10,64.90,0.76,804
131
- 8-Aug-17,77.00,66.00,0.00,3598
132
- 9-Aug-17,82.90,66.00,0.00,3893
133
- 10-Aug-17,82.90,69.10,0.00,3423
134
- 11-Aug-17,81.00,70.00,0.01,3148
135
- 12-Aug-17,75.90,64.90,0.11,4146
136
- 13-Aug-17,82.00,71.10,0.00,3274
137
- 14-Aug-17,80.10,70.00,0.00,3291
138
- 15-Aug-17,73.00,69.10,0.45,2149
139
- 16-Aug-17,84.90,70.00,0.00,3685
140
- 17-Aug-17,82.00,71.10,0.00,3637
141
- 18-Aug-17,81.00,73.00,0.88,1064
142
- 19-Aug-17,84.90,73.00,0.00,4693
143
- 20-Aug-17,81.00,70.00,0.00,2822
144
- 21-Aug-17,84.90,73.00,0.00,3088
145
- 22-Aug-17,88.00,75.00,0.30,2983
146
- 23-Aug-17,80.10,71.10,0.01,2994
147
- 24-Aug-17,79.00,66.00,0.00,3688
148
- 25-Aug-17,78.10,64.00,0.00,3144
149
- 26-Aug-17,77.00,62.10,0.00,2710
150
- 27-Aug-17,77.00,63.00,0.00,2676
151
- 28-Aug-17,75.00,63.00,0.00,3332
152
- 29-Aug-17,68.00,62.10,0.10,1472
153
- 30-Aug-17,75.90,61.00,0.01,3468
154
- 31-Aug-17,81.00,64.00,0.00,3279
155
- 1-Sep-17,70.00,55.00,0.00,2945
156
- 2-Sep-17,66.90,54.00,0.53,1876
157
- 3-Sep-17,69.10,60.10,0.74,1004
158
- 4-Sep-17,79.00,62.10,0.00,2866
159
- 5-Sep-17,84.00,70.00,0.01,3244
160
- 6-Sep-17,70.00,62.10,0.42,1232
161
- 7-Sep-17,71.10,59.00,0.01,3249
162
- 8-Sep-17,70.00,59.00,0.00,3234
163
- 9-Sep-17,69.10,55.00,0.00,2609
164
- 10-Sep-17,72.00,57.00,0.00,4960
165
- 11-Sep-17,75.90,55.00,0.00,3657
166
- 12-Sep-17,78.10,61.00,0.00,3497
167
- 13-Sep-17,82.00,64.90,0.06,2994
168
- 14-Sep-17,81.00,70.00,0.02,3013
169
- 15-Sep-17,81.00,66.90,0.00,3344
170
- 16-Sep-17,82.00,70.00,0.00,2560
171
- 17-Sep-17,80.10,70.00,0.00,2676
172
- 18-Sep-17,73.00,69.10,0.00,2673
173
- 19-Sep-17,78.10,69.10,0.22,2012
174
- 20-Sep-17,78.10,71.10,0.00,3296
175
- 21-Sep-17,80.10,71.10,0.00,3317
176
- 22-Sep-17,82.00,66.00,0.00,3297
177
- 23-Sep-17,86.00,68.00,0.00,2810
178
- 24-Sep-17,90.00,69.10,0.00,2543
179
- 25-Sep-17,87.10,72.00,0.00,3276
180
- 26-Sep-17,82.00,69.10,0.00,3157
181
- 27-Sep-17,84.90,71.10,0.00,3216
182
- 28-Sep-17,78.10,66.00,0.00,3421
183
- 29-Sep-17,66.90,55.00,0.00,2988
184
- 30-Sep-17,64.00,55.90,0.00,1903
185
- 1-Oct-17,66.90,50.00,0.00,2297
186
- 2-Oct-17,72.00,52.00,0.00,3387
187
- 3-Oct-17,70.00,57.00,0.00,3386
188
- 4-Oct-17,75.00,55.90,0.00,3412
189
- 5-Oct-17,82.00,64.90,0.00,3312
190
- 6-Oct-17,81.00,69.10,0.00,2982
191
- 7-Oct-17,80.10,66.00,0.00,2750
192
- 8-Oct-17,77.00,72.00,0.22,1235
193
- 9-Oct-17,75.90,72.00,0.26,898
194
- 10-Oct-17,80.10,66.00,0.00,3922
195
- 11-Oct-17,75.00,64.90,0.06,2721
196
- 12-Oct-17,63.00,55.90,0.07,2411
197
- 13-Oct-17,64.90,52.00,0.00,2839
198
- 14-Oct-17,71.10,62.10,0.08,2021
199
- 15-Oct-17,72.00,66.00,0.01,2169
200
- 16-Oct-17,60.10,52.00,0.01,2751
201
- 17-Oct-17,57.90,43.00,0.00,2869
202
- 18-Oct-17,71.10,50.00,0.00,3264
203
- 19-Oct-17,70.00,55.90,0.00,3265
204
- 20-Oct-17,73.00,57.90,0.00,3169
205
- 21-Oct-17,78.10,57.00,0.00,2538
206
- 22-Oct-17,75.90,57.00,0.00,2744
207
- 23-Oct-17,73.90,64.00,0.00,3189
208
- 24-Oct-17,73.00,66.90,0.20,954
209
- 25-Oct-17,64.90,57.90,0.00,3367
210
- 26-Oct-17,57.00,53.10,0.00,2565
211
- 27-Oct-17,62.10,48.00,0.00,3150
212
- 28-Oct-17,68.00,55.90,0.00,2245
213
- 29-Oct-17,64.90,61.00,3.03,183
214
- 30-Oct-17,55.00,46.00,0.25,1428
215
- 31-Oct-17,54.00,44.00,0.00,2727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assignment-2/poetry.lock DELETED
The diff for this file is too large to render. See raw diff
 
assignment-2/pyproject.toml DELETED
@@ -1,23 +0,0 @@
1
- [tool.poetry]
2
- name = "assignment-2"
3
- version = "0.1.0"
4
- description = ""
5
- authors = ["Artur Janik <aj2614@nyu.edu>"]
6
- readme = "README.md"
7
- packages = [{include = "assignment_2"}]
8
-
9
- [tool.poetry.dependencies]
10
- python = "^3.10"
11
- black = "^23.1.0"
12
- jupyterlab = "^3.6.1"
13
- ipython = "^8.10.0"
14
- numpy = "^1.24.2"
15
- pandas = "^1.5.3"
16
- jax = "^0.4.4"
17
- seaborn = "^0.12.2"
18
- matplotlib = "^3.7.0"
19
-
20
-
21
- [build-system]
22
- requires = ["poetry-core"]
23
- build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
language-models-project/docker-compose.yml β†’ docker-compose.yml RENAMED
File without changes
language-models-project/language_models_project/__init__.py DELETED
File without changes
language-models-project/tests/__init__.py DELETED
File without changes
{assignment-2/assignment_2 β†’ language_models_project}/__init__.py RENAMED
File without changes
language_models_project/app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import easyocr as ocr #OCR
2
+ import streamlit as st #Web App
3
+ from PIL import Image #Image Processing
4
+ import numpy as np #Image Processing
5
+
6
+ #title
7
+ st.title("Easy OCR - Extract Text from Images")
8
+
9
+ #subtitle
10
+ st.markdown("## Optical Character Recognition - Using `easyocr`, `streamlit` - hosted on πŸ€— Spaces")
11
+
12
+ st.markdown("Link to the app - [image-to-text-app on πŸ€— Spaces](https://huggingface.co/spaces/Amrrs/image-to-text-app)")
13
+
14
+ #image uploader
15
+ image = st.file_uploader(label = "Upload your image here",type=['png','jpg','jpeg'])
16
+
17
+
18
+ @st.cache
19
+ def load_model():
20
+ reader = ocr.Reader(['en'],model_storage_directory='.')
21
+ return reader
22
+
23
+ reader = load_model() #load model
24
+
25
+ if image is not None:
26
+
27
+ input_image = Image.open(image) #read image
28
+ st.image(input_image) #display image
29
+
30
+ with st.spinner("πŸ€– AI is at Work! "):
31
+
32
+
33
+ result = reader.readtext(np.array(input_image))
34
+
35
+ result_text = [] #empty list for results
36
+
37
+
38
+ for text in result:
39
+ result_text.append(text[1])
40
+
41
+ st.write(result_text)
42
+ #st.success("Here you go!")
43
+ st.balloons()
44
+ else:
45
+ st.write("Upload an Image")
46
+
47
+ st.caption("Made with ❀️ by @1littlecoder. Credits to πŸ€— Spaces for Hosting this ")
{language-models-project/language_models_project β†’ language_models_project}/main.py RENAMED
File without changes
midterm/Pipfile DELETED
@@ -1,19 +0,0 @@
1
- [[source]]
2
- url = "https://pypi.org/simple"
3
- verify_ssl = true
4
- name = "pypi"
5
-
6
- [packages]
7
- black = "*"
8
- jupyterlab = "*"
9
- pandas = "*"
10
- scikit-learn = "*"
11
- numpy = "*"
12
- ipython = "*"
13
- seaborn = "*"
14
- imblearn = "*"
15
-
16
- [dev-packages]
17
-
18
- [requires]
19
- python_version = "3.10"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
midterm/Pipfile.lock DELETED
The diff for this file is too large to render. See raw diff
 
midterm/README.md DELETED
File without changes
midterm/data/01_raw/CBC_data.csv DELETED
The diff for this file is too large to render. See raw diff
 
midterm/midterm/__init__.py DELETED
File without changes
midterm/midterm/take_at_home_(1).ipynb DELETED
The diff for this file is too large to render. See raw diff
 
midterm/pyproject.toml DELETED
@@ -1,22 +0,0 @@
1
- [tool.poetry]
2
- name = "midterm"
3
- version = "0.1.0"
4
- description = ""
5
- authors = ["Your Name <you@example.com>"]
6
- readme = "README.md"
7
-
8
- [tool.poetry.dependencies]
9
- python = "^3.10"
10
- black = "^23.1.0"
11
- jupyterlab = "^3.6.1"
12
- ipython = "^8.10.0"
13
- numpy = "^1.24.2"
14
- pandas = "^1.5.3"
15
- jax = "^0.4.4"
16
- seaborn = "^0.12.2"
17
- matplotlib = "^3.7.0"
18
-
19
-
20
- [build-system]
21
- requires = ["poetry-core"]
22
- build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
midterm/tests/__init__.py DELETED
File without changes
language-models-project/poetry.lock β†’ poetry.lock RENAMED
File without changes
language-models-project/pyproject.toml β†’ pyproject.toml RENAMED
File without changes
{assignment-2/tests β†’ tests}/__init__.py RENAMED
File without changes
{language-models-project/tests β†’ tests}/test_classifier.py RENAMED
File without changes