NativeVex commited on
Commit
92601fa
1 Parent(s): 2e0ece4

assignment-2

Browse files
assignment-2/README.md ADDED
File without changes
assignment-2/assignment_2/GML.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## imports
2
+ import numpy as np
3
+ import pandas as pd
4
+ from scipy.optimize import minimize
5
+ from scipy.stats import norm
6
+ import math
7
+
8
+
9
+ ## Problem 1
10
+ data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]
11
+
12
+ data_mean = np.mean(data)
13
+ data_variance = np.var(data)
14
+
15
+
16
+ mu = 0.5
17
+ sigma = 0.5
18
+ w = np.array([mu, sigma])
19
+
20
+ w_star = np.array([data_mean, data_variance])
21
+ mu_star = data_mean
22
+ sigma_star = np.sqrt(data_variance)
23
+ offset = 10 * np.random.random(2)
24
+
25
+ w1p = w_star + 0.5 * offset
26
+ w1n = w_star - 0.5 * offset
27
+ w2p = w_star + 0.25 * offset
28
+ w2n = w_star - 0.25 * offset
29
+
30
+ # Negative Log Likelihood is defined as follows:
31
+ # $-\ln(\frac{1}{\sqrt{2\pi\sigma^2}}\exp(-\frac{1}{2}\frac{(x-\mu)}{\sigma}^2))$.
32
+ # Ignoring the contribution of the constant, we find that $\frac{\delta}{\delta
33
+ # \mu} \mathcal{N} = \frac{\mu-x}{\sigma^2}$ and $\frac{\delta}{\delta \sigma}
34
+ # \mathcal{N} = \frac{\sigma^2 + (\mu-x)^2 - \sigma^2}{\sigma^3}$. We apply these as our step functions for our SGD.
35
+
36
+ loss = lambda mu, sigma, x: np.sum(
37
+ [-np.log(norm.pdf(xi, loc=mu, scale=sigma)) for xi in x]
38
+ )
39
+
40
+ loss_2_electric_boogaloo = lambda mu, sigma, x: -len(x) / 2 * np.log(
41
+ 2 * np.pi * sigma**2
42
+ ) - 1 / (2 * sigma**2) * np.sum((x - mu) ** 2)
43
+
44
+
45
+ dmu = lambda mu, sigma, x: -np.sum([mu - xi for xi in x]) / (sigma**2)
46
+ dsigma = lambda mu, sigma, x: -len(x) / sigma + np.sum([(mu - xi) ** 2 for xi in x]) / (
47
+ sigma**3
48
+ )
49
+
50
+ log = []
51
+
52
+
53
+ def SGD_problem1(mu, sigma, x, learning_rate=0.01, n_epochs=1000):
54
+ global log
55
+ log = []
56
+ for epoch in range(n_epochs):
57
+ mu += learning_rate * dmu(mu, sigma, x)
58
+ sigma += learning_rate * dsigma(mu, sigma, x)
59
+
60
+ # print(f"Epoch {epoch}, Loss: {loss(mu, sigma, x)}, New mu: {mu}, New sigma: {sigma}")
61
+ log.append(
62
+ {
63
+ "Epoch": epoch,
64
+ "Loss": loss(mu, sigma, x),
65
+ "Loss 2 Alternative": loss_2_alternative(mu, sigma, x),
66
+ "New mu": mu,
67
+ "New sigma": sigma,
68
+ }
69
+ )
70
+ return np.array([mu, sigma])
71
+
72
+
73
+ def debug_SGD_1(wnn, data):
74
+ print("SGD Problem 1")
75
+ print("wnn", SGD_problem1(*wnn, data))
76
+ dflog = pd.DataFrame(log)
77
+ dflog["mu_star"] = mu_star
78
+ dflog["mu_std"] = sigma_star
79
+ print(f"mu diff at start {dflog.iloc[0]['New mu'] - dflog.iloc[0]['mu_star']}")
80
+ print(f"mu diff at end {dflog.iloc[-1]['New mu'] - dflog.iloc[-1]['mu_star']}")
81
+ if np.abs(dflog.iloc[-1]["New mu"] - dflog.iloc[-1]["mu_star"]) < np.abs(
82
+ dflog.iloc[0]["New mu"] - dflog.iloc[0]["mu_star"]
83
+ ):
84
+ print("mu is improving")
85
+ else:
86
+ print("mu is not improving")
87
+
88
+ print(f"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['mu_std']}")
89
+ print(f"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['mu_std']}")
90
+ if np.abs(dflog.iloc[-1]["New sigma"] - dflog.iloc[-1]["mu_std"]) < np.abs(
91
+ dflog.iloc[0]["New sigma"] - dflog.iloc[0]["mu_std"]
92
+ ):
93
+ print("sigma is improving")
94
+ else:
95
+ print("sigma is not improving")
96
+
97
+ return dflog
98
+
99
+
100
+ # _ = debug_SGD_1(w1p, data)
101
+ # _ = debug_SGD_1(w1n, data)
102
+ # _ = debug_SGD_1(w2p, data)
103
+ # _ = debug_SGD_1(w2n, data)
104
+
105
+
106
+ # TODO EXPLAIN WHY += WORKS HERE.
107
+
108
+
109
+ ## Problem 2
110
+ x = np.array([8, 16, 22, 33, 50, 51])
111
+ y = np.array([5, 20, 14, 32, 42, 58])
112
+
113
+ # $-\frac{n}{\sigma}+\frac{1}{\sigma^3}\sum_{i=1}^n(y_i - (mx+c))^2$
114
+ dsigma = lambda sigma, c, m, x: -len(x) / sigma + np.sum(
115
+ [(xi - (m * x + c)) ** 2 for xi in x]
116
+ ) / (sigma**3)
117
+ # $-\frac{1}{\sigma^2}\sum_{i=1}^n(y_i - (mx+c))$
118
+ dc = lambda sigma, c, m, x: -np.sum([xi - (m * x + c) for xi in x]) / (sigma**2)
119
+ # $-\frac{1}{\sigma^2}\sum_{i=1}^n(x_i(y_i - (mx+c)))$
120
+ dm = lambda sigma, c, m, x: -np.sum([x * (xi - (m * x + c)) for xi in x]) / (sigma**2)
121
+
122
+
123
+ log2 = []
124
+
125
+
126
+ def SGD_problem2(
127
+ sigma: float,
128
+ c: float,
129
+ m: float,
130
+ x: np.array,
131
+ y: np.array,
132
+ learning_rate=0.01,
133
+ n_epochs=1000,
134
+ ):
135
+ global log2
136
+ log2 = []
137
+ for epoch in range(n_epochs):
138
+ sigma += learning_rate * dsigma(sigma, c, m, x)
139
+ c += learning_rate * dc(sigma, c, m, x)
140
+ m += learning_rate * dm(sigma, c, m, x)
141
+
142
+ log2.append(
143
+ {
144
+ "Epoch": epoch,
145
+ "New sigma": sigma,
146
+ "New c": c,
147
+ "New m": m,
148
+ "dc": dc(sigma, c, m, x),
149
+ "dm": dm(sigma, c, m, x),
150
+ "dsigma": dsigma(sigma, c, m, x),
151
+ "Loss": loss((m * x + c), sigma, y),
152
+ }
153
+ )
154
+ print(f"Epoch {epoch}, Loss: {loss((m * x + c), sigma, y)}")
155
+ return np.array([sigma, c, m])
156
+
157
+
158
+ # def debug_SGD_2(wnn, data):
159
+ # print("SGD Problem 2")
160
+ # print("wnn", SGD_problem2(*wnn, data))
161
+ # dflog = pd.DataFrame(log)
162
+ # dflog["m_star"] = m_star
163
+ # dflog["c_star"] = c_star
164
+ # dflog["sigma_star"] = sigma_star
165
+ # print(f"m diff at start {dflog.iloc[0]['New m'] - dflog.iloc[0]['m_star']}")
166
+ # print(f"m diff at end {dflog.iloc[-1]['New m'] - dflog.iloc[-1]['m_star']}")
167
+ # if np.abs(dflog.iloc[-1]["New m"] - dflog.iloc[-1]["m_star"]) < np.abs(
168
+ # dflog.iloc[0]["New m"] - dflog.iloc[0]["m_star"]
169
+ # ):
170
+ # print("m is improving")
171
+ # else:
172
+ # print("m is not improving")
173
+ # print(f"c diff at start {dflog.iloc[0]['New c'] - dflog.iloc[0]['c_star']}")
174
+ # print(f"c diff at end {dflog.iloc[-1]['New c'] - dflog.iloc[-1]['c_star']}")
175
+ # if np.abs(dflog.iloc[-1]["New c"] - dflog.iloc[-1]["c_star"]) < np.abs(
176
+ # dflog.iloc[0]["New c"] - dflog.iloc[0]["c_star"]
177
+ # ):
178
+ # print("c is improving")
179
+ # else:
180
+ # print("c is not improving")
181
+ # print(f"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['sigma_star']}")
182
+ # print(f"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['sigma_star']}")
183
+ # if np.abs(dflog.iloc[-1]["New sigma"] - dflog.iloc[-1]["sigma_star"]) < np.abs(
184
+ # dflog.iloc[0]["New sigma"] - dflog.iloc[0]["sigma_star"]
185
+ # ):
186
+ # print("sigma is improving")
187
+ # else:
188
+ # print("sigma is not improving")
189
+ # return dflog
190
+
191
+ result = SGD_problem2(0.5, 0.5, 0.5, x, y)
192
+ print(f"final parameters: m={result[2]}, c={result[1]}, sigma={result[0]}")
193
+
194
+
195
+ ## pset2
196
+ # Knowing that the poisson pdf is $P(k) = \frac{\lambda^k e^{-\lambda}}{k!}$, we can find the negative log likelihood of the data as $-\log(\Pi_{i=1}^n P(k_i)) = -\sum_{i=1}^n \log(\frac{\lambda^k_i e^{-\lambda}}{k_i!}) = \sum_{i=1}^n -\ln(\lambda) k_i + \ln(k_i!) + \lambda$. Which simplified, gives $n\lambda + \sum_{i=1}^n \ln(k_i!) - \sum_{i=1}^n k_i \ln(\lambda)$. Differentiating with respect to $\lambda$ gives $n - \sum_{i=1}^n \frac{k_i}{\lambda}$. Which is our desired $\frac{\partial L}{\partial \lambda}$!
197
+
198
+
199
+ import pandas as pd
200
+
201
+ df = pd.read_csv("../data/01_raw/nyc_bb_bicyclist_counts.csv")
202
+
203
+ dlambda = lambda l, k: len(k) - np.sum([ki / l for ki in k])
204
+
205
+
206
+ def SGD_problem3(
207
+ l: float,
208
+ k: np.array,
209
+ learning_rate=0.01,
210
+ n_epochs=1000,
211
+ ):
212
+ global log3
213
+ log3 = []
214
+ for epoch in range(n_epochs):
215
+ l -= learning_rate * dlambda(l, k)
216
+ # $n\lambda + \sum_{i=1}^n \ln(k_i!) - \sum_{i=1}^n k_i \ln(\lambda)$
217
+ # the rest of the loss function is commented out because it's a
218
+ # constant and was causing overflows. It is unnecessary, and a useless
219
+ # pain.
220
+ loss = len(k) * l - np.sum(
221
+ [ki * np.log(l) for ki in k]
222
+ ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])
223
+
224
+ log3.append(
225
+ {
226
+ "Epoch": epoch,
227
+ "New lambda": l,
228
+ "dlambda": dlambda(l, k),
229
+ "Loss": loss,
230
+ }
231
+ )
232
+ print(f"Epoch {epoch}", f"Loss: {loss}")
233
+ return np.array([l])
234
+
235
+
236
+ l_star = df["BB_COUNT"].mean()
237
+
238
+
239
+ def debug_SGD_3(data, l=1000):
240
+ print("SGD Problem 3")
241
+ print(f"l: {SGD_problem3(l, data)}")
242
+ dflog = pd.DataFrame(log3)
243
+ dflog["l_star"] = l_star
244
+ print(f"l diff at start {dflog.iloc[0]['New lambda'] - dflog.iloc[0]['l_star']}")
245
+ print(f"l diff at end {dflog.iloc[-1]['New lambda'] - dflog.iloc[-1]['l_star']}")
246
+ if np.abs(dflog.iloc[-1]["New lambda"] - dflog.iloc[-1]["l_star"]) < np.abs(
247
+ dflog.iloc[0]["New lambda"] - dflog.iloc[0]["l_star"]
248
+ ):
249
+ print("l is improving")
250
+ else:
251
+ print("l is not improving")
252
+ return dflog
253
+
254
+
255
+ debug_SGD_3(data=df["BB_COUNT"].values, l=l_star + 1000)
256
+ debug_SGD_3(data=df["BB_COUNT"].values, l=l_star - 1000)
257
+
258
+
259
+ ## pset 4
260
+
261
+ # dw = lambda w, x: len(x) * np.exp(np.dot(x, w)) * x - np.sum()
262
+
263
+ primitive = lambda xi, wi: (x.shape[0] * np.exp(wi * xi) * xi) - (xi**2)
264
+ p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])
265
+
266
+
267
+ def SGD_problem4(
268
+ w: np.array,
269
+ x: np.array,
270
+ learning_rate=0.01,
271
+ n_epochs=1000,
272
+ ):
273
+ global log4
274
+ log4 = []
275
+ for epoch in range(n_epochs):
276
+ w -= learning_rate * p_dw(w, x)
277
+ # custom
278
+ # loss = x.shape[0] * np.exp(np.dot(x, w))
279
+ loss_fn = lambda k, l: len(k) * l - np.sum(
280
+ [ki * np.log(l) for ki in k]
281
+ ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])
282
+ loss = loss_fn(x, np.exp(np.dot(x, w)))
283
+ log4.append(
284
+ {
285
+ "Epoch": epoch,
286
+ "New w": w,
287
+ "dw": dw(w, x),
288
+ "Loss": loss,
289
+ }
290
+ )
291
+ print(f"Epoch {epoch}", f"Loss: {loss}")
292
+ return w
293
+
294
+
295
+ def debug_SGD_3(data, w=np.array([1, 1])):
296
+ print("SGD Problem 4")
297
+ print(f"w: {SGD_problem4(w, data)}")
298
+ dflog = pd.DataFrame(log4)
299
+ return dflog
300
+
301
+
302
+ _ = debug_SGD_3(
303
+ data=df[["HIGH_T", "LOW_T", "PRECIP"]].to_numpy(),
304
+ w=np.array([1.0, 1.0, 1.0]),
305
+ )
assignment-2/assignment_2/Gaussian Maximum Likelihood.ipynb ADDED
@@ -0,0 +1,1682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "fce8933f-4594-4bb5-bffe-86fcb9ddd684",
6
+ "metadata": {},
7
+ "source": [
8
+ "# MLE of a Gaussian $p_{model}(x|w)$"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 4,
14
+ "id": "f6cd23f0-e755-48af-be5e-aaee83dda1e7",
15
+ "metadata": {
16
+ "tags": []
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "import numpy as np\n",
21
+ "\n",
22
+ "data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n",
23
+ "\n",
24
+ "\n",
25
+ "## imports\n",
26
+ "import numpy as np\n",
27
+ "import pandas as pd\n",
28
+ "from scipy.optimize import minimize\n",
29
+ "from scipy.stats import norm\n",
30
+ "import math\n",
31
+ "\n",
32
+ "\n",
33
+ "## Problem 1\n",
34
+ "data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n",
35
+ "\n",
36
+ "data_mean = np.mean(data)\n",
37
+ "data_variance = np.var(data)\n",
38
+ "\n",
39
+ "\n",
40
+ "mu = 0.5\n",
41
+ "sigma = 0.5\n",
42
+ "w = np.array([mu, sigma])\n",
43
+ "\n",
44
+ "w_star = np.array([data_mean, data_variance])\n",
45
+ "mu_star = data_mean\n",
46
+ "sigma_star = np.sqrt(data_variance)\n",
47
+ "offset = 10 * np.random.random(2)\n",
48
+ "\n",
49
+ "w1p = w_star + 0.5 * offset\n",
50
+ "w1n = w_star - 0.5 * offset\n",
51
+ "w2p = w_star + 0.25 * offset\n",
52
+ "w2n = w_star - 0.25 * offset"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "id": "f3d8587b-3862-4e98-bbcc-99d57bb313c1",
58
+ "metadata": {},
59
+ "source": [
60
+ "Negative Log Likelihood is defined as follows: $-\\ln(\\frac{1}{\\sqrt{2\\pi\\sigma^2}}\\exp(-\\frac{1}{2}\\frac{(x-\\mu)}{\\sigma}^2))$. Ignoring the contribution of the constant, we find that $\\frac{\\delta}{\\delta \\mu} \\mathcal{N} = \\frac{\\mu-x}{\\sigma^2}$ and $\\frac{\\delta}{\\delta \\sigma} \\mathcal{N} = \\frac{\\sigma^2 + (\\mu-x)^2 - \\sigma^2}{\\sigma^3}$. We apply these as our step functions for our SGD. "
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 5,
66
+ "id": "27bf27ad-031e-4b65-a44d-53c5c1a09d91",
67
+ "metadata": {
68
+ "tags": []
69
+ },
70
+ "outputs": [],
71
+ "source": [
72
+ "loss = lambda mu, sigma, x: np.sum(\n",
73
+ " [-np.log(norm.pdf(xi, loc=mu, scale=sigma)) for xi in x]\n",
74
+ ")\n",
75
+ "\n",
76
+ "loss_2_alternative = lambda mu, sigma, x: -len(x) / 2 * np.log(\n",
77
+ " 2 * np.pi * sigma**2\n",
78
+ ") - 1 / (2 * sigma**2) * np.sum((x - mu) ** 2)\n",
79
+ "\n",
80
+ "\n",
81
+ "dmu = lambda mu, sigma, x: -np.sum([mu - xi for xi in x]) / (sigma**2)\n",
82
+ "dsigma = lambda mu, sigma, x: -len(x) / sigma + np.sum([(mu - xi) ** 2 for xi in x]) / (sigma**3)\n",
83
+ "\n",
84
+ "log = []\n",
85
+ "def SGD_problem1(mu, sigma, x, learning_rate=0.01, n_epochs=1000):\n",
86
+ " global log\n",
87
+ " log = []\n",
88
+ " for epoch in range(n_epochs):\n",
89
+ " mu += learning_rate * dmu(mu, sigma, x)\n",
90
+ " sigma += learning_rate * dsigma(mu, sigma, x)\n",
91
+ "\n",
92
+ " # print(f\"Epoch {epoch}, Loss: {loss(mu, sigma, x)}, New mu: {mu}, New sigma: {sigma}\")\n",
93
+ " log.append(\n",
94
+ " {\n",
95
+ " \"Epoch\": epoch,\n",
96
+ " \"Loss\": loss(mu, sigma, x),\n",
97
+ " \"Loss 2 Alternative\": loss_2_alternative(mu, sigma, x),\n",
98
+ " \"New mu\": mu,\n",
99
+ " \"New sigma\": sigma,\n",
100
+ " }\n",
101
+ " )\n",
102
+ " return np.array([mu, sigma])\n",
103
+ "\n",
104
+ "\n",
105
+ "def debug_SGD_1(wnn, data):\n",
106
+ " print(\"SGD Problem 1\")\n",
107
+ " print(\"wnn\", SGD_problem1(*wnn, data))\n",
108
+ " dflog = pd.DataFrame(log)\n",
109
+ " dflog[\"mu_star\"] = mu_star\n",
110
+ " dflog[\"mu_std\"] = sigma_star\n",
111
+ " print(f\"mu diff at start {dflog.iloc[0]['New mu'] - dflog.iloc[0]['mu_star']}\")\n",
112
+ " print(f\"mu diff at end {dflog.iloc[-1]['New mu'] - dflog.iloc[-1]['mu_star']}\")\n",
113
+ " if np.abs(dflog.iloc[-1][\"New mu\"] - dflog.iloc[-1][\"mu_star\"]) < np.abs(\n",
114
+ " dflog.iloc[0][\"New mu\"] - dflog.iloc[0][\"mu_star\"]\n",
115
+ " ):\n",
116
+ " print(\"mu is improving\")\n",
117
+ " else:\n",
118
+ " print(\"mu is not improving\")\n",
119
+ "\n",
120
+ " print(f\"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['mu_std']}\")\n",
121
+ " print(f\"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['mu_std']}\")\n",
122
+ " if np.abs(dflog.iloc[-1][\"New sigma\"] - dflog.iloc[-1][\"mu_std\"]) < np.abs(\n",
123
+ " dflog.iloc[0][\"New sigma\"] - dflog.iloc[0][\"mu_std\"]\n",
124
+ " ):\n",
125
+ " print(\"sigma is improving\")\n",
126
+ " else:\n",
127
+ " print(\"sigma is not improving\")\n",
128
+ "\n",
129
+ " return dflog"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": 6,
135
+ "id": "27dd3bc6-b96e-4f8b-9118-01ad344dfd6a",
136
+ "metadata": {
137
+ "tags": []
138
+ },
139
+ "outputs": [
140
+ {
141
+ "name": "stdout",
142
+ "output_type": "stream",
143
+ "text": [
144
+ "SGD Problem 1\n",
145
+ "wnn [6.2142858 2.42541812]\n",
146
+ "mu diff at start 0.27610721776969527\n",
147
+ "mu diff at end 8.978893806244059e-08\n",
148
+ "mu is improving\n",
149
+ "sigma diff at start 8.134860851821205\n",
150
+ "sigma diff at end 1.7124079931818414e-12\n",
151
+ "sigma is improving\n",
152
+ "SGD Problem 1\n",
153
+ "wnn [6.21428571 2.42541812]\n",
154
+ "mu diff at start -0.24923650064862635\n",
155
+ "mu diff at end -6.602718372050731e-12\n",
156
+ "mu is improving\n",
157
+ "sigma diff at start -0.859536014291925\n",
158
+ "sigma diff at end -3.552713678800501e-15\n",
159
+ "sigma is improving\n",
160
+ "SGD Problem 1\n",
161
+ "wnn [6.21428572 2.42541812]\n",
162
+ "mu diff at start 0.13794086144778994\n",
163
+ "mu diff at end 1.0008935902305893e-09\n",
164
+ "mu is improving\n",
165
+ "sigma diff at start 5.786783512688555\n",
166
+ "sigma diff at end 4.440892098500626e-15\n",
167
+ "sigma is improving\n",
168
+ "SGD Problem 1\n",
169
+ "wnn [6.21428571 2.42541812]\n",
170
+ "mu diff at start -0.13668036978891251\n",
171
+ "mu diff at end -8.528289185960602e-12\n",
172
+ "mu is improving\n",
173
+ "sigma diff at start 1.091241177336173\n",
174
+ "sigma diff at end 4.440892098500626e-15\n",
175
+ "sigma is improving\n"
176
+ ]
177
+ }
178
+ ],
179
+ "source": [
180
+ "_ = debug_SGD_1(w1p, data)\n",
181
+ "_ = debug_SGD_1(w1n, data)\n",
182
+ "_ = debug_SGD_1(w2p, data)\n",
183
+ "_ = debug_SGD_1(w2n, data)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "id": "30096401-0bd5-4cf6-b093-a688476e16f1",
189
+ "metadata": {
190
+ "tags": []
191
+ },
192
+ "source": [
193
+ "# MLE of Conditional Gaussian"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "markdown",
198
+ "id": "101a3c5e-1e02-41e6-9eab-aba65c39627a",
199
+ "metadata": {},
200
+ "source": [
201
+ "dsigma = $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$ \n",
202
+ "dc = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$ \n",
203
+ "dm = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$ "
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 8,
209
+ "id": "21969012-f81b-43d4-975d-13411e975f8f",
210
+ "metadata": {
211
+ "collapsed": true,
212
+ "jupyter": {
213
+ "outputs_hidden": true
214
+ },
215
+ "tags": []
216
+ },
217
+ "outputs": [
218
+ {
219
+ "name": "stdout",
220
+ "output_type": "stream",
221
+ "text": [
222
+ "Epoch 0, Loss: 297.82677563555086\n",
223
+ "Epoch 1, Loss: 297.8267749215061\n",
224
+ "Epoch 2, Loss: 297.82677420752475\n",
225
+ "Epoch 3, Loss: 297.82677349360694\n",
226
+ "Epoch 4, Loss: 297.8267727797526\n",
227
+ "Epoch 5, Loss: 297.8267720659618\n",
228
+ "Epoch 6, Loss: 297.82677135223446\n",
229
+ "Epoch 7, Loss: 297.82677063857074\n",
230
+ "Epoch 8, Loss: 297.8267699249706\n",
231
+ "Epoch 9, Loss: 297.826769211434\n",
232
+ "Epoch 10, Loss: 297.8267684979611\n",
233
+ "Epoch 11, Loss: 297.82676778455175\n",
234
+ "Epoch 12, Loss: 297.82676707120623\n",
235
+ "Epoch 13, Loss: 297.82676635792427\n",
236
+ "Epoch 14, Loss: 297.8267656447061\n",
237
+ "Epoch 15, Loss: 297.8267649315517\n",
238
+ "Epoch 16, Loss: 297.82676421846105\n",
239
+ "Epoch 17, Loss: 297.82676350543414\n",
240
+ "Epoch 18, Loss: 297.82676279247113\n",
241
+ "Epoch 19, Loss: 297.82676207957184\n",
242
+ "Epoch 20, Loss: 297.82676136673643\n",
243
+ "Epoch 21, Loss: 297.826760653965\n",
244
+ "Epoch 22, Loss: 297.82675994125736\n",
245
+ "Epoch 23, Loss: 297.8267592286137\n",
246
+ "Epoch 24, Loss: 297.8267585160339\n",
247
+ "Epoch 25, Loss: 297.8267578035182\n",
248
+ "Epoch 26, Loss: 297.82675709106655\n",
249
+ "Epoch 27, Loss: 297.8267563786788\n",
250
+ "Epoch 28, Loss: 297.82675566635504\n",
251
+ "Epoch 29, Loss: 297.82675495409546\n",
252
+ "Epoch 30, Loss: 297.82675424189983\n",
253
+ "Epoch 31, Loss: 297.82675352976844\n",
254
+ "Epoch 32, Loss: 297.8267528177011\n",
255
+ "Epoch 33, Loss: 297.82675210569795\n",
256
+ "Epoch 34, Loss: 297.826751393759\n",
257
+ "Epoch 35, Loss: 297.8267506818843\n",
258
+ "Epoch 36, Loss: 297.8267499700737\n",
259
+ "Epoch 37, Loss: 297.8267492583274\n",
260
+ "Epoch 38, Loss: 297.82674854664543\n",
261
+ "Epoch 39, Loss: 297.8267478350277\n",
262
+ "Epoch 40, Loss: 297.8267471234743\n",
263
+ "Epoch 41, Loss: 297.82674641198514\n",
264
+ "Epoch 42, Loss: 297.8267457005605\n",
265
+ "Epoch 43, Loss: 297.82674498920017\n",
266
+ "Epoch 44, Loss: 297.82674427790425\n",
267
+ "Epoch 45, Loss: 297.82674356667275\n",
268
+ "Epoch 46, Loss: 297.82674285550564\n",
269
+ "Epoch 47, Loss: 297.8267421444032\n",
270
+ "Epoch 48, Loss: 297.82674143336516\n",
271
+ "Epoch 49, Loss: 297.8267407223916\n",
272
+ "Epoch 50, Loss: 297.8267400114827\n",
273
+ "Epoch 51, Loss: 297.82673930063817\n",
274
+ "Epoch 52, Loss: 297.8267385898584\n",
275
+ "Epoch 53, Loss: 297.8267378791432\n",
276
+ "Epoch 54, Loss: 297.8267371684925\n",
277
+ "Epoch 55, Loss: 297.8267364579067\n",
278
+ "Epoch 56, Loss: 297.8267357473855\n",
279
+ "Epoch 57, Loss: 297.82673503692894\n",
280
+ "Epoch 58, Loss: 297.82673432653723\n",
281
+ "Epoch 59, Loss: 297.8267336162103\n",
282
+ "Epoch 60, Loss: 297.826732905948\n",
283
+ "Epoch 61, Loss: 297.82673219575054\n",
284
+ "Epoch 62, Loss: 297.826731485618\n",
285
+ "Epoch 63, Loss: 297.82673077555023\n",
286
+ "Epoch 64, Loss: 297.82673006554734\n",
287
+ "Epoch 65, Loss: 297.8267293556093\n",
288
+ "Epoch 66, Loss: 297.82672864573624\n",
289
+ "Epoch 67, Loss: 297.82672793592815\n",
290
+ "Epoch 68, Loss: 297.82672722618497\n",
291
+ "Epoch 69, Loss: 297.8267265165068\n",
292
+ "Epoch 70, Loss: 297.8267258068936\n",
293
+ "Epoch 71, Loss: 297.8267250973455\n",
294
+ "Epoch 72, Loss: 297.8267243878624\n",
295
+ "Epoch 73, Loss: 297.8267236784444\n",
296
+ "Epoch 74, Loss: 297.82672296909146\n",
297
+ "Epoch 75, Loss: 297.8267222598038\n",
298
+ "Epoch 76, Loss: 297.8267215505811\n",
299
+ "Epoch 77, Loss: 297.8267208414237\n",
300
+ "Epoch 78, Loss: 297.82672013233156\n",
301
+ "Epoch 79, Loss: 297.8267194233045\n",
302
+ "Epoch 80, Loss: 297.8267187143427\n",
303
+ "Epoch 81, Loss: 297.8267180054462\n",
304
+ "Epoch 82, Loss: 297.82671729661496\n",
305
+ "Epoch 83, Loss: 297.82671658784903\n",
306
+ "Epoch 84, Loss: 297.8267158791485\n",
307
+ "Epoch 85, Loss: 297.82671517051335\n",
308
+ "Epoch 86, Loss: 297.8267144619435\n",
309
+ "Epoch 87, Loss: 297.8267137534391\n",
310
+ "Epoch 88, Loss: 297.82671304500013\n",
311
+ "Epoch 89, Loss: 297.82671233662654\n",
312
+ "Epoch 90, Loss: 297.82671162831855\n",
313
+ "Epoch 91, Loss: 297.8267109200759\n",
314
+ "Epoch 92, Loss: 297.82671021189896\n",
315
+ "Epoch 93, Loss: 297.8267095037876\n",
316
+ "Epoch 94, Loss: 297.8267087957417\n",
317
+ "Epoch 95, Loss: 297.82670808776135\n",
318
+ "Epoch 96, Loss: 297.82670737984665\n",
319
+ "Epoch 97, Loss: 297.8267066719976\n",
320
+ "Epoch 98, Loss: 297.82670596421417\n",
321
+ "Epoch 99, Loss: 297.8267052564966\n",
322
+ "Epoch 100, Loss: 297.82670454884465\n",
323
+ "Epoch 101, Loss: 297.82670384125834\n",
324
+ "Epoch 102, Loss: 297.8267031337379\n",
325
+ "Epoch 103, Loss: 297.8267024262833\n",
326
+ "Epoch 104, Loss: 297.82670171889436\n",
327
+ "Epoch 105, Loss: 297.8267010115713\n",
328
+ "Epoch 106, Loss: 297.82670030431416\n",
329
+ "Epoch 107, Loss: 297.82669959712285\n",
330
+ "Epoch 108, Loss: 297.82669888999743\n",
331
+ "Epoch 109, Loss: 297.8266981829379\n",
332
+ "Epoch 110, Loss: 297.8266974759444\n",
333
+ "Epoch 111, Loss: 297.8266967690169\n",
334
+ "Epoch 112, Loss: 297.8266960621552\n",
335
+ "Epoch 113, Loss: 297.8266953553597\n",
336
+ "Epoch 114, Loss: 297.82669464863017\n",
337
+ "Epoch 115, Loss: 297.82669394196677\n",
338
+ "Epoch 116, Loss: 297.8266932353694\n",
339
+ "Epoch 117, Loss: 297.82669252883824\n",
340
+ "Epoch 118, Loss: 297.8266918223731\n",
341
+ "Epoch 119, Loss: 297.8266911159742\n",
342
+ "Epoch 120, Loss: 297.82669040964146\n",
343
+ "Epoch 121, Loss: 297.8266897033749\n",
344
+ "Epoch 122, Loss: 297.8266889971746\n",
345
+ "Epoch 123, Loss: 297.82668829104057\n",
346
+ "Epoch 124, Loss: 297.8266875849728\n",
347
+ "Epoch 125, Loss: 297.8266868789714\n",
348
+ "Epoch 126, Loss: 297.8266861730362\n",
349
+ "Epoch 127, Loss: 297.8266854671674\n",
350
+ "Epoch 128, Loss: 297.826684761365\n",
351
+ "Epoch 129, Loss: 297.82668405562896\n",
352
+ "Epoch 130, Loss: 297.8266833499594\n",
353
+ "Epoch 131, Loss: 297.8266826443563\n",
354
+ "Epoch 132, Loss: 297.8266819388196\n",
355
+ "Epoch 133, Loss: 297.82668123334935\n",
356
+ "Epoch 134, Loss: 297.8266805279458\n",
357
+ "Epoch 135, Loss: 297.8266798226087\n",
358
+ "Epoch 136, Loss: 297.82667911733813\n",
359
+ "Epoch 137, Loss: 297.8266784121342\n",
360
+ "Epoch 138, Loss: 297.82667770699686\n",
361
+ "Epoch 139, Loss: 297.82667700192616\n",
362
+ "Epoch 140, Loss: 297.8266762969221\n",
363
+ "Epoch 141, Loss: 297.8266755919847\n",
364
+ "Epoch 142, Loss: 297.82667488711405\n",
365
+ "Epoch 143, Loss: 297.8266741823101\n",
366
+ "Epoch 144, Loss: 297.82667347757297\n",
367
+ "Epoch 145, Loss: 297.82667277290255\n",
368
+ "Epoch 146, Loss: 297.82667206829893\n",
369
+ "Epoch 147, Loss: 297.8266713637622\n",
370
+ "Epoch 148, Loss: 297.8266706592923\n",
371
+ "Epoch 149, Loss: 297.8266699548893\n",
372
+ "Epoch 150, Loss: 297.8266692505533\n",
373
+ "Epoch 151, Loss: 297.82666854628405\n",
374
+ "Epoch 152, Loss: 297.82666784208175\n",
375
+ "Epoch 153, Loss: 297.8266671379465\n",
376
+ "Epoch 154, Loss: 297.8266664338782\n",
377
+ "Epoch 155, Loss: 297.82666572987705\n",
378
+ "Epoch 156, Loss: 297.8266650259427\n",
379
+ "Epoch 157, Loss: 297.82666432207566\n",
380
+ "Epoch 158, Loss: 297.82666361827546\n",
381
+ "Epoch 159, Loss: 297.8266629145426\n",
382
+ "Epoch 160, Loss: 297.8266622108768\n",
383
+ "Epoch 161, Loss: 297.8266615072782\n",
384
+ "Epoch 162, Loss: 297.8266608037467\n",
385
+ "Epoch 163, Loss: 297.8266601002824\n",
386
+ "Epoch 164, Loss: 297.8266593968855\n",
387
+ "Epoch 165, Loss: 297.8266586935557\n",
388
+ "Epoch 166, Loss: 297.82665799029326\n",
389
+ "Epoch 167, Loss: 297.82665728709804\n",
390
+ "Epoch 168, Loss: 297.82665658397036\n",
391
+ "Epoch 169, Loss: 297.8266558809098\n",
392
+ "Epoch 170, Loss: 297.82665517791673\n",
393
+ "Epoch 171, Loss: 297.8266544749911\n",
394
+ "Epoch 172, Loss: 297.82665377213283\n",
395
+ "Epoch 173, Loss: 297.826653069342\n",
396
+ "Epoch 174, Loss: 297.8266523666187\n",
397
+ "Epoch 175, Loss: 297.82665166396293\n",
398
+ "Epoch 176, Loss: 297.82665096137464\n",
399
+ "Epoch 177, Loss: 297.8266502588538\n",
400
+ "Epoch 178, Loss: 297.82664955640064\n",
401
+ "Epoch 179, Loss: 297.82664885401516\n",
402
+ "Epoch 180, Loss: 297.82664815169716\n",
403
+ "Epoch 181, Loss: 297.8266474494469\n",
404
+ "Epoch 182, Loss: 297.8266467472642\n",
405
+ "Epoch 183, Loss: 297.8266460451493\n",
406
+ "Epoch 184, Loss: 297.82664534310203\n",
407
+ "Epoch 185, Loss: 297.82664464112264\n",
408
+ "Epoch 186, Loss: 297.82664393921095\n",
409
+ "Epoch 187, Loss: 297.82664323736697\n",
410
+ "Epoch 188, Loss: 297.8266425355909\n",
411
+ "Epoch 189, Loss: 297.8266418338827\n",
412
+ "Epoch 190, Loss: 297.82664113224223\n",
413
+ "Epoch 191, Loss: 297.8266404306698\n",
414
+ "Epoch 192, Loss: 297.82663972916515\n",
415
+ "Epoch 193, Loss: 297.82663902772856\n",
416
+ "Epoch 194, Loss: 297.82663832635984\n",
417
+ "Epoch 195, Loss: 297.8266376250591\n",
418
+ "Epoch 196, Loss: 297.8266369238264\n",
419
+ "Epoch 197, Loss: 297.82663622266176\n",
420
+ "Epoch 198, Loss: 297.8266355215652\n",
421
+ "Epoch 199, Loss: 297.8266348205367\n",
422
+ "Epoch 200, Loss: 297.8266341195762\n",
423
+ "Epoch 201, Loss: 297.82663341868397\n",
424
+ "Epoch 202, Loss: 297.82663271785987\n",
425
+ "Epoch 203, Loss: 297.82663201710386\n",
426
+ "Epoch 204, Loss: 297.8266313164161\n",
427
+ "Epoch 205, Loss: 297.82663061579655\n",
428
+ "Epoch 206, Loss: 297.8266299152454\n",
429
+ "Epoch 207, Loss: 297.8266292147624\n",
430
+ "Epoch 208, Loss: 297.8266285143477\n",
431
+ "Epoch 209, Loss: 297.82662781400137\n",
432
+ "Epoch 210, Loss: 297.82662711372336\n",
433
+ "Epoch 211, Loss: 297.8266264135138\n",
434
+ "Epoch 212, Loss: 297.8266257133725\n",
435
+ "Epoch 213, Loss: 297.8266250132998\n",
436
+ "Epoch 214, Loss: 297.82662431329544\n",
437
+ "Epoch 215, Loss: 297.8266236133595\n",
438
+ "Epoch 216, Loss: 297.8266229134922\n",
439
+ "Epoch 217, Loss: 297.8266222136934\n",
440
+ "Epoch 218, Loss: 297.8266215139631\n",
441
+ "Epoch 219, Loss: 297.8266208143014\n",
442
+ "Epoch 220, Loss: 297.8266201147082\n",
443
+ "Epoch 221, Loss: 297.8266194151838\n",
444
+ "Epoch 222, Loss: 297.82661871572793\n",
445
+ "Epoch 223, Loss: 297.82661801634066\n",
446
+ "Epoch 224, Loss: 297.82661731702217\n",
447
+ "Epoch 225, Loss: 297.82661661777234\n",
448
+ "Epoch 226, Loss: 297.8266159185914\n",
449
+ "Epoch 227, Loss: 297.8266152194791\n",
450
+ "Epoch 228, Loss: 297.82661452043567\n",
451
+ "Epoch 229, Loss: 297.82661382146097\n",
452
+ "Epoch 230, Loss: 297.8266131225552\n",
453
+ "Epoch 231, Loss: 297.82661242371825\n",
454
+ "Epoch 232, Loss: 297.82661172495017\n",
455
+ "Epoch 233, Loss: 297.82661102625104\n",
456
+ "Epoch 234, Loss: 297.8266103276209\n",
457
+ "Epoch 235, Loss: 297.8266096290597\n",
458
+ "Epoch 236, Loss: 297.82660893056743\n",
459
+ "Epoch 237, Loss: 297.8266082321442\n",
460
+ "Epoch 238, Loss: 297.82660753379\n",
461
+ "Epoch 239, Loss: 297.8266068355049\n",
462
+ "Epoch 240, Loss: 297.82660613728893\n",
463
+ "Epoch 241, Loss: 297.82660543914204\n",
464
+ "Epoch 242, Loss: 297.8266047410642\n",
465
+ "Epoch 243, Loss: 297.82660404305557\n",
466
+ "Epoch 244, Loss: 297.82660334511615\n",
467
+ "Epoch 245, Loss: 297.826602647246\n",
468
+ "Epoch 246, Loss: 297.826601949445\n",
469
+ "Epoch 247, Loss: 297.8266012517133\n",
470
+ "Epoch 248, Loss: 297.82660055405086\n",
471
+ "Epoch 249, Loss: 297.82659985645773\n",
472
+ "Epoch 250, Loss: 297.82659915893396\n",
473
+ "Epoch 251, Loss: 297.8265984614796\n",
474
+ "Epoch 252, Loss: 297.8265977640945\n",
475
+ "Epoch 253, Loss: 297.8265970667789\n",
476
+ "Epoch 254, Loss: 297.8265963695328\n",
477
+ "Epoch 255, Loss: 297.82659567235606\n",
478
+ "Epoch 256, Loss: 297.8265949752488\n",
479
+ "Epoch 257, Loss: 297.8265942782112\n",
480
+ "Epoch 258, Loss: 297.82659358124295\n",
481
+ "Epoch 259, Loss: 297.82659288434434\n",
482
+ "Epoch 260, Loss: 297.82659218751525\n",
483
+ "Epoch 261, Loss: 297.82659149075585\n",
484
+ "Epoch 262, Loss: 297.826590794066\n",
485
+ "Epoch 263, Loss: 297.8265900974459\n",
486
+ "Epoch 264, Loss: 297.8265894008955\n",
487
+ "Epoch 265, Loss: 297.82658870441463\n",
488
+ "Epoch 266, Loss: 297.8265880080037\n",
489
+ "Epoch 267, Loss: 297.8265873116626\n",
490
+ "Epoch 268, Loss: 297.82658661539097\n",
491
+ "Epoch 269, Loss: 297.8265859191893\n",
492
+ "Epoch 270, Loss: 297.8265852230575\n",
493
+ "Epoch 271, Loss: 297.8265845269956\n",
494
+ "Epoch 272, Loss: 297.82658383100346\n",
495
+ "Epoch 273, Loss: 297.82658313508136\n",
496
+ "Epoch 274, Loss: 297.8265824392291\n",
497
+ "Epoch 275, Loss: 297.8265817434468\n",
498
+ "Epoch 276, Loss: 297.82658104773463\n",
499
+ "Epoch 277, Loss: 297.82658035209226\n",
500
+ "Epoch 278, Loss: 297.82657965652004\n",
501
+ "Epoch 279, Loss: 297.8265789610178\n",
502
+ "Epoch 280, Loss: 297.82657826558574\n",
503
+ "Epoch 281, Loss: 297.8265775702237\n",
504
+ "Epoch 282, Loss: 297.82657687493185\n",
505
+ "Epoch 283, Loss: 297.8265761797103\n",
506
+ "Epoch 284, Loss: 297.82657548455876\n",
507
+ "Epoch 285, Loss: 297.82657478947743\n",
508
+ "Epoch 286, Loss: 297.82657409446637\n",
509
+ "Epoch 287, Loss: 297.8265733995255\n",
510
+ "Epoch 288, Loss: 297.826572704655\n",
511
+ "Epoch 289, Loss: 297.8265720098549\n",
512
+ "Epoch 290, Loss: 297.82657131512497\n",
513
+ "Epoch 291, Loss: 297.8265706204655\n",
514
+ "Epoch 292, Loss: 297.82656992587636\n",
515
+ "Epoch 293, Loss: 297.82656923135755\n",
516
+ "Epoch 294, Loss: 297.82656853690935\n",
517
+ "Epoch 295, Loss: 297.8265678425316\n",
518
+ "Epoch 296, Loss: 297.82656714822417\n",
519
+ "Epoch 297, Loss: 297.82656645398737\n",
520
+ "Epoch 298, Loss: 297.8265657598211\n",
521
+ "Epoch 299, Loss: 297.82656506572545\n",
522
+ "Epoch 300, Loss: 297.8265643717003\n",
523
+ "Epoch 301, Loss: 297.8265636777458\n",
524
+ "Epoch 302, Loss: 297.8265629838619\n",
525
+ "Epoch 303, Loss: 297.8265622900487\n",
526
+ "Epoch 304, Loss: 297.8265615963061\n",
527
+ "Epoch 305, Loss: 297.82656090263424\n",
528
+ "Epoch 306, Loss: 297.8265602090333\n",
529
+ "Epoch 307, Loss: 297.82655951550294\n",
530
+ "Epoch 308, Loss: 297.82655882204335\n",
531
+ "Epoch 309, Loss: 297.8265581286547\n",
532
+ "Epoch 310, Loss: 297.82655743533684\n",
533
+ "Epoch 311, Loss: 297.8265567420898\n",
534
+ "Epoch 312, Loss: 297.8265560489138\n",
535
+ "Epoch 313, Loss: 297.8265553558085\n",
536
+ "Epoch 314, Loss: 297.82655466277436\n",
537
+ "Epoch 315, Loss: 297.8265539698111\n",
538
+ "Epoch 316, Loss: 297.8265532769188\n",
539
+ "Epoch 317, Loss: 297.8265525840975\n",
540
+ "Epoch 318, Loss: 297.8265518913472\n",
541
+ "Epoch 319, Loss: 297.8265511986681\n",
542
+ "Epoch 320, Loss: 297.82655050606\n",
543
+ "Epoch 321, Loss: 297.8265498135231\n",
544
+ "Epoch 322, Loss: 297.8265491210572\n",
545
+ "Epoch 323, Loss: 297.82654842866265\n",
546
+ "Epoch 324, Loss: 297.8265477363392\n",
547
+ "Epoch 325, Loss: 297.826547044087\n",
548
+ "Epoch 326, Loss: 297.826546351906\n",
549
+ "Epoch 327, Loss: 297.82654565979635\n",
550
+ "Epoch 328, Loss: 297.8265449677579\n",
551
+ "Epoch 329, Loss: 297.8265442757908\n",
552
+ "Epoch 330, Loss: 297.8265435838951\n",
553
+ "Epoch 331, Loss: 297.82654289207085\n",
554
+ "Epoch 332, Loss: 297.8265422003178\n",
555
+ "Epoch 333, Loss: 297.82654150863635\n",
556
+ "Epoch 334, Loss: 297.8265408170263\n",
557
+ "Epoch 335, Loss: 297.8265401254876\n",
558
+ "Epoch 336, Loss: 297.8265394340205\n",
559
+ "Epoch 337, Loss: 297.826538742625\n",
560
+ "Epoch 338, Loss: 297.8265380513009\n",
561
+ "Epoch 339, Loss: 297.8265373600486\n",
562
+ "Epoch 340, Loss: 297.8265366688676\n",
563
+ "Epoch 341, Loss: 297.82653597775845\n",
564
+ "Epoch 342, Loss: 297.82653528672085\n",
565
+ "Epoch 343, Loss: 297.826534595755\n",
566
+ "Epoch 344, Loss: 297.82653390486087\n",
567
+ "Epoch 345, Loss: 297.8265332140384\n",
568
+ "Epoch 346, Loss: 297.8265325232878\n",
569
+ "Epoch 347, Loss: 297.8265318326088\n",
570
+ "Epoch 348, Loss: 297.8265311420018\n",
571
+ "Epoch 349, Loss: 297.8265304514666\n",
572
+ "Epoch 350, Loss: 297.82652976100314\n",
573
+ "Epoch 351, Loss: 297.82652907061174\n",
574
+ "Epoch 352, Loss: 297.82652838029213\n",
575
+ "Epoch 353, Loss: 297.8265276900445\n",
576
+ "Epoch 354, Loss: 297.82652699986875\n",
577
+ "Epoch 355, Loss: 297.82652630976514\n",
578
+ "Epoch 356, Loss: 297.82652561973345\n",
579
+ "Epoch 357, Loss: 297.82652492977377\n",
580
+ "Epoch 358, Loss: 297.82652423988617\n",
581
+ "Epoch 359, Loss: 297.82652355007065\n",
582
+ "Epoch 360, Loss: 297.8265228603273\n",
583
+ "Epoch 361, Loss: 297.8265221706561\n",
584
+ "Epoch 362, Loss: 297.82652148105706\n",
585
+ "Epoch 363, Loss: 297.8265207915302\n",
586
+ "Epoch 364, Loss: 297.8265201020755\n",
587
+ "Epoch 365, Loss: 297.8265194126932\n",
588
+ "Epoch 366, Loss: 297.82651872338306\n",
589
+ "Epoch 367, Loss: 297.8265180341453\n",
590
+ "Epoch 368, Loss: 297.8265173449798\n",
591
+ "Epoch 369, Loss: 297.8265166558866\n",
592
+ "Epoch 370, Loss: 297.8265159668659\n",
593
+ "Epoch 371, Loss: 297.82651527791745\n",
594
+ "Epoch 372, Loss: 297.8265145890415\n",
595
+ "Epoch 373, Loss: 297.82651390023807\n",
596
+ "Epoch 374, Loss: 297.82651321150695\n",
597
+ "Epoch 375, Loss: 297.82651252284853\n",
598
+ "Epoch 376, Loss: 297.8265118342626\n",
599
+ "Epoch 377, Loss: 297.8265111457491\n",
600
+ "Epoch 378, Loss: 297.82651045730825\n",
601
+ "Epoch 379, Loss: 297.8265097689401\n",
602
+ "Epoch 380, Loss: 297.8265090806445\n",
603
+ "Epoch 381, Loss: 297.8265083924216\n",
604
+ "Epoch 382, Loss: 297.82650770427136\n",
605
+ "Epoch 383, Loss: 297.82650701619394\n",
606
+ "Epoch 384, Loss: 297.82650632818905\n",
607
+ "Epoch 385, Loss: 297.8265056402571\n",
608
+ "Epoch 386, Loss: 297.8265049523977\n",
609
+ "Epoch 387, Loss: 297.82650426461134\n",
610
+ "Epoch 388, Loss: 297.82650357689784\n",
611
+ "Epoch 389, Loss: 297.82650288925714\n",
612
+ "Epoch 390, Loss: 297.82650220168927\n",
613
+ "Epoch 391, Loss: 297.82650151419443\n",
614
+ "Epoch 392, Loss: 297.8265008267725\n",
615
+ "Epoch 393, Loss: 297.8265001394235\n",
616
+ "Epoch 394, Loss: 297.8264994521476\n",
617
+ "Epoch 395, Loss: 297.8264987649447\n",
618
+ "Epoch 396, Loss: 297.82649807781473\n",
619
+ "Epoch 397, Loss: 297.82649739075794\n",
620
+ "Epoch 398, Loss: 297.82649670377424\n",
621
+ "Epoch 399, Loss: 297.82649601686364\n",
622
+ "Epoch 400, Loss: 297.8264953300262\n",
623
+ "Epoch 401, Loss: 297.826494643262\n",
624
+ "Epoch 402, Loss: 297.82649395657097\n",
625
+ "Epoch 403, Loss: 297.8264932699532\n",
626
+ "Epoch 404, Loss: 297.8264925834086\n",
627
+ "Epoch 405, Loss: 297.8264918969374\n",
628
+ "Epoch 406, Loss: 297.8264912105394\n",
629
+ "Epoch 407, Loss: 297.8264905242148\n",
630
+ "Epoch 408, Loss: 297.82648983796355\n",
631
+ "Epoch 409, Loss: 297.8264891517858\n",
632
+ "Epoch 410, Loss: 297.82648846568134\n",
633
+ "Epoch 411, Loss: 297.8264877796504\n",
634
+ "Epoch 412, Loss: 297.8264870936928\n",
635
+ "Epoch 413, Loss: 297.82648640780883\n",
636
+ "Epoch 414, Loss: 297.8264857219984\n",
637
+ "Epoch 415, Loss: 297.8264850362614\n",
638
+ "Epoch 416, Loss: 297.82648435059804\n",
639
+ "Epoch 417, Loss: 297.8264836650082\n",
640
+ "Epoch 418, Loss: 297.8264829794921\n",
641
+ "Epoch 419, Loss: 297.82648229404964\n",
642
+ "Epoch 420, Loss: 297.8264816086808\n",
643
+ "Epoch 421, Loss: 297.8264809233857\n",
644
+ "Epoch 422, Loss: 297.82648023816444\n",
645
+ "Epoch 423, Loss: 297.8264795530168\n",
646
+ "Epoch 424, Loss: 297.82647886794297\n",
647
+ "Epoch 425, Loss: 297.82647818294294\n",
648
+ "Epoch 426, Loss: 297.82647749801686\n",
649
+ "Epoch 427, Loss: 297.8264768131645\n",
650
+ "Epoch 428, Loss: 297.8264761283861\n",
651
+ "Epoch 429, Loss: 297.82647544368155\n",
652
+ "Epoch 430, Loss: 297.82647475905094\n",
653
+ "Epoch 431, Loss: 297.82647407449446\n",
654
+ "Epoch 432, Loss: 297.8264733900118\n",
655
+ "Epoch 433, Loss: 297.82647270560335\n",
656
+ "Epoch 434, Loss: 297.8264720212687\n",
657
+ "Epoch 435, Loss: 297.82647133700834\n",
658
+ "Epoch 436, Loss: 297.8264706528221\n",
659
+ "Epoch 437, Loss: 297.82646996870983\n",
660
+ "Epoch 438, Loss: 297.8264692846718\n",
661
+ "Epoch 439, Loss: 297.8264686007079\n",
662
+ "Epoch 440, Loss: 297.8264679168184\n",
663
+ "Epoch 441, Loss: 297.8264672330029\n",
664
+ "Epoch 442, Loss: 297.8264665492618\n",
665
+ "Epoch 443, Loss: 297.82646586559486\n",
666
+ "Epoch 444, Loss: 297.8264651820023\n",
667
+ "Epoch 445, Loss: 297.82646449848403\n",
668
+ "Epoch 446, Loss: 297.8264638150402\n",
669
+ "Epoch 447, Loss: 297.8264631316708\n",
670
+ "Epoch 448, Loss: 297.8264624483758\n",
671
+ "Epoch 449, Loss: 297.8264617651551\n",
672
+ "Epoch 450, Loss: 297.8264610820091\n",
673
+ "Epoch 451, Loss: 297.82646039893746\n",
674
+ "Epoch 452, Loss: 297.82645971594036\n",
675
+ "Epoch 453, Loss: 297.8264590330179\n",
676
+ "Epoch 454, Loss: 297.8264583501699\n",
677
+ "Epoch 455, Loss: 297.82645766739654\n",
678
+ "Epoch 456, Loss: 297.82645698469787\n",
679
+ "Epoch 457, Loss: 297.8264563020737\n",
680
+ "Epoch 458, Loss: 297.82645561952444\n",
681
+ "Epoch 459, Loss: 297.8264549370498\n",
682
+ "Epoch 460, Loss: 297.82645425464983\n",
683
+ "Epoch 461, Loss: 297.8264535723248\n",
684
+ "Epoch 462, Loss: 297.82645289007445\n",
685
+ "Epoch 463, Loss: 297.8264522078989\n",
686
+ "Epoch 464, Loss: 297.8264515257983\n",
687
+ "Epoch 465, Loss: 297.8264508437724\n",
688
+ "Epoch 466, Loss: 297.8264501618215\n",
689
+ "Epoch 467, Loss: 297.82644947994555\n",
690
+ "Epoch 468, Loss: 297.82644879814444\n",
691
+ "Epoch 469, Loss: 297.8264481164186\n",
692
+ "Epoch 470, Loss: 297.82644743476743\n",
693
+ "Epoch 471, Loss: 297.82644675319153\n",
694
+ "Epoch 472, Loss: 297.82644607169055\n",
695
+ "Epoch 473, Loss: 297.8264453902647\n",
696
+ "Epoch 474, Loss: 297.8264447089139\n",
697
+ "Epoch 475, Loss: 297.82644402763833\n",
698
+ "Epoch 476, Loss: 297.82644334643805\n",
699
+ "Epoch 477, Loss: 297.82644266531275\n",
700
+ "Epoch 478, Loss: 297.8264419842628\n",
701
+ "Epoch 479, Loss: 297.82644130328805\n",
702
+ "Epoch 480, Loss: 297.8264406223886\n",
703
+ "Epoch 481, Loss: 297.8264399415644\n",
704
+ "Epoch 482, Loss: 297.8264392608157\n",
705
+ "Epoch 483, Loss: 297.8264385801421\n",
706
+ "Epoch 484, Loss: 297.82643789954403\n",
707
+ "Epoch 485, Loss: 297.82643721902133\n",
708
+ "Epoch 486, Loss: 297.8264365385741\n",
709
+ "Epoch 487, Loss: 297.8264358582022\n",
710
+ "Epoch 488, Loss: 297.82643517790603\n",
711
+ "Epoch 489, Loss: 297.8264344976852\n",
712
+ "Epoch 490, Loss: 297.82643381753996\n",
713
+ "Epoch 491, Loss: 297.8264331374703\n",
714
+ "Epoch 492, Loss: 297.8264324574763\n",
715
+ "Epoch 493, Loss: 297.82643177755784\n",
716
+ "Epoch 494, Loss: 297.82643109771504\n",
717
+ "Epoch 495, Loss: 297.82643041794796\n",
718
+ "Epoch 496, Loss: 297.8264297382566\n",
719
+ "Epoch 497, Loss: 297.82642905864094\n",
720
+ "Epoch 498, Loss: 297.82642837910106\n",
721
+ "Epoch 499, Loss: 297.82642769963695\n",
722
+ "Epoch 500, Loss: 297.8264270202487\n",
723
+ "Epoch 501, Loss: 297.8264263409362\n",
724
+ "Epoch 502, Loss: 297.82642566169966\n",
725
+ "Epoch 503, Loss: 297.826424982539\n",
726
+ "Epoch 504, Loss: 297.8264243034543\n",
727
+ "Epoch 505, Loss: 297.82642362444545\n",
728
+ "Epoch 506, Loss: 297.8264229455126\n",
729
+ "Epoch 507, Loss: 297.8264222666558\n",
730
+ "Epoch 508, Loss: 297.82642158787496\n",
731
+ "Epoch 509, Loss: 297.82642090917034\n",
732
+ "Epoch 510, Loss: 297.8264202305416\n",
733
+ "Epoch 511, Loss: 297.82641955198915\n",
734
+ "Epoch 512, Loss: 297.82641887351275\n",
735
+ "Epoch 513, Loss: 297.8264181951125\n",
736
+ "Epoch 514, Loss: 297.8264175167885\n",
737
+ "Epoch 515, Loss: 297.82641683854075\n",
738
+ "Epoch 516, Loss: 297.82641616036915\n",
739
+ "Epoch 517, Loss: 297.82641548227383\n",
740
+ "Epoch 518, Loss: 297.82641480425497\n",
741
+ "Epoch 519, Loss: 297.8264141263123\n",
742
+ "Epoch 520, Loss: 297.82641344844603\n",
743
+ "Epoch 521, Loss: 297.8264127706562\n",
744
+ "Epoch 522, Loss: 297.82641209294275\n",
745
+ "Epoch 523, Loss: 297.82641141530576\n",
746
+ "Epoch 524, Loss: 297.8264107377451\n",
747
+ "Epoch 525, Loss: 297.826410060261\n",
748
+ "Epoch 526, Loss: 297.82640938285346\n",
749
+ "Epoch 527, Loss: 297.8264087055225\n",
750
+ "Epoch 528, Loss: 297.826408028268\n",
751
+ "Epoch 529, Loss: 297.82640735109027\n",
752
+ "Epoch 530, Loss: 297.8264066739891\n",
753
+ "Epoch 531, Loss: 297.8264059969646\n",
754
+ "Epoch 532, Loss: 297.8264053200167\n",
755
+ "Epoch 533, Loss: 297.8264046431456\n",
756
+ "Epoch 534, Loss: 297.82640396635117\n",
757
+ "Epoch 535, Loss: 297.8264032896337\n",
758
+ "Epoch 536, Loss: 297.82640261299275\n",
759
+ "Epoch 537, Loss: 297.8264019364288\n",
760
+ "Epoch 538, Loss: 297.82640125994163\n",
761
+ "Epoch 539, Loss: 297.8264005835315\n",
762
+ "Epoch 540, Loss: 297.8263999071981\n",
763
+ "Epoch 541, Loss: 297.82639923094166\n",
764
+ "Epoch 542, Loss: 297.8263985547622\n",
765
+ "Epoch 543, Loss: 297.82639787865975\n",
766
+ "Epoch 544, Loss: 297.82639720263427\n",
767
+ "Epoch 545, Loss: 297.8263965266858\n",
768
+ "Epoch 546, Loss: 297.82639585081455\n",
769
+ "Epoch 547, Loss: 297.82639517502025\n",
770
+ "Epoch 548, Loss: 297.82639449930315\n",
771
+ "Epoch 549, Loss: 297.8263938236632\n",
772
+ "Epoch 550, Loss: 297.8263931481004\n",
773
+ "Epoch 551, Loss: 297.8263924726149\n",
774
+ "Epoch 552, Loss: 297.8263917972065\n",
775
+ "Epoch 553, Loss: 297.8263911218754\n",
776
+ "Epoch 554, Loss: 297.82639044662164\n",
777
+ "Epoch 555, Loss: 297.8263897714451\n",
778
+ "Epoch 556, Loss: 297.8263890963461\n",
779
+ "Epoch 557, Loss: 297.82638842132434\n",
780
+ "Epoch 558, Loss: 297.82638774638\n",
781
+ "Epoch 559, Loss: 297.82638707151307\n",
782
+ "Epoch 560, Loss: 297.82638639672365\n",
783
+ "Epoch 561, Loss: 297.8263857220117\n",
784
+ "Epoch 562, Loss: 297.8263850473772\n",
785
+ "Epoch 563, Loss: 297.8263843728203\n",
786
+ "Epoch 564, Loss: 297.826383698341\n",
787
+ "Epoch 565, Loss: 297.82638302393923\n",
788
+ "Epoch 566, Loss: 297.82638234961513\n",
789
+ "Epoch 567, Loss: 297.8263816753686\n",
790
+ "Epoch 568, Loss: 297.82638100119976\n",
791
+ "Epoch 569, Loss: 297.82638032710867\n",
792
+ "Epoch 570, Loss: 297.82637965309533\n",
793
+ "Epoch 571, Loss: 297.82637897915964\n",
794
+ "Epoch 572, Loss: 297.8263783053019\n",
795
+ "Epoch 573, Loss: 297.8263776315219\n",
796
+ "Epoch 574, Loss: 297.82637695781983\n",
797
+ "Epoch 575, Loss: 297.8263762841955\n",
798
+ "Epoch 576, Loss: 297.8263756106491\n",
799
+ "Epoch 577, Loss: 297.8263749371807\n",
800
+ "Epoch 578, Loss: 297.8263742637902\n",
801
+ "Epoch 579, Loss: 297.82637359047766\n",
802
+ "Epoch 580, Loss: 297.8263729172433\n",
803
+ "Epoch 581, Loss: 297.82637224408677\n",
804
+ "Epoch 582, Loss: 297.8263715710084\n",
805
+ "Epoch 583, Loss: 297.826370898008\n",
806
+ "Epoch 584, Loss: 297.8263702250859\n",
807
+ "Epoch 585, Loss: 297.82636955224183\n",
808
+ "Epoch 586, Loss: 297.82636887947604\n",
809
+ "Epoch 587, Loss: 297.8263682067884\n",
810
+ "Epoch 588, Loss: 297.826367534179\n",
811
+ "Epoch 589, Loss: 297.82636686164784\n",
812
+ "Epoch 590, Loss: 297.8263661891951\n",
813
+ "Epoch 591, Loss: 297.82636551682054\n",
814
+ "Epoch 592, Loss: 297.82636484452433\n",
815
+ "Epoch 593, Loss: 297.8263641723065\n",
816
+ "Epoch 594, Loss: 297.8263635001672\n",
817
+ "Epoch 595, Loss: 297.82636282810614\n",
818
+ "Epoch 596, Loss: 297.8263621561237\n",
819
+ "Epoch 597, Loss: 297.8263614842196\n",
820
+ "Epoch 598, Loss: 297.82636081239417\n",
821
+ "Epoch 599, Loss: 297.8263601406472\n",
822
+ "Epoch 600, Loss: 297.8263594689787\n",
823
+ "Epoch 601, Loss: 297.826358797389\n",
824
+ "Epoch 602, Loss: 297.82635812587785\n",
825
+ "Epoch 603, Loss: 297.82635745444526\n",
826
+ "Epoch 604, Loss: 297.82635678309146\n",
827
+ "Epoch 605, Loss: 297.8263561118164\n",
828
+ "Epoch 606, Loss: 297.82635544062\n",
829
+ "Epoch 607, Loss: 297.8263547695023\n",
830
+ "Epoch 608, Loss: 297.82635409846347\n",
831
+ "Epoch 609, Loss: 297.8263534275034\n",
832
+ "Epoch 610, Loss: 297.8263527566223\n",
833
+ "Epoch 611, Loss: 297.8263520858201\n",
834
+ "Epoch 612, Loss: 297.82635141509684\n",
835
+ "Epoch 613, Loss: 297.8263507444523\n",
836
+ "Epoch 614, Loss: 297.8263500738869\n",
837
+ "Epoch 615, Loss: 297.8263494034004\n",
838
+ "Epoch 616, Loss: 297.8263487329929\n",
839
+ "Epoch 617, Loss: 297.8263480626645\n",
840
+ "Epoch 618, Loss: 297.8263473924152\n",
841
+ "Epoch 619, Loss: 297.82634672224503\n",
842
+ "Epoch 620, Loss: 297.8263460521538\n",
843
+ "Epoch 621, Loss: 297.82634538214194\n",
844
+ "Epoch 622, Loss: 297.8263447122093\n",
845
+ "Epoch 623, Loss: 297.82634404235574\n",
846
+ "Epoch 624, Loss: 297.8263433725814\n",
847
+ "Epoch 625, Loss: 297.8263427028865\n",
848
+ "Epoch 626, Loss: 297.82634203327075\n",
849
+ "Epoch 627, Loss: 297.8263413637344\n",
850
+ "Epoch 628, Loss: 297.82634069427735\n",
851
+ "Epoch 629, Loss: 297.82634002489976\n",
852
+ "Epoch 630, Loss: 297.82633935560165\n",
853
+ "Epoch 631, Loss: 297.8263386863829\n",
854
+ "Epoch 632, Loss: 297.8263380172436\n",
855
+ "Epoch 633, Loss: 297.82633734818376\n",
856
+ "Epoch 634, Loss: 297.8263366792035\n",
857
+ "Epoch 635, Loss: 297.8263360103027\n",
858
+ "Epoch 636, Loss: 297.8263353414817\n",
859
+ "Epoch 637, Loss: 297.82633467274024\n",
860
+ "Epoch 638, Loss: 297.8263340040784\n",
861
+ "Epoch 639, Loss: 297.8263333354962\n",
862
+ "Epoch 640, Loss: 297.8263326669936\n",
863
+ "Epoch 641, Loss: 297.826331998571\n",
864
+ "Epoch 642, Loss: 297.82633133022796\n",
865
+ "Epoch 643, Loss: 297.82633066196473\n",
866
+ "Epoch 644, Loss: 297.82632999378137\n",
867
+ "Epoch 645, Loss: 297.8263293256778\n",
868
+ "Epoch 646, Loss: 297.8263286576541\n",
869
+ "Epoch 647, Loss: 297.8263279897103\n",
870
+ "Epoch 648, Loss: 297.82632732184646\n",
871
+ "Epoch 649, Loss: 297.8263266540626\n",
872
+ "Epoch 650, Loss: 297.8263259863587\n",
873
+ "Epoch 651, Loss: 297.8263253187347\n",
874
+ "Epoch 652, Loss: 297.82632465119093\n",
875
+ "Epoch 653, Loss: 297.82632398372704\n",
876
+ "Epoch 654, Loss: 297.8263233163434\n",
877
+ "Epoch 655, Loss: 297.8263226490398\n",
878
+ "Epoch 656, Loss: 297.82632198181636\n",
879
+ "Epoch 657, Loss: 297.82632131467324\n",
880
+ "Epoch 658, Loss: 297.82632064761026\n",
881
+ "Epoch 659, Loss: 297.82631998062743\n",
882
+ "Epoch 660, Loss: 297.8263193137249\n",
883
+ "Epoch 661, Loss: 297.8263186469027\n",
884
+ "Epoch 662, Loss: 297.82631798016087\n",
885
+ "Epoch 663, Loss: 297.8263173134994\n",
886
+ "Epoch 664, Loss: 297.82631664691814\n",
887
+ "Epoch 665, Loss: 297.82631598041746\n",
888
+ "Epoch 666, Loss: 297.8263153139972\n",
889
+ "Epoch 667, Loss: 297.82631464765734\n",
890
+ "Epoch 668, Loss: 297.8263139813981\n",
891
+ "Epoch 669, Loss: 297.82631331521924\n",
892
+ "Epoch 670, Loss: 297.82631264912106\n",
893
+ "Epoch 671, Loss: 297.82631198310344\n",
894
+ "Epoch 672, Loss: 297.8263113171664\n",
895
+ "Epoch 673, Loss: 297.82631065131005\n",
896
+ "Epoch 674, Loss: 297.8263099855343\n",
897
+ "Epoch 675, Loss: 297.82630931983937\n",
898
+ "Epoch 676, Loss: 297.82630865422504\n",
899
+ "Epoch 677, Loss: 297.8263079886915\n",
900
+ "Epoch 678, Loss: 297.8263073232387\n",
901
+ "Epoch 679, Loss: 297.8263066578669\n",
902
+ "Epoch 680, Loss: 297.82630599257584\n",
903
+ "Epoch 681, Loss: 297.8263053273656\n",
904
+ "Epoch 682, Loss: 297.82630466223634\n",
905
+ "Epoch 683, Loss: 297.8263039971879\n",
906
+ "Epoch 684, Loss: 297.82630333222056\n",
907
+ "Epoch 685, Loss: 297.8263026673342\n",
908
+ "Epoch 686, Loss: 297.8263020025288\n",
909
+ "Epoch 687, Loss: 297.82630133780435\n",
910
+ "Epoch 688, Loss: 297.82630067316114\n",
911
+ "Epoch 689, Loss: 297.826300008599\n",
912
+ "Epoch 690, Loss: 297.8262993441179\n",
913
+ "Epoch 691, Loss: 297.8262986797181\n",
914
+ "Epoch 692, Loss: 297.8262980153994\n",
915
+ "Epoch 693, Loss: 297.82629735116194\n",
916
+ "Epoch 694, Loss: 297.82629668700577\n",
917
+ "Epoch 695, Loss: 297.8262960229308\n",
918
+ "Epoch 696, Loss: 297.8262953589371\n",
919
+ "Epoch 697, Loss: 297.82629469502484\n",
920
+ "Epoch 698, Loss: 297.826294031194\n",
921
+ "Epoch 699, Loss: 297.8262933674444\n",
922
+ "Epoch 700, Loss: 297.8262927037763\n",
923
+ "Epoch 701, Loss: 297.82629204018974\n",
924
+ "Epoch 702, Loss: 297.8262913766845\n",
925
+ "Epoch 703, Loss: 297.82629071326073\n",
926
+ "Epoch 704, Loss: 297.82629004991867\n",
927
+ "Epoch 705, Loss: 297.82628938665823\n",
928
+ "Epoch 706, Loss: 297.8262887234792\n",
929
+ "Epoch 707, Loss: 297.8262880603819\n",
930
+ "Epoch 708, Loss: 297.82628739736623\n",
931
+ "Epoch 709, Loss: 297.8262867344323\n",
932
+ "Epoch 710, Loss: 297.82628607158\n",
933
+ "Epoch 711, Loss: 297.8262854088095\n",
934
+ "Epoch 712, Loss: 297.8262847461207\n",
935
+ "Epoch 713, Loss: 297.82628408351377\n",
936
+ "Epoch 714, Loss: 297.8262834209886\n",
937
+ "Epoch 715, Loss: 297.8262827585453\n",
938
+ "Epoch 716, Loss: 297.8262820961839\n",
939
+ "Epoch 717, Loss: 297.8262814339045\n",
940
+ "Epoch 718, Loss: 297.826280771707\n",
941
+ "Epoch 719, Loss: 297.8262801095915\n",
942
+ "Epoch 720, Loss: 297.82627944755797\n",
943
+ "Epoch 721, Loss: 297.8262787856065\n",
944
+ "Epoch 722, Loss: 297.82627812373715\n",
945
+ "Epoch 723, Loss: 297.8262774619497\n",
946
+ "Epoch 724, Loss: 297.82627680024444\n",
947
+ "Epoch 725, Loss: 297.8262761386215\n",
948
+ "Epoch 726, Loss: 297.8262754770806\n",
949
+ "Epoch 727, Loss: 297.826274815622\n",
950
+ "Epoch 728, Loss: 297.82627415424554\n",
951
+ "Epoch 729, Loss: 297.82627349295143\n",
952
+ "Epoch 730, Loss: 297.8262728317395\n",
953
+ "Epoch 731, Loss: 297.82627217061\n",
954
+ "Epoch 732, Loss: 297.82627150956284\n",
955
+ "Epoch 733, Loss: 297.8262708485981\n",
956
+ "Epoch 734, Loss: 297.82627018771575\n",
957
+ "Epoch 735, Loss: 297.8262695269158\n",
958
+ "Epoch 736, Loss: 297.8262688661983\n",
959
+ "Epoch 737, Loss: 297.82626820556334\n",
960
+ "Epoch 738, Loss: 297.826267545011\n",
961
+ "Epoch 739, Loss: 297.8262668845411\n",
962
+ "Epoch 740, Loss: 297.82626622415387\n",
963
+ "Epoch 741, Loss: 297.82626556384935\n",
964
+ "Epoch 742, Loss: 297.82626490362736\n",
965
+ "Epoch 743, Loss: 297.826264243488\n",
966
+ "Epoch 744, Loss: 297.82626358343146\n",
967
+ "Epoch 745, Loss: 297.82626292345765\n",
968
+ "Epoch 746, Loss: 297.82626226356655\n",
969
+ "Epoch 747, Loss: 297.8262616037582\n",
970
+ "Epoch 748, Loss: 297.8262609440329\n",
971
+ "Epoch 749, Loss: 297.8262602843903\n",
972
+ "Epoch 750, Loss: 297.82625962483047\n",
973
+ "Epoch 751, Loss: 297.82625896535376\n",
974
+ "Epoch 752, Loss: 297.82625830595987\n",
975
+ "Epoch 753, Loss: 297.8262576466491\n",
976
+ "Epoch 754, Loss: 297.82625698742123\n",
977
+ "Epoch 755, Loss: 297.82625632827643\n",
978
+ "Epoch 756, Loss: 297.8262556692147\n",
979
+ "Epoch 757, Loss: 297.82625501023597\n",
980
+ "Epoch 758, Loss: 297.8262543513405\n",
981
+ "Epoch 759, Loss: 297.8262536925281\n",
982
+ "Epoch 760, Loss: 297.82625303379893\n",
983
+ "Epoch 761, Loss: 297.8262523751529\n",
984
+ "Epoch 762, Loss: 297.82625171659015\n",
985
+ "Epoch 763, Loss: 297.82625105811076\n",
986
+ "Epoch 764, Loss: 297.8262503997146\n",
987
+ "Epoch 765, Loss: 297.82624974140174\n",
988
+ "Epoch 766, Loss: 297.82624908317234\n",
989
+ "Epoch 767, Loss: 297.8262484250262\n",
990
+ "Epoch 768, Loss: 297.82624776696355\n",
991
+ "Epoch 769, Loss: 297.8262471089844\n",
992
+ "Epoch 770, Loss: 297.82624645108865\n",
993
+ "Epoch 771, Loss: 297.8262457932765\n",
994
+ "Epoch 772, Loss: 297.82624513554777\n",
995
+ "Epoch 773, Loss: 297.8262444779027\n",
996
+ "Epoch 774, Loss: 297.8262438203412\n",
997
+ "Epoch 775, Loss: 297.8262431628633\n",
998
+ "Epoch 776, Loss: 297.8262425054691\n",
999
+ "Epoch 777, Loss: 297.82624184815865\n",
1000
+ "Epoch 778, Loss: 297.82624119093174\n",
1001
+ "Epoch 779, Loss: 297.82624053378873\n",
1002
+ "Epoch 780, Loss: 297.82623987672946\n",
1003
+ "Epoch 781, Loss: 297.826239219754\n",
1004
+ "Epoch 782, Loss: 297.8262385628624\n",
1005
+ "Epoch 783, Loss: 297.82623790605464\n",
1006
+ "Epoch 784, Loss: 297.82623724933075\n",
1007
+ "Epoch 785, Loss: 297.82623659269086\n",
1008
+ "Epoch 786, Loss: 297.8262359361348\n",
1009
+ "Epoch 787, Loss: 297.8262352796629\n",
1010
+ "Epoch 788, Loss: 297.8262346232749\n",
1011
+ "Epoch 789, Loss: 297.8262339669709\n",
1012
+ "Epoch 790, Loss: 297.8262333107511\n",
1013
+ "Epoch 791, Loss: 297.82623265461535\n",
1014
+ "Epoch 792, Loss: 297.8262319985638\n",
1015
+ "Epoch 793, Loss: 297.8262313425964\n",
1016
+ "Epoch 794, Loss: 297.8262306867131\n",
1017
+ "Epoch 795, Loss: 297.8262300309141\n",
1018
+ "Epoch 796, Loss: 297.82622937519943\n",
1019
+ "Epoch 797, Loss: 297.826228719569\n",
1020
+ "Epoch 798, Loss: 297.82622806402287\n",
1021
+ "Epoch 799, Loss: 297.82622740856107\n",
1022
+ "Epoch 800, Loss: 297.8262267531837\n",
1023
+ "Epoch 801, Loss: 297.82622609789064\n",
1024
+ "Epoch 802, Loss: 297.826225442682\n",
1025
+ "Epoch 803, Loss: 297.826224787558\n",
1026
+ "Epoch 804, Loss: 297.8262241325183\n",
1027
+ "Epoch 805, Loss: 297.8262234775633\n",
1028
+ "Epoch 806, Loss: 297.8262228226928\n",
1029
+ "Epoch 807, Loss: 297.82622216790685\n",
1030
+ "Epoch 808, Loss: 297.8262215132056\n",
1031
+ "Epoch 809, Loss: 297.8262208585888\n",
1032
+ "Epoch 810, Loss: 297.8262202040569\n",
1033
+ "Epoch 811, Loss: 297.8262195496096\n",
1034
+ "Epoch 812, Loss: 297.82621889524705\n",
1035
+ "Epoch 813, Loss: 297.82621824096935\n",
1036
+ "Epoch 814, Loss: 297.8262175867764\n",
1037
+ "Epoch 815, Loss: 297.8262169326682\n",
1038
+ "Epoch 816, Loss: 297.82621627864495\n",
1039
+ "Epoch 817, Loss: 297.82621562470666\n",
1040
+ "Epoch 818, Loss: 297.8262149708532\n",
1041
+ "Epoch 819, Loss: 297.82621431708463\n",
1042
+ "Epoch 820, Loss: 297.8262136634011\n",
1043
+ "Epoch 821, Loss: 297.82621300980264\n",
1044
+ "Epoch 822, Loss: 297.82621235628915\n",
1045
+ "Epoch 823, Loss: 297.82621170286075\n",
1046
+ "Epoch 824, Loss: 297.8262110495175\n",
1047
+ "Epoch 825, Loss: 297.8262103962594\n",
1048
+ "Epoch 826, Loss: 297.8262097430863\n",
1049
+ "Epoch 827, Loss: 297.8262090899985\n",
1050
+ "Epoch 828, Loss: 297.82620843699596\n",
1051
+ "Epoch 829, Loss: 297.8262077840786\n",
1052
+ "Epoch 830, Loss: 297.8262071312466\n",
1053
+ "Epoch 831, Loss: 297.82620647849984\n",
1054
+ "Epoch 832, Loss: 297.8262058258384\n",
1055
+ "Epoch 833, Loss: 297.82620517326245\n",
1056
+ "Epoch 834, Loss: 297.8262045207719\n",
1057
+ "Epoch 835, Loss: 297.82620386836675\n",
1058
+ "Epoch 836, Loss: 297.82620321604696\n",
1059
+ "Epoch 837, Loss: 297.8262025638128\n",
1060
+ "Epoch 838, Loss: 297.82620191166416\n",
1061
+ "Epoch 839, Loss: 297.8262012596011\n",
1062
+ "Epoch 840, Loss: 297.82620060762355\n",
1063
+ "Epoch 841, Loss: 297.8261999557316\n",
1064
+ "Epoch 842, Loss: 297.8261993039254\n",
1065
+ "Epoch 843, Loss: 297.8261986522048\n",
1066
+ "Epoch 844, Loss: 297.82619800056995\n",
1067
+ "Epoch 845, Loss: 297.8261973490208\n",
1068
+ "Epoch 846, Loss: 297.82619669755746\n",
1069
+ "Epoch 847, Loss: 297.8261960461799\n",
1070
+ "Epoch 848, Loss: 297.82619539488826\n",
1071
+ "Epoch 849, Loss: 297.82619474368244\n",
1072
+ "Epoch 850, Loss: 297.82619409256245\n",
1073
+ "Epoch 851, Loss: 297.8261934415284\n",
1074
+ "Epoch 852, Loss: 297.82619279058036\n",
1075
+ "Epoch 853, Loss: 297.82619213971833\n",
1076
+ "Epoch 854, Loss: 297.8261914889422\n",
1077
+ "Epoch 855, Loss: 297.82619083825216\n",
1078
+ "Epoch 856, Loss: 297.82619018764836\n",
1079
+ "Epoch 857, Loss: 297.8261895371304\n",
1080
+ "Epoch 858, Loss: 297.82618888669873\n",
1081
+ "Epoch 859, Loss: 297.8261882363531\n",
1082
+ "Epoch 860, Loss: 297.8261875860939\n",
1083
+ "Epoch 861, Loss: 297.8261869359208\n",
1084
+ "Epoch 862, Loss: 297.826186285834\n",
1085
+ "Epoch 863, Loss: 297.82618563583344\n",
1086
+ "Epoch 864, Loss: 297.8261849859192\n",
1087
+ "Epoch 865, Loss: 297.8261843360914\n",
1088
+ "Epoch 866, Loss: 297.82618368634996\n",
1089
+ "Epoch 867, Loss: 297.8261830366949\n",
1090
+ "Epoch 868, Loss: 297.82618238712627\n",
1091
+ "Epoch 869, Loss: 297.82618173764416\n",
1092
+ "Epoch 870, Loss: 297.8261810882486\n",
1093
+ "Epoch 871, Loss: 297.82618043893956\n",
1094
+ "Epoch 872, Loss: 297.826179789717\n",
1095
+ "Epoch 873, Loss: 297.8261791405811\n",
1096
+ "Epoch 874, Loss: 297.82617849153183\n",
1097
+ "Epoch 875, Loss: 297.82617784256917\n",
1098
+ "Epoch 876, Loss: 297.8261771936933\n",
1099
+ "Epoch 877, Loss: 297.8261765449042\n",
1100
+ "Epoch 878, Loss: 297.8261758962016\n",
1101
+ "Epoch 879, Loss: 297.826175247586\n",
1102
+ "Epoch 880, Loss: 297.82617459905725\n",
1103
+ "Epoch 881, Loss: 297.82617395061516\n",
1104
+ "Epoch 882, Loss: 297.8261733022601\n",
1105
+ "Epoch 883, Loss: 297.82617265399193\n",
1106
+ "Epoch 884, Loss: 297.82617200581075\n",
1107
+ "Epoch 885, Loss: 297.8261713577164\n",
1108
+ "Epoch 886, Loss: 297.8261707097091\n",
1109
+ "Epoch 887, Loss: 297.82617006178884\n",
1110
+ "Epoch 888, Loss: 297.8261694139557\n",
1111
+ "Epoch 889, Loss: 297.82616876620966\n",
1112
+ "Epoch 890, Loss: 297.82616811855064\n",
1113
+ "Epoch 891, Loss: 297.8261674709788\n",
1114
+ "Epoch 892, Loss: 297.82616682349425\n",
1115
+ "Epoch 893, Loss: 297.8261661760969\n",
1116
+ "Epoch 894, Loss: 297.82616552878676\n",
1117
+ "Epoch 895, Loss: 297.82616488156384\n",
1118
+ "Epoch 896, Loss: 297.8261642344284\n",
1119
+ "Epoch 897, Loss: 297.8261635873801\n",
1120
+ "Epoch 898, Loss: 297.82616294041935\n",
1121
+ "Epoch 899, Loss: 297.8261622935459\n",
1122
+ "Epoch 900, Loss: 297.82616164676\n",
1123
+ "Epoch 901, Loss: 297.8261610000614\n",
1124
+ "Epoch 902, Loss: 297.8261603534504\n",
1125
+ "Epoch 903, Loss: 297.82615970692694\n",
1126
+ "Epoch 904, Loss: 297.826159060491\n",
1127
+ "Epoch 905, Loss: 297.82615841414264\n",
1128
+ "Epoch 906, Loss: 297.82615776788197\n",
1129
+ "Epoch 907, Loss: 297.826157121709\n",
1130
+ "Epoch 908, Loss: 297.82615647562363\n",
1131
+ "Epoch 909, Loss: 297.8261558296259\n",
1132
+ "Epoch 910, Loss: 297.8261551837161\n",
1133
+ "Epoch 911, Loss: 297.826154537894\n",
1134
+ "Epoch 912, Loss: 297.82615389215977\n",
1135
+ "Epoch 913, Loss: 297.82615324651323\n",
1136
+ "Epoch 914, Loss: 297.8261526009547\n",
1137
+ "Epoch 915, Loss: 297.826151955484\n",
1138
+ "Epoch 916, Loss: 297.8261513101013\n",
1139
+ "Epoch 917, Loss: 297.8261506648065\n",
1140
+ "Epoch 918, Loss: 297.8261500195997\n",
1141
+ "Epoch 919, Loss: 297.826149374481\n",
1142
+ "Epoch 920, Loss: 297.82614872945044\n",
1143
+ "Epoch 921, Loss: 297.8261480845078\n",
1144
+ "Epoch 922, Loss: 297.8261474396533\n",
1145
+ "Epoch 923, Loss: 297.8261467948871\n",
1146
+ "Epoch 924, Loss: 297.826146150209\n",
1147
+ "Epoch 925, Loss: 297.82614550561914\n",
1148
+ "Epoch 926, Loss: 297.8261448611175\n",
1149
+ "Epoch 927, Loss: 297.82614421670417\n",
1150
+ "Epoch 928, Loss: 297.8261435723791\n",
1151
+ "Epoch 929, Loss: 297.8261429281424\n",
1152
+ "Epoch 930, Loss: 297.8261422839941\n",
1153
+ "Epoch 931, Loss: 297.82614163993424\n",
1154
+ "Epoch 932, Loss: 297.8261409959628\n",
1155
+ "Epoch 933, Loss: 297.82614035207973\n",
1156
+ "Epoch 934, Loss: 297.82613970828527\n",
1157
+ "Epoch 935, Loss: 297.8261390645793\n",
1158
+ "Epoch 936, Loss: 297.826138420962\n",
1159
+ "Epoch 937, Loss: 297.8261377774331\n",
1160
+ "Epoch 938, Loss: 297.82613713399303\n",
1161
+ "Epoch 939, Loss: 297.82613649064143\n",
1162
+ "Epoch 940, Loss: 297.8261358473787\n",
1163
+ "Epoch 941, Loss: 297.8261352042046\n",
1164
+ "Epoch 942, Loss: 297.8261345611193\n",
1165
+ "Epoch 943, Loss: 297.8261339181227\n",
1166
+ "Epoch 944, Loss: 297.826133275215\n",
1167
+ "Epoch 945, Loss: 297.82613263239614\n",
1168
+ "Epoch 946, Loss: 297.8261319896662\n",
1169
+ "Epoch 947, Loss: 297.82613134702507\n",
1170
+ "Epoch 948, Loss: 297.826130704473\n",
1171
+ "Epoch 949, Loss: 297.8261300620098\n",
1172
+ "Epoch 950, Loss: 297.8261294196356\n",
1173
+ "Epoch 951, Loss: 297.8261287773505\n",
1174
+ "Epoch 952, Loss: 297.8261281351545\n",
1175
+ "Epoch 953, Loss: 297.8261274930475\n",
1176
+ "Epoch 954, Loss: 297.82612685102976\n",
1177
+ "Epoch 955, Loss: 297.8261262091011\n",
1178
+ "Epoch 956, Loss: 297.82612556726167\n",
1179
+ "Epoch 957, Loss: 297.8261249255115\n",
1180
+ "Epoch 958, Loss: 297.8261242838506\n",
1181
+ "Epoch 959, Loss: 297.8261236422789\n",
1182
+ "Epoch 960, Loss: 297.8261230007966\n",
1183
+ "Epoch 961, Loss: 297.8261223594037\n",
1184
+ "Epoch 962, Loss: 297.82612171810007\n",
1185
+ "Epoch 963, Loss: 297.82612107688584\n",
1186
+ "Epoch 964, Loss: 297.8261204357612\n",
1187
+ "Epoch 965, Loss: 297.82611979472597\n",
1188
+ "Epoch 966, Loss: 297.8261191537804\n",
1189
+ "Epoch 967, Loss: 297.82611851292415\n",
1190
+ "Epoch 968, Loss: 297.8261178721576\n",
1191
+ "Epoch 969, Loss: 297.8261172314807\n",
1192
+ "Epoch 970, Loss: 297.8261165908933\n",
1193
+ "Epoch 971, Loss: 297.82611595039566\n",
1194
+ "Epoch 972, Loss: 297.8261153099878\n",
1195
+ "Epoch 973, Loss: 297.82611466966955\n",
1196
+ "Epoch 974, Loss: 297.8261140294412\n",
1197
+ "Epoch 975, Loss: 297.8261133893026\n",
1198
+ "Epoch 976, Loss: 297.82611274925375\n",
1199
+ "Epoch 977, Loss: 297.82611210929485\n",
1200
+ "Epoch 978, Loss: 297.82611146942594\n",
1201
+ "Epoch 979, Loss: 297.8261108296469\n",
1202
+ "Epoch 980, Loss: 297.8261101899578\n",
1203
+ "Epoch 981, Loss: 297.8261095503587\n",
1204
+ "Epoch 982, Loss: 297.8261089108496\n",
1205
+ "Epoch 983, Loss: 297.82610827143054\n",
1206
+ "Epoch 984, Loss: 297.8261076321016\n",
1207
+ "Epoch 985, Loss: 297.8261069928628\n",
1208
+ "Epoch 986, Loss: 297.8261063537141\n",
1209
+ "Epoch 987, Loss: 297.8261057146557\n",
1210
+ "Epoch 988, Loss: 297.8261050756874\n",
1211
+ "Epoch 989, Loss: 297.82610443680943\n",
1212
+ "Epoch 990, Loss: 297.82610379802173\n",
1213
+ "Epoch 991, Loss: 297.82610315932436\n",
1214
+ "Epoch 992, Loss: 297.8261025207173\n",
1215
+ "Epoch 993, Loss: 297.8261018822007\n",
1216
+ "Epoch 994, Loss: 297.8261012437744\n",
1217
+ "Epoch 995, Loss: 297.8261006054387\n",
1218
+ "Epoch 996, Loss: 297.82609996719333\n",
1219
+ "Epoch 997, Loss: 297.8260993290385\n",
1220
+ "Epoch 998, Loss: 297.8260986909742\n",
1221
+ "Epoch 999, Loss: 297.8260980530006\n",
1222
+ "final parameters: m=0.45136980910052144, c=0.49775672565271384, sigma=1562.2616856027405\n"
1223
+ ]
1224
+ }
1225
+ ],
1226
+ "source": [
1227
+ "## Problem 2\n",
1228
+ "x = np.array([8, 16, 22, 33, 50, 51])\n",
1229
+ "y = np.array([5, 20, 14, 32, 42, 58])\n",
1230
+ "\n",
1231
+ "# $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$\n",
1232
+ "dsigma = lambda sigma, c, m, x: -len(x) / sigma + np.sum(\n",
1233
+ " [(xi - (m * x + c)) ** 2 for xi in x]\n",
1234
+ ") / (sigma**3)\n",
1235
+ "# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$\n",
1236
+ "dc = lambda sigma, c, m, x: -np.sum([xi - (m * x + c) for xi in x]) / (sigma**2)\n",
1237
+ "# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$\n",
1238
+ "dm = lambda sigma, c, m, x: -np.sum([x * (xi - (m * x + c)) for xi in x]) / (sigma**2)\n",
1239
+ "\n",
1240
+ "\n",
1241
+ "log2 = []\n",
1242
+ "\n",
1243
+ "\n",
1244
+ "def SGD_problem2(\n",
1245
+ " sigma: float,\n",
1246
+ " c: float,\n",
1247
+ " m: float,\n",
1248
+ " x: np.array,\n",
1249
+ " y: np.array,\n",
1250
+ " learning_rate=0.01,\n",
1251
+ " n_epochs=1000,\n",
1252
+ "):\n",
1253
+ " global log2\n",
1254
+ " log2 = []\n",
1255
+ " for epoch in range(n_epochs):\n",
1256
+ " sigma += learning_rate * dsigma(sigma, c, m, x)\n",
1257
+ " c += learning_rate * dc(sigma, c, m, x)\n",
1258
+ " m += learning_rate * dm(sigma, c, m, x)\n",
1259
+ "\n",
1260
+ " log2.append(\n",
1261
+ " {\n",
1262
+ " \"Epoch\": epoch,\n",
1263
+ " \"New sigma\": sigma,\n",
1264
+ " \"New c\": c,\n",
1265
+ " \"New m\": m,\n",
1266
+ " \"dc\": dc(sigma, c, m, x),\n",
1267
+ " \"dm\": dm(sigma, c, m, x),\n",
1268
+ " \"dsigma\": dsigma(sigma, c, m, x),\n",
1269
+ " \"Loss\": loss((m * x + c), sigma, y),\n",
1270
+ " }\n",
1271
+ " )\n",
1272
+ " print(f\"Epoch {epoch}, Loss: {loss((m * x + c), sigma, y)}\")\n",
1273
+ " return np.array([sigma, c, m])\n",
1274
+ "\n",
1275
+ "\n",
1276
+ "result = SGD_problem2(0.5, 0.5, 0.5, x, y)\n",
1277
+ "print(f\"final parameters: m={result[2]}, c={result[1]}, sigma={result[0]}\")"
1278
+ ]
1279
+ },
1280
+ {
1281
+ "cell_type": "markdown",
1282
+ "id": "0562b012-f4ca-47de-bc76-e0eb2bf1e509",
1283
+ "metadata": {},
1284
+ "source": [
1285
+ "loss appears to be decreasing. Uncollapse cell for output"
1286
+ ]
1287
+ },
1288
+ {
1289
+ "cell_type": "markdown",
1290
+ "id": "bed9f3ce-c15c-4f30-8906-26f3e51acf30",
1291
+ "metadata": {},
1292
+ "source": [
1293
+ "# Bike Rides and the Poisson Model"
1294
+ ]
1295
+ },
1296
+ {
1297
+ "cell_type": "markdown",
1298
+ "id": "975e2ef5-f5d5-45a3-b635-8faef035906f",
1299
+ "metadata": {},
1300
+ "source": [
1301
+ "Knowing that the poisson pdf is $P(k) = \\frac{\\lambda^k e^{-\\lambda}}{k!}$, we can find the negative log likelihood of the data as $-\\log(\\Pi_{i=1}^n P(k_i)) = -\\sum_{i=1}^n \\log(\\frac{\\lambda^k_i e^{-\\lambda}}{k_i!}) = \\sum_{i=1}^n -\\ln(\\lambda) k_i + \\ln(k_i!) + \\lambda$. Which simplified, gives $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$. Differentiating with respect to $\\lambda$ gives $n - \\sum_{i=1}^n \\frac{k_i}{\\lambda}$. Which is our desired $\\frac{\\partial L}{\\partial \\lambda}$!"
1302
+ ]
1303
+ },
1304
+ {
1305
+ "cell_type": "code",
1306
+ "execution_count": 10,
1307
+ "id": "3877723c-179e-4759-bed5-9eb70110ded2",
1308
+ "metadata": {
1309
+ "tags": []
1310
+ },
1311
+ "outputs": [
1312
+ {
1313
+ "name": "stdout",
1314
+ "output_type": "stream",
1315
+ "text": [
1316
+ "SGD Problem 3\n",
1317
+ "l: [3215.17703224]\n",
1318
+ "l diff at start 999.4184849065878\n",
1319
+ "l diff at end 535.134976163929\n",
1320
+ "l is improving\n",
1321
+ "SGD Problem 3\n",
1322
+ "l: [2326.70336987]\n",
1323
+ "l diff at start -998.7262223631474\n",
1324
+ "l diff at end -353.33868620734074\n",
1325
+ "l is improving\n"
1326
+ ]
1327
+ },
1328
+ {
1329
+ "data": {
1330
+ "text/html": [
1331
+ "<div>\n",
1332
+ "<style scoped>\n",
1333
+ " .dataframe tbody tr th:only-of-type {\n",
1334
+ " vertical-align: middle;\n",
1335
+ " }\n",
1336
+ "\n",
1337
+ " .dataframe tbody tr th {\n",
1338
+ " vertical-align: top;\n",
1339
+ " }\n",
1340
+ "\n",
1341
+ " .dataframe thead th {\n",
1342
+ " text-align: right;\n",
1343
+ " }\n",
1344
+ "</style>\n",
1345
+ "<table border=\"1\" class=\"dataframe\">\n",
1346
+ " <thead>\n",
1347
+ " <tr style=\"text-align: right;\">\n",
1348
+ " <th></th>\n",
1349
+ " <th>Epoch</th>\n",
1350
+ " <th>New lambda</th>\n",
1351
+ " <th>dlambda</th>\n",
1352
+ " <th>Loss</th>\n",
1353
+ " <th>l_star</th>\n",
1354
+ " </tr>\n",
1355
+ " </thead>\n",
1356
+ " <tbody>\n",
1357
+ " <tr>\n",
1358
+ " <th>0</th>\n",
1359
+ " <td>0</td>\n",
1360
+ " <td>1681.315834</td>\n",
1361
+ " <td>-127.119133</td>\n",
1362
+ " <td>-3.899989e+06</td>\n",
1363
+ " <td>2680.042056</td>\n",
1364
+ " </tr>\n",
1365
+ " <tr>\n",
1366
+ " <th>1</th>\n",
1367
+ " <td>1</td>\n",
1368
+ " <td>1682.587025</td>\n",
1369
+ " <td>-126.861418</td>\n",
1370
+ " <td>-3.900150e+06</td>\n",
1371
+ " <td>2680.042056</td>\n",
1372
+ " </tr>\n",
1373
+ " <tr>\n",
1374
+ " <th>2</th>\n",
1375
+ " <td>2</td>\n",
1376
+ " <td>1683.855639</td>\n",
1377
+ " <td>-126.604614</td>\n",
1378
+ " <td>-3.900311e+06</td>\n",
1379
+ " <td>2680.042056</td>\n",
1380
+ " </tr>\n",
1381
+ " <tr>\n",
1382
+ " <th>3</th>\n",
1383
+ " <td>3</td>\n",
1384
+ " <td>1685.121685</td>\n",
1385
+ " <td>-126.348715</td>\n",
1386
+ " <td>-3.900471e+06</td>\n",
1387
+ " <td>2680.042056</td>\n",
1388
+ " </tr>\n",
1389
+ " <tr>\n",
1390
+ " <th>4</th>\n",
1391
+ " <td>4</td>\n",
1392
+ " <td>1686.385173</td>\n",
1393
+ " <td>-126.093716</td>\n",
1394
+ " <td>-3.900631e+06</td>\n",
1395
+ " <td>2680.042056</td>\n",
1396
+ " </tr>\n",
1397
+ " <tr>\n",
1398
+ " <th>...</th>\n",
1399
+ " <td>...</td>\n",
1400
+ " <td>...</td>\n",
1401
+ " <td>...</td>\n",
1402
+ " <td>...</td>\n",
1403
+ " <td>...</td>\n",
1404
+ " </tr>\n",
1405
+ " <tr>\n",
1406
+ " <th>995</th>\n",
1407
+ " <td>995</td>\n",
1408
+ " <td>2325.399976</td>\n",
1409
+ " <td>-32.636710</td>\n",
1410
+ " <td>-3.948159e+06</td>\n",
1411
+ " <td>2680.042056</td>\n",
1412
+ " </tr>\n",
1413
+ " <tr>\n",
1414
+ " <th>996</th>\n",
1415
+ " <td>996</td>\n",
1416
+ " <td>2325.726343</td>\n",
1417
+ " <td>-32.602100</td>\n",
1418
+ " <td>-3.948170e+06</td>\n",
1419
+ " <td>2680.042056</td>\n",
1420
+ " </tr>\n",
1421
+ " <tr>\n",
1422
+ " <th>997</th>\n",
1423
+ " <td>997</td>\n",
1424
+ " <td>2326.052364</td>\n",
1425
+ " <td>-32.567536</td>\n",
1426
+ " <td>-3.948180e+06</td>\n",
1427
+ " <td>2680.042056</td>\n",
1428
+ " </tr>\n",
1429
+ " <tr>\n",
1430
+ " <th>998</th>\n",
1431
+ " <td>998</td>\n",
1432
+ " <td>2326.378040</td>\n",
1433
+ " <td>-32.533018</td>\n",
1434
+ " <td>-3.948191e+06</td>\n",
1435
+ " <td>2680.042056</td>\n",
1436
+ " </tr>\n",
1437
+ " <tr>\n",
1438
+ " <th>999</th>\n",
1439
+ " <td>999</td>\n",
1440
+ " <td>2326.703370</td>\n",
1441
+ " <td>-32.498547</td>\n",
1442
+ " <td>-3.948201e+06</td>\n",
1443
+ " <td>2680.042056</td>\n",
1444
+ " </tr>\n",
1445
+ " </tbody>\n",
1446
+ "</table>\n",
1447
+ "<p>1000 rows × 5 columns</p>\n",
1448
+ "</div>"
1449
+ ],
1450
+ "text/plain": [
1451
+ " Epoch New lambda dlambda Loss l_star\n",
1452
+ "0 0 1681.315834 -127.119133 -3.899989e+06 2680.042056\n",
1453
+ "1 1 1682.587025 -126.861418 -3.900150e+06 2680.042056\n",
1454
+ "2 2 1683.855639 -126.604614 -3.900311e+06 2680.042056\n",
1455
+ "3 3 1685.121685 -126.348715 -3.900471e+06 2680.042056\n",
1456
+ "4 4 1686.385173 -126.093716 -3.900631e+06 2680.042056\n",
1457
+ ".. ... ... ... ... ...\n",
1458
+ "995 995 2325.399976 -32.636710 -3.948159e+06 2680.042056\n",
1459
+ "996 996 2325.726343 -32.602100 -3.948170e+06 2680.042056\n",
1460
+ "997 997 2326.052364 -32.567536 -3.948180e+06 2680.042056\n",
1461
+ "998 998 2326.378040 -32.533018 -3.948191e+06 2680.042056\n",
1462
+ "999 999 2326.703370 -32.498547 -3.948201e+06 2680.042056\n",
1463
+ "\n",
1464
+ "[1000 rows x 5 columns]"
1465
+ ]
1466
+ },
1467
+ "execution_count": 10,
1468
+ "metadata": {},
1469
+ "output_type": "execute_result"
1470
+ }
1471
+ ],
1472
+ "source": [
1473
+ "import pandas as pd\n",
1474
+ "\n",
1475
+ "df = pd.read_csv(\"../data/01_raw/nyc_bb_bicyclist_counts.csv\")\n",
1476
+ "\n",
1477
+ "dlambda = lambda l, k: len(k) - np.sum([ki / l for ki in k])\n",
1478
+ "\n",
1479
+ "\n",
1480
+ "def SGD_problem3(\n",
1481
+ " l: float,\n",
1482
+ " k: np.array,\n",
1483
+ " learning_rate=0.01,\n",
1484
+ " n_epochs=1000,\n",
1485
+ "):\n",
1486
+ " global log3\n",
1487
+ " log3 = []\n",
1488
+ " for epoch in range(n_epochs):\n",
1489
+ " l -= learning_rate * dlambda(l, k)\n",
1490
+ " # $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$\n",
1491
+ " # the rest of the loss function is commented out because it's a\n",
1492
+ " # constant and was causing overflows. It is unnecessary, and a useless\n",
1493
+ " # pain.\n",
1494
+ " loss = len(k) * l - np.sum(\n",
1495
+ " [ki * np.log(l) for ki in k]\n",
1496
+ " ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n",
1497
+ "\n",
1498
+ " log3.append(\n",
1499
+ " {\n",
1500
+ " \"Epoch\": epoch,\n",
1501
+ " \"New lambda\": l,\n",
1502
+ " \"dlambda\": dlambda(l, k),\n",
1503
+ " \"Loss\": loss,\n",
1504
+ " }\n",
1505
+ " )\n",
1506
+ " # print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n",
1507
+ " return np.array([l])\n",
1508
+ "\n",
1509
+ "\n",
1510
+ "l_star = df[\"BB_COUNT\"].mean()\n",
1511
+ "\n",
1512
+ "\n",
1513
+ "def debug_SGD_3(data, l=1000):\n",
1514
+ " print(\"SGD Problem 3\")\n",
1515
+ " print(f\"l: {SGD_problem3(l, data)}\")\n",
1516
+ " dflog = pd.DataFrame(log3)\n",
1517
+ " dflog[\"l_star\"] = l_star\n",
1518
+ " print(f\"l diff at start {dflog.iloc[0]['New lambda'] - dflog.iloc[0]['l_star']}\")\n",
1519
+ " print(f\"l diff at end {dflog.iloc[-1]['New lambda'] - dflog.iloc[-1]['l_star']}\")\n",
1520
+ " if np.abs(dflog.iloc[-1][\"New lambda\"] - dflog.iloc[-1][\"l_star\"]) < np.abs(\n",
1521
+ " dflog.iloc[0][\"New lambda\"] - dflog.iloc[0][\"l_star\"]\n",
1522
+ " ):\n",
1523
+ " print(\"l is improving\")\n",
1524
+ " else:\n",
1525
+ " print(\"l is not improving\")\n",
1526
+ " return dflog\n",
1527
+ "\n",
1528
+ "\n",
1529
+ "debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star + 1000)\n",
1530
+ "debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star - 1000)"
1531
+ ]
1532
+ },
1533
+ {
1534
+ "cell_type": "markdown",
1535
+ "id": "c05192f9-78ae-4bdb-9df5-cac91006d79f",
1536
+ "metadata": {},
1537
+ "source": [
1538
+ "l approaches the l_star and decreases the loss function."
1539
+ ]
1540
+ },
1541
+ {
1542
+ "cell_type": "markdown",
1543
+ "id": "4955b868-7f67-4760-bf86-39f6edd55871",
1544
+ "metadata": {},
1545
+ "source": [
1546
+ "## Maximum Likelihood II"
1547
+ ]
1548
+ },
1549
+ {
1550
+ "cell_type": "code",
1551
+ "execution_count": 7,
1552
+ "id": "7c8b167d-c397-4155-93f3-d826c279fbb2",
1553
+ "metadata": {
1554
+ "tags": []
1555
+ },
1556
+ "outputs": [
1557
+ {
1558
+ "ename": "SyntaxError",
1559
+ "evalue": "invalid syntax (3400372070.py, line 66)",
1560
+ "output_type": "error",
1561
+ "traceback": [
1562
+ "\u001b[0;36m Cell \u001b[0;32mIn[7], line 66\u001b[0;36m\u001b[0m\n\u001b[0;31m p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
1563
+ ]
1564
+ }
1565
+ ],
1566
+ "source": [
1567
+ "## pset 4\n",
1568
+ "\n",
1569
+ "# dw = lambda w, x: len(x) * np.exp(np.dot(x, w)) * x - np.sum()\n",
1570
+ "\n",
1571
+ "primitive = lambda xi, wi: (x.shape[0] * np.exp(wi * xi) * xi) - (xi**2)\n",
1572
+ "p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])\n",
1573
+ "\n",
1574
+ "\n",
1575
+ "def SGD_problem4(\n",
1576
+ " w: np.array,\n",
1577
+ " x: np.array,\n",
1578
+ " learning_rate=0.01,\n",
1579
+ " n_epochs=1000,\n",
1580
+ "):\n",
1581
+ " global log4\n",
1582
+ " log4 = []\n",
1583
+ " for epoch in range(n_epochs):\n",
1584
+ " w -= learning_rate * p_dw(w, x)\n",
1585
+ " # custom\n",
1586
+ " # loss = x.shape[0] * np.exp(np.dot(x, w))\n",
1587
+ " loss_fn = lambda k, l: len(k) * l - np.sum(\n",
1588
+ " [ki * np.log(l) for ki in k]\n",
1589
+ " ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n",
1590
+ " loss = loss_fn(x, np.exp(np.dot(x, w)))\n",
1591
+ " log4.append(\n",
1592
+ " {\n",
1593
+ " \"Epoch\": epoch,\n",
1594
+ " \"New w\": w,\n",
1595
+ " \"dw\": dw(w, x),\n",
1596
+ " \"Loss\": loss,\n",
1597
+ " }\n",
1598
+ " )\n",
1599
+ " print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n",
1600
+ " return w\n",
1601
+ "\n",
1602
+ "\n",
1603
+ "def debug_SGD_3(data, w=np.array([1, 1])):\n",
1604
+ " print(\"SGD Problem 4\")\n",
1605
+ " print(f\"w: {SGD_problem4(w, data)}\")\n",
1606
+ " dflog = pd.DataFrame(log4)\n",
1607
+ " return dflog\n",
1608
+ "\n",
1609
+ "\n",
1610
+ "_ = debug_SGD_3(\n",
1611
+ " data=df[[\"HIGH_T\", \"LOW_T\", \"PRECIP\"]].to_numpy(),\n",
1612
+ " w=np.array([1.0, 1.0, 1.0]),\n",
1613
+ ")"
1614
+ ]
1615
+ },
1616
+ {
1617
+ "cell_type": "markdown",
1618
+ "id": "69e9148b-70fb-46e3-bc29-a08f471cccab",
1619
+ "metadata": {},
1620
+ "source": [
1621
+ "Seek to maximize likelihood, which is equivalent to minimizing log likelihood, which is equivalent to maximizing MSE.\n",
1622
+ "\n",
1623
+ "-> Find MSE"
1624
+ ]
1625
+ },
1626
+ {
1627
+ "cell_type": "code",
1628
+ "execution_count": 2,
1629
+ "id": "fb9d2e20-8a02-4f78-a3d3-6fc171f17af6",
1630
+ "metadata": {},
1631
+ "outputs": [
1632
+ {
1633
+ "ename": "SyntaxError",
1634
+ "evalue": "incomplete input (3680490224.py, line 3)",
1635
+ "output_type": "error",
1636
+ "traceback": [
1637
+ "\u001b[0;36m Cell \u001b[0;32mIn[2], line 3\u001b[0;36m\u001b[0m\n\u001b[0;31m def MSE(x: List):\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m incomplete input\n"
1638
+ ]
1639
+ }
1640
+ ],
1641
+ "source": [
1642
+ "from typing import List\n",
1643
+ "from numpy import \n",
1644
+ "\n",
1645
+ "w = np.array(2)\n",
1646
+ "\n",
1647
+ "def partial_of_Loss(w, x, y):\n",
1648
+ " \n",
1649
+ " "
1650
+ ]
1651
+ },
1652
+ {
1653
+ "cell_type": "code",
1654
+ "execution_count": null,
1655
+ "id": "7c00197d-873d-41b0-a458-dc8478b40f52",
1656
+ "metadata": {},
1657
+ "outputs": [],
1658
+ "source": []
1659
+ }
1660
+ ],
1661
+ "metadata": {
1662
+ "kernelspec": {
1663
+ "display_name": "Python 3 (ipykernel)",
1664
+ "language": "python",
1665
+ "name": "python3"
1666
+ },
1667
+ "language_info": {
1668
+ "codemirror_mode": {
1669
+ "name": "ipython",
1670
+ "version": 3
1671
+ },
1672
+ "file_extension": ".py",
1673
+ "mimetype": "text/x-python",
1674
+ "name": "python",
1675
+ "nbconvert_exporter": "python",
1676
+ "pygments_lexer": "ipython3",
1677
+ "version": "3.10.4"
1678
+ }
1679
+ },
1680
+ "nbformat": 4,
1681
+ "nbformat_minor": 5
1682
+ }
assignment-2/assignment_2/__init__.py ADDED
File without changes
assignment-2/data/01_raw/nyc_bb_bicyclist_counts.csv ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date,HIGH_T,LOW_T,PRECIP,BB_COUNT
2
+ 1-Apr-17,46.00,37.00,0.00,606
3
+ 2-Apr-17,62.10,41.00,0.00,2021
4
+ 3-Apr-17,63.00,50.00,0.03,2470
5
+ 4-Apr-17,51.10,46.00,1.18,723
6
+ 5-Apr-17,63.00,46.00,0.00,2807
7
+ 6-Apr-17,48.90,41.00,0.73,461
8
+ 7-Apr-17,48.00,43.00,0.01,1222
9
+ 8-Apr-17,55.90,39.90,0.00,1674
10
+ 9-Apr-17,66.00,45.00,0.00,2375
11
+ 10-Apr-17,73.90,55.00,0.00,3324
12
+ 11-Apr-17,80.10,62.10,0.00,3887
13
+ 12-Apr-17,73.90,57.90,0.02,2565
14
+ 13-Apr-17,64.00,48.90,0.00,3353
15
+ 14-Apr-17,64.90,48.90,0.00,2942
16
+ 15-Apr-17,64.90,52.00,0.00,2253
17
+ 16-Apr-17,84.90,62.10,0.01,2877
18
+ 17-Apr-17,73.90,64.00,0.01,3152
19
+ 18-Apr-17,66.00,50.00,0.00,3415
20
+ 19-Apr-17,52.00,45.00,0.01,1965
21
+ 20-Apr-17,64.90,50.00,0.17,1567
22
+ 21-Apr-17,53.10,48.00,0.29,1426
23
+ 22-Apr-17,55.90,52.00,0.11,1318
24
+ 23-Apr-17,64.90,46.90,0.00,2520
25
+ 24-Apr-17,60.10,50.00,0.01,2544
26
+ 25-Apr-17,54.00,50.00,0.91,611
27
+ 26-Apr-17,59.00,54.00,0.34,1247
28
+ 27-Apr-17,68.00,59.00,0.00,2959
29
+ 28-Apr-17,82.90,57.90,0.00,3679
30
+ 29-Apr-17,84.00,64.00,0.06,3315
31
+ 30-Apr-17,64.00,54.00,0.00,2225
32
+ 1-May-17,72.00,50.00,0.00,3084
33
+ 2-May-17,73.90,66.90,0.00,3423
34
+ 3-May-17,64.90,57.90,0.00,3342
35
+ 4-May-17,63.00,50.00,0.00,3019
36
+ 5-May-17,59.00,52.00,3.02,513
37
+ 6-May-17,64.90,57.00,0.18,1892
38
+ 7-May-17,54.00,48.90,0.01,3539
39
+ 8-May-17,57.00,45.00,0.00,2886
40
+ 9-May-17,61.00,48.00,0.00,2718
41
+ 10-May-17,70.00,51.10,0.00,2810
42
+ 11-May-17,61.00,51.80,0.00,2657
43
+ 12-May-17,62.10,51.10,0.00,2640
44
+ 13-May-17,51.10,45.00,1.31,151
45
+ 14-May-17,64.90,46.00,0.02,1452
46
+ 15-May-17,66.90,55.90,0.00,2685
47
+ 16-May-17,78.10,57.90,0.00,3666
48
+ 17-May-17,90.00,66.00,0.00,3535
49
+ 18-May-17,91.90,75.00,0.00,3190
50
+ 19-May-17,90.00,75.90,0.00,2952
51
+ 20-May-17,64.00,55.90,0.01,2161
52
+ 21-May-17,66.90,55.00,0.00,2612
53
+ 22-May-17,61.00,54.00,0.59,768
54
+ 23-May-17,68.00,57.90,0.00,3174
55
+ 24-May-17,66.90,57.00,0.04,2969
56
+ 25-May-17,57.90,55.90,0.58,488
57
+ 26-May-17,73.00,55.90,0.10,2590
58
+ 27-May-17,71.10,61.00,0.00,2609
59
+ 28-May-17,71.10,59.00,0.00,2640
60
+ 29-May-17,57.90,55.90,0.13,836
61
+ 30-May-17,59.00,55.90,0.06,2301
62
+ 31-May-17,75.00,57.90,0.03,2689
63
+ 1-Jun-17,78.10,62.10,0.00,3468
64
+ 2-Jun-17,73.90,60.10,0.01,3271
65
+ 3-Jun-17,72.00,55.00,0.01,2589
66
+ 4-Jun-17,68.00,60.10,0.09,1805
67
+ 5-Jun-17,66.90,60.10,0.02,2171
68
+ 6-Jun-17,55.90,53.10,0.06,1193
69
+ 7-Jun-17,66.90,54.00,0.00,3211
70
+ 8-Jun-17,68.00,59.00,0.00,3253
71
+ 9-Jun-17,80.10,59.00,0.00,3401
72
+ 10-Jun-17,84.00,68.00,0.00,3066
73
+ 11-Jun-17,90.00,73.00,0.00,2465
74
+ 12-Jun-17,91.90,77.00,0.00,2854
75
+ 13-Jun-17,93.90,78.10,0.01,2882
76
+ 14-Jun-17,84.00,69.10,0.29,2596
77
+ 15-Jun-17,75.00,66.00,0.00,3510
78
+ 16-Jun-17,68.00,66.00,0.00,2054
79
+ 17-Jun-17,73.00,66.90,1.39,1399
80
+ 18-Jun-17,84.00,72.00,0.01,2199
81
+ 19-Jun-17,87.10,70.00,1.35,1648
82
+ 20-Jun-17,82.00,72.00,0.03,3407
83
+ 21-Jun-17,82.00,72.00,0.00,3304
84
+ 22-Jun-17,82.00,70.00,0.00,3368
85
+ 23-Jun-17,82.90,75.90,0.04,2283
86
+ 24-Jun-17,82.90,71.10,1.29,2307
87
+ 25-Jun-17,82.00,69.10,0.00,2625
88
+ 26-Jun-17,78.10,66.00,0.00,3386
89
+ 27-Jun-17,75.90,61.00,0.18,3182
90
+ 28-Jun-17,78.10,62.10,0.00,3766
91
+ 29-Jun-17,81.00,68.00,0.00,3356
92
+ 30-Jun-17,88.00,73.90,0.01,2687
93
+ 1-Jul-17,84.90,72.00,0.23,1848
94
+ 2-Jul-17,87.10,73.00,0.00,2467
95
+ 3-Jul-17,87.10,71.10,0.45,2714
96
+ 4-Jul-17,82.90,70.00,0.00,2296
97
+ 5-Jul-17,84.90,71.10,0.00,3170
98
+ 6-Jul-17,75.00,71.10,0.01,3065
99
+ 7-Jul-17,79.00,68.00,1.78,1513
100
+ 8-Jul-17,82.90,70.00,0.00,2718
101
+ 9-Jul-17,81.00,69.10,0.00,3048
102
+ 10-Jul-17,82.90,71.10,0.00,3506
103
+ 11-Jul-17,84.00,75.00,0.00,2929
104
+ 12-Jul-17,87.10,77.00,0.00,2860
105
+ 13-Jul-17,89.10,77.00,0.00,2563
106
+ 14-Jul-17,69.10,64.90,0.35,907
107
+ 15-Jul-17,82.90,68.00,0.00,2853
108
+ 16-Jul-17,84.90,70.00,0.00,2917
109
+ 17-Jul-17,84.90,73.90,0.00,3264
110
+ 18-Jul-17,87.10,75.90,0.00,3507
111
+ 19-Jul-17,91.00,77.00,0.00,3114
112
+ 20-Jul-17,93.00,78.10,0.01,2840
113
+ 21-Jul-17,91.00,77.00,0.00,2751
114
+ 22-Jul-17,91.00,78.10,0.57,2301
115
+ 23-Jul-17,78.10,73.00,0.06,2321
116
+ 24-Jul-17,69.10,63.00,0.74,1576
117
+ 25-Jul-17,71.10,64.00,0.00,3191
118
+ 26-Jul-17,75.90,66.00,0.00,3821
119
+ 27-Jul-17,77.00,66.90,0.01,3287
120
+ 28-Jul-17,84.90,73.00,0.00,3123
121
+ 29-Jul-17,75.90,68.00,0.00,2074
122
+ 30-Jul-17,81.00,64.90,0.00,3331
123
+ 31-Jul-17,88.00,66.90,0.00,3560
124
+ 1-Aug-17,91.00,72.00,0.00,3492
125
+ 2-Aug-17,86.00,69.10,0.09,2637
126
+ 3-Aug-17,86.00,70.00,0.00,3346
127
+ 4-Aug-17,82.90,70.00,0.15,2400
128
+ 5-Aug-17,77.00,70.00,0.30,3409
129
+ 6-Aug-17,75.90,64.00,0.00,3130
130
+ 7-Aug-17,71.10,64.90,0.76,804
131
+ 8-Aug-17,77.00,66.00,0.00,3598
132
+ 9-Aug-17,82.90,66.00,0.00,3893
133
+ 10-Aug-17,82.90,69.10,0.00,3423
134
+ 11-Aug-17,81.00,70.00,0.01,3148
135
+ 12-Aug-17,75.90,64.90,0.11,4146
136
+ 13-Aug-17,82.00,71.10,0.00,3274
137
+ 14-Aug-17,80.10,70.00,0.00,3291
138
+ 15-Aug-17,73.00,69.10,0.45,2149
139
+ 16-Aug-17,84.90,70.00,0.00,3685
140
+ 17-Aug-17,82.00,71.10,0.00,3637
141
+ 18-Aug-17,81.00,73.00,0.88,1064
142
+ 19-Aug-17,84.90,73.00,0.00,4693
143
+ 20-Aug-17,81.00,70.00,0.00,2822
144
+ 21-Aug-17,84.90,73.00,0.00,3088
145
+ 22-Aug-17,88.00,75.00,0.30,2983
146
+ 23-Aug-17,80.10,71.10,0.01,2994
147
+ 24-Aug-17,79.00,66.00,0.00,3688
148
+ 25-Aug-17,78.10,64.00,0.00,3144
149
+ 26-Aug-17,77.00,62.10,0.00,2710
150
+ 27-Aug-17,77.00,63.00,0.00,2676
151
+ 28-Aug-17,75.00,63.00,0.00,3332
152
+ 29-Aug-17,68.00,62.10,0.10,1472
153
+ 30-Aug-17,75.90,61.00,0.01,3468
154
+ 31-Aug-17,81.00,64.00,0.00,3279
155
+ 1-Sep-17,70.00,55.00,0.00,2945
156
+ 2-Sep-17,66.90,54.00,0.53,1876
157
+ 3-Sep-17,69.10,60.10,0.74,1004
158
+ 4-Sep-17,79.00,62.10,0.00,2866
159
+ 5-Sep-17,84.00,70.00,0.01,3244
160
+ 6-Sep-17,70.00,62.10,0.42,1232
161
+ 7-Sep-17,71.10,59.00,0.01,3249
162
+ 8-Sep-17,70.00,59.00,0.00,3234
163
+ 9-Sep-17,69.10,55.00,0.00,2609
164
+ 10-Sep-17,72.00,57.00,0.00,4960
165
+ 11-Sep-17,75.90,55.00,0.00,3657
166
+ 12-Sep-17,78.10,61.00,0.00,3497
167
+ 13-Sep-17,82.00,64.90,0.06,2994
168
+ 14-Sep-17,81.00,70.00,0.02,3013
169
+ 15-Sep-17,81.00,66.90,0.00,3344
170
+ 16-Sep-17,82.00,70.00,0.00,2560
171
+ 17-Sep-17,80.10,70.00,0.00,2676
172
+ 18-Sep-17,73.00,69.10,0.00,2673
173
+ 19-Sep-17,78.10,69.10,0.22,2012
174
+ 20-Sep-17,78.10,71.10,0.00,3296
175
+ 21-Sep-17,80.10,71.10,0.00,3317
176
+ 22-Sep-17,82.00,66.00,0.00,3297
177
+ 23-Sep-17,86.00,68.00,0.00,2810
178
+ 24-Sep-17,90.00,69.10,0.00,2543
179
+ 25-Sep-17,87.10,72.00,0.00,3276
180
+ 26-Sep-17,82.00,69.10,0.00,3157
181
+ 27-Sep-17,84.90,71.10,0.00,3216
182
+ 28-Sep-17,78.10,66.00,0.00,3421
183
+ 29-Sep-17,66.90,55.00,0.00,2988
184
+ 30-Sep-17,64.00,55.90,0.00,1903
185
+ 1-Oct-17,66.90,50.00,0.00,2297
186
+ 2-Oct-17,72.00,52.00,0.00,3387
187
+ 3-Oct-17,70.00,57.00,0.00,3386
188
+ 4-Oct-17,75.00,55.90,0.00,3412
189
+ 5-Oct-17,82.00,64.90,0.00,3312
190
+ 6-Oct-17,81.00,69.10,0.00,2982
191
+ 7-Oct-17,80.10,66.00,0.00,2750
192
+ 8-Oct-17,77.00,72.00,0.22,1235
193
+ 9-Oct-17,75.90,72.00,0.26,898
194
+ 10-Oct-17,80.10,66.00,0.00,3922
195
+ 11-Oct-17,75.00,64.90,0.06,2721
196
+ 12-Oct-17,63.00,55.90,0.07,2411
197
+ 13-Oct-17,64.90,52.00,0.00,2839
198
+ 14-Oct-17,71.10,62.10,0.08,2021
199
+ 15-Oct-17,72.00,66.00,0.01,2169
200
+ 16-Oct-17,60.10,52.00,0.01,2751
201
+ 17-Oct-17,57.90,43.00,0.00,2869
202
+ 18-Oct-17,71.10,50.00,0.00,3264
203
+ 19-Oct-17,70.00,55.90,0.00,3265
204
+ 20-Oct-17,73.00,57.90,0.00,3169
205
+ 21-Oct-17,78.10,57.00,0.00,2538
206
+ 22-Oct-17,75.90,57.00,0.00,2744
207
+ 23-Oct-17,73.90,64.00,0.00,3189
208
+ 24-Oct-17,73.00,66.90,0.20,954
209
+ 25-Oct-17,64.90,57.90,0.00,3367
210
+ 26-Oct-17,57.00,53.10,0.00,2565
211
+ 27-Oct-17,62.10,48.00,0.00,3150
212
+ 28-Oct-17,68.00,55.90,0.00,2245
213
+ 29-Oct-17,64.90,61.00,3.03,183
214
+ 30-Oct-17,55.00,46.00,0.25,1428
215
+ 31-Oct-17,54.00,44.00,0.00,2727
assignment-2/data/nyc_bb_bicyclist_counts.csv ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Date,HIGH_T,LOW_T,PRECIP,BB_COUNT
2
+ 1-Apr-17,46.00,37.00,0.00,606
3
+ 2-Apr-17,62.10,41.00,0.00,2021
4
+ 3-Apr-17,63.00,50.00,0.03,2470
5
+ 4-Apr-17,51.10,46.00,1.18,723
6
+ 5-Apr-17,63.00,46.00,0.00,2807
7
+ 6-Apr-17,48.90,41.00,0.73,461
8
+ 7-Apr-17,48.00,43.00,0.01,1222
9
+ 8-Apr-17,55.90,39.90,0.00,1674
10
+ 9-Apr-17,66.00,45.00,0.00,2375
11
+ 10-Apr-17,73.90,55.00,0.00,3324
12
+ 11-Apr-17,80.10,62.10,0.00,3887
13
+ 12-Apr-17,73.90,57.90,0.02,2565
14
+ 13-Apr-17,64.00,48.90,0.00,3353
15
+ 14-Apr-17,64.90,48.90,0.00,2942
16
+ 15-Apr-17,64.90,52.00,0.00,2253
17
+ 16-Apr-17,84.90,62.10,0.01,2877
18
+ 17-Apr-17,73.90,64.00,0.01,3152
19
+ 18-Apr-17,66.00,50.00,0.00,3415
20
+ 19-Apr-17,52.00,45.00,0.01,1965
21
+ 20-Apr-17,64.90,50.00,0.17,1567
22
+ 21-Apr-17,53.10,48.00,0.29,1426
23
+ 22-Apr-17,55.90,52.00,0.11,1318
24
+ 23-Apr-17,64.90,46.90,0.00,2520
25
+ 24-Apr-17,60.10,50.00,0.01,2544
26
+ 25-Apr-17,54.00,50.00,0.91,611
27
+ 26-Apr-17,59.00,54.00,0.34,1247
28
+ 27-Apr-17,68.00,59.00,0.00,2959
29
+ 28-Apr-17,82.90,57.90,0.00,3679
30
+ 29-Apr-17,84.00,64.00,0.06,3315
31
+ 30-Apr-17,64.00,54.00,0.00,2225
32
+ 1-May-17,72.00,50.00,0.00,3084
33
+ 2-May-17,73.90,66.90,0.00,3423
34
+ 3-May-17,64.90,57.90,0.00,3342
35
+ 4-May-17,63.00,50.00,0.00,3019
36
+ 5-May-17,59.00,52.00,3.02,513
37
+ 6-May-17,64.90,57.00,0.18,1892
38
+ 7-May-17,54.00,48.90,0.01,3539
39
+ 8-May-17,57.00,45.00,0.00,2886
40
+ 9-May-17,61.00,48.00,0.00,2718
41
+ 10-May-17,70.00,51.10,0.00,2810
42
+ 11-May-17,61.00,51.80,0.00,2657
43
+ 12-May-17,62.10,51.10,0.00,2640
44
+ 13-May-17,51.10,45.00,1.31,151
45
+ 14-May-17,64.90,46.00,0.02,1452
46
+ 15-May-17,66.90,55.90,0.00,2685
47
+ 16-May-17,78.10,57.90,0.00,3666
48
+ 17-May-17,90.00,66.00,0.00,3535
49
+ 18-May-17,91.90,75.00,0.00,3190
50
+ 19-May-17,90.00,75.90,0.00,2952
51
+ 20-May-17,64.00,55.90,0.01,2161
52
+ 21-May-17,66.90,55.00,0.00,2612
53
+ 22-May-17,61.00,54.00,0.59,768
54
+ 23-May-17,68.00,57.90,0.00,3174
55
+ 24-May-17,66.90,57.00,0.04,2969
56
+ 25-May-17,57.90,55.90,0.58,488
57
+ 26-May-17,73.00,55.90,0.10,2590
58
+ 27-May-17,71.10,61.00,0.00,2609
59
+ 28-May-17,71.10,59.00,0.00,2640
60
+ 29-May-17,57.90,55.90,0.13,836
61
+ 30-May-17,59.00,55.90,0.06,2301
62
+ 31-May-17,75.00,57.90,0.03,2689
63
+ 1-Jun-17,78.10,62.10,0.00,3468
64
+ 2-Jun-17,73.90,60.10,0.01,3271
65
+ 3-Jun-17,72.00,55.00,0.01,2589
66
+ 4-Jun-17,68.00,60.10,0.09,1805
67
+ 5-Jun-17,66.90,60.10,0.02,2171
68
+ 6-Jun-17,55.90,53.10,0.06,1193
69
+ 7-Jun-17,66.90,54.00,0.00,3211
70
+ 8-Jun-17,68.00,59.00,0.00,3253
71
+ 9-Jun-17,80.10,59.00,0.00,3401
72
+ 10-Jun-17,84.00,68.00,0.00,3066
73
+ 11-Jun-17,90.00,73.00,0.00,2465
74
+ 12-Jun-17,91.90,77.00,0.00,2854
75
+ 13-Jun-17,93.90,78.10,0.01,2882
76
+ 14-Jun-17,84.00,69.10,0.29,2596
77
+ 15-Jun-17,75.00,66.00,0.00,3510
78
+ 16-Jun-17,68.00,66.00,0.00,2054
79
+ 17-Jun-17,73.00,66.90,1.39,1399
80
+ 18-Jun-17,84.00,72.00,0.01,2199
81
+ 19-Jun-17,87.10,70.00,1.35,1648
82
+ 20-Jun-17,82.00,72.00,0.03,3407
83
+ 21-Jun-17,82.00,72.00,0.00,3304
84
+ 22-Jun-17,82.00,70.00,0.00,3368
85
+ 23-Jun-17,82.90,75.90,0.04,2283
86
+ 24-Jun-17,82.90,71.10,1.29,2307
87
+ 25-Jun-17,82.00,69.10,0.00,2625
88
+ 26-Jun-17,78.10,66.00,0.00,3386
89
+ 27-Jun-17,75.90,61.00,0.18,3182
90
+ 28-Jun-17,78.10,62.10,0.00,3766
91
+ 29-Jun-17,81.00,68.00,0.00,3356
92
+ 30-Jun-17,88.00,73.90,0.01,2687
93
+ 1-Jul-17,84.90,72.00,0.23,1848
94
+ 2-Jul-17,87.10,73.00,0.00,2467
95
+ 3-Jul-17,87.10,71.10,0.45,2714
96
+ 4-Jul-17,82.90,70.00,0.00,2296
97
+ 5-Jul-17,84.90,71.10,0.00,3170
98
+ 6-Jul-17,75.00,71.10,0.01,3065
99
+ 7-Jul-17,79.00,68.00,1.78,1513
100
+ 8-Jul-17,82.90,70.00,0.00,2718
101
+ 9-Jul-17,81.00,69.10,0.00,3048
102
+ 10-Jul-17,82.90,71.10,0.00,3506
103
+ 11-Jul-17,84.00,75.00,0.00,2929
104
+ 12-Jul-17,87.10,77.00,0.00,2860
105
+ 13-Jul-17,89.10,77.00,0.00,2563
106
+ 14-Jul-17,69.10,64.90,0.35,907
107
+ 15-Jul-17,82.90,68.00,0.00,2853
108
+ 16-Jul-17,84.90,70.00,0.00,2917
109
+ 17-Jul-17,84.90,73.90,0.00,3264
110
+ 18-Jul-17,87.10,75.90,0.00,3507
111
+ 19-Jul-17,91.00,77.00,0.00,3114
112
+ 20-Jul-17,93.00,78.10,0.01,2840
113
+ 21-Jul-17,91.00,77.00,0.00,2751
114
+ 22-Jul-17,91.00,78.10,0.57,2301
115
+ 23-Jul-17,78.10,73.00,0.06,2321
116
+ 24-Jul-17,69.10,63.00,0.74,1576
117
+ 25-Jul-17,71.10,64.00,0.00,3191
118
+ 26-Jul-17,75.90,66.00,0.00,3821
119
+ 27-Jul-17,77.00,66.90,0.01,3287
120
+ 28-Jul-17,84.90,73.00,0.00,3123
121
+ 29-Jul-17,75.90,68.00,0.00,2074
122
+ 30-Jul-17,81.00,64.90,0.00,3331
123
+ 31-Jul-17,88.00,66.90,0.00,3560
124
+ 1-Aug-17,91.00,72.00,0.00,3492
125
+ 2-Aug-17,86.00,69.10,0.09,2637
126
+ 3-Aug-17,86.00,70.00,0.00,3346
127
+ 4-Aug-17,82.90,70.00,0.15,2400
128
+ 5-Aug-17,77.00,70.00,0.30,3409
129
+ 6-Aug-17,75.90,64.00,0.00,3130
130
+ 7-Aug-17,71.10,64.90,0.76,804
131
+ 8-Aug-17,77.00,66.00,0.00,3598
132
+ 9-Aug-17,82.90,66.00,0.00,3893
133
+ 10-Aug-17,82.90,69.10,0.00,3423
134
+ 11-Aug-17,81.00,70.00,0.01,3148
135
+ 12-Aug-17,75.90,64.90,0.11,4146
136
+ 13-Aug-17,82.00,71.10,0.00,3274
137
+ 14-Aug-17,80.10,70.00,0.00,3291
138
+ 15-Aug-17,73.00,69.10,0.45,2149
139
+ 16-Aug-17,84.90,70.00,0.00,3685
140
+ 17-Aug-17,82.00,71.10,0.00,3637
141
+ 18-Aug-17,81.00,73.00,0.88,1064
142
+ 19-Aug-17,84.90,73.00,0.00,4693
143
+ 20-Aug-17,81.00,70.00,0.00,2822
144
+ 21-Aug-17,84.90,73.00,0.00,3088
145
+ 22-Aug-17,88.00,75.00,0.30,2983
146
+ 23-Aug-17,80.10,71.10,0.01,2994
147
+ 24-Aug-17,79.00,66.00,0.00,3688
148
+ 25-Aug-17,78.10,64.00,0.00,3144
149
+ 26-Aug-17,77.00,62.10,0.00,2710
150
+ 27-Aug-17,77.00,63.00,0.00,2676
151
+ 28-Aug-17,75.00,63.00,0.00,3332
152
+ 29-Aug-17,68.00,62.10,0.10,1472
153
+ 30-Aug-17,75.90,61.00,0.01,3468
154
+ 31-Aug-17,81.00,64.00,0.00,3279
155
+ 1-Sep-17,70.00,55.00,0.00,2945
156
+ 2-Sep-17,66.90,54.00,0.53,1876
157
+ 3-Sep-17,69.10,60.10,0.74,1004
158
+ 4-Sep-17,79.00,62.10,0.00,2866
159
+ 5-Sep-17,84.00,70.00,0.01,3244
160
+ 6-Sep-17,70.00,62.10,0.42,1232
161
+ 7-Sep-17,71.10,59.00,0.01,3249
162
+ 8-Sep-17,70.00,59.00,0.00,3234
163
+ 9-Sep-17,69.10,55.00,0.00,2609
164
+ 10-Sep-17,72.00,57.00,0.00,4960
165
+ 11-Sep-17,75.90,55.00,0.00,3657
166
+ 12-Sep-17,78.10,61.00,0.00,3497
167
+ 13-Sep-17,82.00,64.90,0.06,2994
168
+ 14-Sep-17,81.00,70.00,0.02,3013
169
+ 15-Sep-17,81.00,66.90,0.00,3344
170
+ 16-Sep-17,82.00,70.00,0.00,2560
171
+ 17-Sep-17,80.10,70.00,0.00,2676
172
+ 18-Sep-17,73.00,69.10,0.00,2673
173
+ 19-Sep-17,78.10,69.10,0.22,2012
174
+ 20-Sep-17,78.10,71.10,0.00,3296
175
+ 21-Sep-17,80.10,71.10,0.00,3317
176
+ 22-Sep-17,82.00,66.00,0.00,3297
177
+ 23-Sep-17,86.00,68.00,0.00,2810
178
+ 24-Sep-17,90.00,69.10,0.00,2543
179
+ 25-Sep-17,87.10,72.00,0.00,3276
180
+ 26-Sep-17,82.00,69.10,0.00,3157
181
+ 27-Sep-17,84.90,71.10,0.00,3216
182
+ 28-Sep-17,78.10,66.00,0.00,3421
183
+ 29-Sep-17,66.90,55.00,0.00,2988
184
+ 30-Sep-17,64.00,55.90,0.00,1903
185
+ 1-Oct-17,66.90,50.00,0.00,2297
186
+ 2-Oct-17,72.00,52.00,0.00,3387
187
+ 3-Oct-17,70.00,57.00,0.00,3386
188
+ 4-Oct-17,75.00,55.90,0.00,3412
189
+ 5-Oct-17,82.00,64.90,0.00,3312
190
+ 6-Oct-17,81.00,69.10,0.00,2982
191
+ 7-Oct-17,80.10,66.00,0.00,2750
192
+ 8-Oct-17,77.00,72.00,0.22,1235
193
+ 9-Oct-17,75.90,72.00,0.26,898
194
+ 10-Oct-17,80.10,66.00,0.00,3922
195
+ 11-Oct-17,75.00,64.90,0.06,2721
196
+ 12-Oct-17,63.00,55.90,0.07,2411
197
+ 13-Oct-17,64.90,52.00,0.00,2839
198
+ 14-Oct-17,71.10,62.10,0.08,2021
199
+ 15-Oct-17,72.00,66.00,0.01,2169
200
+ 16-Oct-17,60.10,52.00,0.01,2751
201
+ 17-Oct-17,57.90,43.00,0.00,2869
202
+ 18-Oct-17,71.10,50.00,0.00,3264
203
+ 19-Oct-17,70.00,55.90,0.00,3265
204
+ 20-Oct-17,73.00,57.90,0.00,3169
205
+ 21-Oct-17,78.10,57.00,0.00,2538
206
+ 22-Oct-17,75.90,57.00,0.00,2744
207
+ 23-Oct-17,73.90,64.00,0.00,3189
208
+ 24-Oct-17,73.00,66.90,0.20,954
209
+ 25-Oct-17,64.90,57.90,0.00,3367
210
+ 26-Oct-17,57.00,53.10,0.00,2565
211
+ 27-Oct-17,62.10,48.00,0.00,3150
212
+ 28-Oct-17,68.00,55.90,0.00,2245
213
+ 29-Oct-17,64.90,61.00,3.03,183
214
+ 30-Oct-17,55.00,46.00,0.25,1428
215
+ 31-Oct-17,54.00,44.00,0.00,2727
assignment-2/poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
assignment-2/pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "assignment-2"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Artur Janik <aj2614@nyu.edu>"]
6
+ readme = "README.md"
7
+ packages = [{include = "assignment_2"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.10"
11
+ black = "^23.1.0"
12
+ jupyterlab = "^3.6.1"
13
+ ipython = "^8.10.0"
14
+ numpy = "^1.24.2"
15
+ pandas = "^1.5.3"
16
+ jax = "^0.4.4"
17
+ seaborn = "^0.12.2"
18
+ matplotlib = "^3.7.0"
19
+
20
+
21
+ [build-system]
22
+ requires = ["poetry-core"]
23
+ build-backend = "poetry.core.masonry.api"
assignment-2/tests/__init__.py ADDED
File without changes