Spaces:
Runtime error
Runtime error
assignment-2
Browse files- assignment-2/README.md +0 -0
- assignment-2/assignment_2/GML.py +305 -0
- assignment-2/assignment_2/Gaussian Maximum Likelihood.ipynb +1682 -0
- assignment-2/assignment_2/__init__.py +0 -0
- assignment-2/data/01_raw/nyc_bb_bicyclist_counts.csv +215 -0
- assignment-2/data/nyc_bb_bicyclist_counts.csv +215 -0
- assignment-2/poetry.lock +0 -0
- assignment-2/pyproject.toml +23 -0
- assignment-2/tests/__init__.py +0 -0
assignment-2/README.md
ADDED
File without changes
|
assignment-2/assignment_2/GML.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## imports
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from scipy.optimize import minimize
|
5 |
+
from scipy.stats import norm
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
## Problem 1
|
10 |
+
data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]
|
11 |
+
|
12 |
+
data_mean = np.mean(data)
|
13 |
+
data_variance = np.var(data)
|
14 |
+
|
15 |
+
|
16 |
+
mu = 0.5
|
17 |
+
sigma = 0.5
|
18 |
+
w = np.array([mu, sigma])
|
19 |
+
|
20 |
+
w_star = np.array([data_mean, data_variance])
|
21 |
+
mu_star = data_mean
|
22 |
+
sigma_star = np.sqrt(data_variance)
|
23 |
+
offset = 10 * np.random.random(2)
|
24 |
+
|
25 |
+
w1p = w_star + 0.5 * offset
|
26 |
+
w1n = w_star - 0.5 * offset
|
27 |
+
w2p = w_star + 0.25 * offset
|
28 |
+
w2n = w_star - 0.25 * offset
|
29 |
+
|
30 |
+
# Negative Log Likelihood is defined as follows:
|
31 |
+
# $-\ln(\frac{1}{\sqrt{2\pi\sigma^2}}\exp(-\frac{1}{2}\frac{(x-\mu)}{\sigma}^2))$.
|
32 |
+
# Ignoring the contribution of the constant, we find that $\frac{\delta}{\delta
|
33 |
+
# \mu} \mathcal{N} = \frac{\mu-x}{\sigma^2}$ and $\frac{\delta}{\delta \sigma}
|
34 |
+
# \mathcal{N} = \frac{\sigma^2 + (\mu-x)^2 - \sigma^2}{\sigma^3}$. We apply these as our step functions for our SGD.
|
35 |
+
|
36 |
+
loss = lambda mu, sigma, x: np.sum(
|
37 |
+
[-np.log(norm.pdf(xi, loc=mu, scale=sigma)) for xi in x]
|
38 |
+
)
|
39 |
+
|
40 |
+
loss_2_electric_boogaloo = lambda mu, sigma, x: -len(x) / 2 * np.log(
|
41 |
+
2 * np.pi * sigma**2
|
42 |
+
) - 1 / (2 * sigma**2) * np.sum((x - mu) ** 2)
|
43 |
+
|
44 |
+
|
45 |
+
dmu = lambda mu, sigma, x: -np.sum([mu - xi for xi in x]) / (sigma**2)
|
46 |
+
dsigma = lambda mu, sigma, x: -len(x) / sigma + np.sum([(mu - xi) ** 2 for xi in x]) / (
|
47 |
+
sigma**3
|
48 |
+
)
|
49 |
+
|
50 |
+
log = []
|
51 |
+
|
52 |
+
|
53 |
+
def SGD_problem1(mu, sigma, x, learning_rate=0.01, n_epochs=1000):
|
54 |
+
global log
|
55 |
+
log = []
|
56 |
+
for epoch in range(n_epochs):
|
57 |
+
mu += learning_rate * dmu(mu, sigma, x)
|
58 |
+
sigma += learning_rate * dsigma(mu, sigma, x)
|
59 |
+
|
60 |
+
# print(f"Epoch {epoch}, Loss: {loss(mu, sigma, x)}, New mu: {mu}, New sigma: {sigma}")
|
61 |
+
log.append(
|
62 |
+
{
|
63 |
+
"Epoch": epoch,
|
64 |
+
"Loss": loss(mu, sigma, x),
|
65 |
+
"Loss 2 Alternative": loss_2_alternative(mu, sigma, x),
|
66 |
+
"New mu": mu,
|
67 |
+
"New sigma": sigma,
|
68 |
+
}
|
69 |
+
)
|
70 |
+
return np.array([mu, sigma])
|
71 |
+
|
72 |
+
|
73 |
+
def debug_SGD_1(wnn, data):
|
74 |
+
print("SGD Problem 1")
|
75 |
+
print("wnn", SGD_problem1(*wnn, data))
|
76 |
+
dflog = pd.DataFrame(log)
|
77 |
+
dflog["mu_star"] = mu_star
|
78 |
+
dflog["mu_std"] = sigma_star
|
79 |
+
print(f"mu diff at start {dflog.iloc[0]['New mu'] - dflog.iloc[0]['mu_star']}")
|
80 |
+
print(f"mu diff at end {dflog.iloc[-1]['New mu'] - dflog.iloc[-1]['mu_star']}")
|
81 |
+
if np.abs(dflog.iloc[-1]["New mu"] - dflog.iloc[-1]["mu_star"]) < np.abs(
|
82 |
+
dflog.iloc[0]["New mu"] - dflog.iloc[0]["mu_star"]
|
83 |
+
):
|
84 |
+
print("mu is improving")
|
85 |
+
else:
|
86 |
+
print("mu is not improving")
|
87 |
+
|
88 |
+
print(f"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['mu_std']}")
|
89 |
+
print(f"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['mu_std']}")
|
90 |
+
if np.abs(dflog.iloc[-1]["New sigma"] - dflog.iloc[-1]["mu_std"]) < np.abs(
|
91 |
+
dflog.iloc[0]["New sigma"] - dflog.iloc[0]["mu_std"]
|
92 |
+
):
|
93 |
+
print("sigma is improving")
|
94 |
+
else:
|
95 |
+
print("sigma is not improving")
|
96 |
+
|
97 |
+
return dflog
|
98 |
+
|
99 |
+
|
100 |
+
# _ = debug_SGD_1(w1p, data)
|
101 |
+
# _ = debug_SGD_1(w1n, data)
|
102 |
+
# _ = debug_SGD_1(w2p, data)
|
103 |
+
# _ = debug_SGD_1(w2n, data)
|
104 |
+
|
105 |
+
|
106 |
+
# TODO EXPLAIN WHY += WORKS HERE.
|
107 |
+
|
108 |
+
|
109 |
+
## Problem 2
|
110 |
+
x = np.array([8, 16, 22, 33, 50, 51])
|
111 |
+
y = np.array([5, 20, 14, 32, 42, 58])
|
112 |
+
|
113 |
+
# $-\frac{n}{\sigma}+\frac{1}{\sigma^3}\sum_{i=1}^n(y_i - (mx+c))^2$
|
114 |
+
dsigma = lambda sigma, c, m, x: -len(x) / sigma + np.sum(
|
115 |
+
[(xi - (m * x + c)) ** 2 for xi in x]
|
116 |
+
) / (sigma**3)
|
117 |
+
# $-\frac{1}{\sigma^2}\sum_{i=1}^n(y_i - (mx+c))$
|
118 |
+
dc = lambda sigma, c, m, x: -np.sum([xi - (m * x + c) for xi in x]) / (sigma**2)
|
119 |
+
# $-\frac{1}{\sigma^2}\sum_{i=1}^n(x_i(y_i - (mx+c)))$
|
120 |
+
dm = lambda sigma, c, m, x: -np.sum([x * (xi - (m * x + c)) for xi in x]) / (sigma**2)
|
121 |
+
|
122 |
+
|
123 |
+
log2 = []
|
124 |
+
|
125 |
+
|
126 |
+
def SGD_problem2(
|
127 |
+
sigma: float,
|
128 |
+
c: float,
|
129 |
+
m: float,
|
130 |
+
x: np.array,
|
131 |
+
y: np.array,
|
132 |
+
learning_rate=0.01,
|
133 |
+
n_epochs=1000,
|
134 |
+
):
|
135 |
+
global log2
|
136 |
+
log2 = []
|
137 |
+
for epoch in range(n_epochs):
|
138 |
+
sigma += learning_rate * dsigma(sigma, c, m, x)
|
139 |
+
c += learning_rate * dc(sigma, c, m, x)
|
140 |
+
m += learning_rate * dm(sigma, c, m, x)
|
141 |
+
|
142 |
+
log2.append(
|
143 |
+
{
|
144 |
+
"Epoch": epoch,
|
145 |
+
"New sigma": sigma,
|
146 |
+
"New c": c,
|
147 |
+
"New m": m,
|
148 |
+
"dc": dc(sigma, c, m, x),
|
149 |
+
"dm": dm(sigma, c, m, x),
|
150 |
+
"dsigma": dsigma(sigma, c, m, x),
|
151 |
+
"Loss": loss((m * x + c), sigma, y),
|
152 |
+
}
|
153 |
+
)
|
154 |
+
print(f"Epoch {epoch}, Loss: {loss((m * x + c), sigma, y)}")
|
155 |
+
return np.array([sigma, c, m])
|
156 |
+
|
157 |
+
|
158 |
+
# def debug_SGD_2(wnn, data):
|
159 |
+
# print("SGD Problem 2")
|
160 |
+
# print("wnn", SGD_problem2(*wnn, data))
|
161 |
+
# dflog = pd.DataFrame(log)
|
162 |
+
# dflog["m_star"] = m_star
|
163 |
+
# dflog["c_star"] = c_star
|
164 |
+
# dflog["sigma_star"] = sigma_star
|
165 |
+
# print(f"m diff at start {dflog.iloc[0]['New m'] - dflog.iloc[0]['m_star']}")
|
166 |
+
# print(f"m diff at end {dflog.iloc[-1]['New m'] - dflog.iloc[-1]['m_star']}")
|
167 |
+
# if np.abs(dflog.iloc[-1]["New m"] - dflog.iloc[-1]["m_star"]) < np.abs(
|
168 |
+
# dflog.iloc[0]["New m"] - dflog.iloc[0]["m_star"]
|
169 |
+
# ):
|
170 |
+
# print("m is improving")
|
171 |
+
# else:
|
172 |
+
# print("m is not improving")
|
173 |
+
# print(f"c diff at start {dflog.iloc[0]['New c'] - dflog.iloc[0]['c_star']}")
|
174 |
+
# print(f"c diff at end {dflog.iloc[-1]['New c'] - dflog.iloc[-1]['c_star']}")
|
175 |
+
# if np.abs(dflog.iloc[-1]["New c"] - dflog.iloc[-1]["c_star"]) < np.abs(
|
176 |
+
# dflog.iloc[0]["New c"] - dflog.iloc[0]["c_star"]
|
177 |
+
# ):
|
178 |
+
# print("c is improving")
|
179 |
+
# else:
|
180 |
+
# print("c is not improving")
|
181 |
+
# print(f"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['sigma_star']}")
|
182 |
+
# print(f"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['sigma_star']}")
|
183 |
+
# if np.abs(dflog.iloc[-1]["New sigma"] - dflog.iloc[-1]["sigma_star"]) < np.abs(
|
184 |
+
# dflog.iloc[0]["New sigma"] - dflog.iloc[0]["sigma_star"]
|
185 |
+
# ):
|
186 |
+
# print("sigma is improving")
|
187 |
+
# else:
|
188 |
+
# print("sigma is not improving")
|
189 |
+
# return dflog
|
190 |
+
|
191 |
+
result = SGD_problem2(0.5, 0.5, 0.5, x, y)
|
192 |
+
print(f"final parameters: m={result[2]}, c={result[1]}, sigma={result[0]}")
|
193 |
+
|
194 |
+
|
195 |
+
## pset2
|
196 |
+
# Knowing that the poisson pdf is $P(k) = \frac{\lambda^k e^{-\lambda}}{k!}$, we can find the negative log likelihood of the data as $-\log(\Pi_{i=1}^n P(k_i)) = -\sum_{i=1}^n \log(\frac{\lambda^k_i e^{-\lambda}}{k_i!}) = \sum_{i=1}^n -\ln(\lambda) k_i + \ln(k_i!) + \lambda$. Which simplified, gives $n\lambda + \sum_{i=1}^n \ln(k_i!) - \sum_{i=1}^n k_i \ln(\lambda)$. Differentiating with respect to $\lambda$ gives $n - \sum_{i=1}^n \frac{k_i}{\lambda}$. Which is our desired $\frac{\partial L}{\partial \lambda}$!
|
197 |
+
|
198 |
+
|
199 |
+
import pandas as pd
|
200 |
+
|
201 |
+
df = pd.read_csv("../data/01_raw/nyc_bb_bicyclist_counts.csv")
|
202 |
+
|
203 |
+
dlambda = lambda l, k: len(k) - np.sum([ki / l for ki in k])
|
204 |
+
|
205 |
+
|
206 |
+
def SGD_problem3(
|
207 |
+
l: float,
|
208 |
+
k: np.array,
|
209 |
+
learning_rate=0.01,
|
210 |
+
n_epochs=1000,
|
211 |
+
):
|
212 |
+
global log3
|
213 |
+
log3 = []
|
214 |
+
for epoch in range(n_epochs):
|
215 |
+
l -= learning_rate * dlambda(l, k)
|
216 |
+
# $n\lambda + \sum_{i=1}^n \ln(k_i!) - \sum_{i=1}^n k_i \ln(\lambda)$
|
217 |
+
# the rest of the loss function is commented out because it's a
|
218 |
+
# constant and was causing overflows. It is unnecessary, and a useless
|
219 |
+
# pain.
|
220 |
+
loss = len(k) * l - np.sum(
|
221 |
+
[ki * np.log(l) for ki in k]
|
222 |
+
) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])
|
223 |
+
|
224 |
+
log3.append(
|
225 |
+
{
|
226 |
+
"Epoch": epoch,
|
227 |
+
"New lambda": l,
|
228 |
+
"dlambda": dlambda(l, k),
|
229 |
+
"Loss": loss,
|
230 |
+
}
|
231 |
+
)
|
232 |
+
print(f"Epoch {epoch}", f"Loss: {loss}")
|
233 |
+
return np.array([l])
|
234 |
+
|
235 |
+
|
236 |
+
l_star = df["BB_COUNT"].mean()
|
237 |
+
|
238 |
+
|
239 |
+
def debug_SGD_3(data, l=1000):
|
240 |
+
print("SGD Problem 3")
|
241 |
+
print(f"l: {SGD_problem3(l, data)}")
|
242 |
+
dflog = pd.DataFrame(log3)
|
243 |
+
dflog["l_star"] = l_star
|
244 |
+
print(f"l diff at start {dflog.iloc[0]['New lambda'] - dflog.iloc[0]['l_star']}")
|
245 |
+
print(f"l diff at end {dflog.iloc[-1]['New lambda'] - dflog.iloc[-1]['l_star']}")
|
246 |
+
if np.abs(dflog.iloc[-1]["New lambda"] - dflog.iloc[-1]["l_star"]) < np.abs(
|
247 |
+
dflog.iloc[0]["New lambda"] - dflog.iloc[0]["l_star"]
|
248 |
+
):
|
249 |
+
print("l is improving")
|
250 |
+
else:
|
251 |
+
print("l is not improving")
|
252 |
+
return dflog
|
253 |
+
|
254 |
+
|
255 |
+
debug_SGD_3(data=df["BB_COUNT"].values, l=l_star + 1000)
|
256 |
+
debug_SGD_3(data=df["BB_COUNT"].values, l=l_star - 1000)
|
257 |
+
|
258 |
+
|
259 |
+
## pset 4
|
260 |
+
|
261 |
+
# dw = lambda w, x: len(x) * np.exp(np.dot(x, w)) * x - np.sum()
|
262 |
+
|
263 |
+
primitive = lambda xi, wi: (x.shape[0] * np.exp(wi * xi) * xi) - (xi**2)
|
264 |
+
p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])
|
265 |
+
|
266 |
+
|
267 |
+
def SGD_problem4(
|
268 |
+
w: np.array,
|
269 |
+
x: np.array,
|
270 |
+
learning_rate=0.01,
|
271 |
+
n_epochs=1000,
|
272 |
+
):
|
273 |
+
global log4
|
274 |
+
log4 = []
|
275 |
+
for epoch in range(n_epochs):
|
276 |
+
w -= learning_rate * p_dw(w, x)
|
277 |
+
# custom
|
278 |
+
# loss = x.shape[0] * np.exp(np.dot(x, w))
|
279 |
+
loss_fn = lambda k, l: len(k) * l - np.sum(
|
280 |
+
[ki * np.log(l) for ki in k]
|
281 |
+
) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])
|
282 |
+
loss = loss_fn(x, np.exp(np.dot(x, w)))
|
283 |
+
log4.append(
|
284 |
+
{
|
285 |
+
"Epoch": epoch,
|
286 |
+
"New w": w,
|
287 |
+
"dw": dw(w, x),
|
288 |
+
"Loss": loss,
|
289 |
+
}
|
290 |
+
)
|
291 |
+
print(f"Epoch {epoch}", f"Loss: {loss}")
|
292 |
+
return w
|
293 |
+
|
294 |
+
|
295 |
+
def debug_SGD_3(data, w=np.array([1, 1])):
|
296 |
+
print("SGD Problem 4")
|
297 |
+
print(f"w: {SGD_problem4(w, data)}")
|
298 |
+
dflog = pd.DataFrame(log4)
|
299 |
+
return dflog
|
300 |
+
|
301 |
+
|
302 |
+
_ = debug_SGD_3(
|
303 |
+
data=df[["HIGH_T", "LOW_T", "PRECIP"]].to_numpy(),
|
304 |
+
w=np.array([1.0, 1.0, 1.0]),
|
305 |
+
)
|
assignment-2/assignment_2/Gaussian Maximum Likelihood.ipynb
ADDED
@@ -0,0 +1,1682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "fce8933f-4594-4bb5-bffe-86fcb9ddd684",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# MLE of a Gaussian $p_{model}(x|w)$"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": 4,
|
14 |
+
"id": "f6cd23f0-e755-48af-be5e-aaee83dda1e7",
|
15 |
+
"metadata": {
|
16 |
+
"tags": []
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"import numpy as np\n",
|
21 |
+
"\n",
|
22 |
+
"data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n",
|
23 |
+
"\n",
|
24 |
+
"\n",
|
25 |
+
"## imports\n",
|
26 |
+
"import numpy as np\n",
|
27 |
+
"import pandas as pd\n",
|
28 |
+
"from scipy.optimize import minimize\n",
|
29 |
+
"from scipy.stats import norm\n",
|
30 |
+
"import math\n",
|
31 |
+
"\n",
|
32 |
+
"\n",
|
33 |
+
"## Problem 1\n",
|
34 |
+
"data = [4, 5, 7, 8, 8, 9, 10, 5, 2, 3, 5, 4, 8, 9]\n",
|
35 |
+
"\n",
|
36 |
+
"data_mean = np.mean(data)\n",
|
37 |
+
"data_variance = np.var(data)\n",
|
38 |
+
"\n",
|
39 |
+
"\n",
|
40 |
+
"mu = 0.5\n",
|
41 |
+
"sigma = 0.5\n",
|
42 |
+
"w = np.array([mu, sigma])\n",
|
43 |
+
"\n",
|
44 |
+
"w_star = np.array([data_mean, data_variance])\n",
|
45 |
+
"mu_star = data_mean\n",
|
46 |
+
"sigma_star = np.sqrt(data_variance)\n",
|
47 |
+
"offset = 10 * np.random.random(2)\n",
|
48 |
+
"\n",
|
49 |
+
"w1p = w_star + 0.5 * offset\n",
|
50 |
+
"w1n = w_star - 0.5 * offset\n",
|
51 |
+
"w2p = w_star + 0.25 * offset\n",
|
52 |
+
"w2n = w_star - 0.25 * offset"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "markdown",
|
57 |
+
"id": "f3d8587b-3862-4e98-bbcc-99d57bb313c1",
|
58 |
+
"metadata": {},
|
59 |
+
"source": [
|
60 |
+
"Negative Log Likelihood is defined as follows: $-\\ln(\\frac{1}{\\sqrt{2\\pi\\sigma^2}}\\exp(-\\frac{1}{2}\\frac{(x-\\mu)}{\\sigma}^2))$. Ignoring the contribution of the constant, we find that $\\frac{\\delta}{\\delta \\mu} \\mathcal{N} = \\frac{\\mu-x}{\\sigma^2}$ and $\\frac{\\delta}{\\delta \\sigma} \\mathcal{N} = \\frac{\\sigma^2 + (\\mu-x)^2 - \\sigma^2}{\\sigma^3}$. We apply these as our step functions for our SGD. "
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 5,
|
66 |
+
"id": "27bf27ad-031e-4b65-a44d-53c5c1a09d91",
|
67 |
+
"metadata": {
|
68 |
+
"tags": []
|
69 |
+
},
|
70 |
+
"outputs": [],
|
71 |
+
"source": [
|
72 |
+
"loss = lambda mu, sigma, x: np.sum(\n",
|
73 |
+
" [-np.log(norm.pdf(xi, loc=mu, scale=sigma)) for xi in x]\n",
|
74 |
+
")\n",
|
75 |
+
"\n",
|
76 |
+
"loss_2_alternative = lambda mu, sigma, x: -len(x) / 2 * np.log(\n",
|
77 |
+
" 2 * np.pi * sigma**2\n",
|
78 |
+
") - 1 / (2 * sigma**2) * np.sum((x - mu) ** 2)\n",
|
79 |
+
"\n",
|
80 |
+
"\n",
|
81 |
+
"dmu = lambda mu, sigma, x: -np.sum([mu - xi for xi in x]) / (sigma**2)\n",
|
82 |
+
"dsigma = lambda mu, sigma, x: -len(x) / sigma + np.sum([(mu - xi) ** 2 for xi in x]) / (sigma**3)\n",
|
83 |
+
"\n",
|
84 |
+
"log = []\n",
|
85 |
+
"def SGD_problem1(mu, sigma, x, learning_rate=0.01, n_epochs=1000):\n",
|
86 |
+
" global log\n",
|
87 |
+
" log = []\n",
|
88 |
+
" for epoch in range(n_epochs):\n",
|
89 |
+
" mu += learning_rate * dmu(mu, sigma, x)\n",
|
90 |
+
" sigma += learning_rate * dsigma(mu, sigma, x)\n",
|
91 |
+
"\n",
|
92 |
+
" # print(f\"Epoch {epoch}, Loss: {loss(mu, sigma, x)}, New mu: {mu}, New sigma: {sigma}\")\n",
|
93 |
+
" log.append(\n",
|
94 |
+
" {\n",
|
95 |
+
" \"Epoch\": epoch,\n",
|
96 |
+
" \"Loss\": loss(mu, sigma, x),\n",
|
97 |
+
" \"Loss 2 Alternative\": loss_2_alternative(mu, sigma, x),\n",
|
98 |
+
" \"New mu\": mu,\n",
|
99 |
+
" \"New sigma\": sigma,\n",
|
100 |
+
" }\n",
|
101 |
+
" )\n",
|
102 |
+
" return np.array([mu, sigma])\n",
|
103 |
+
"\n",
|
104 |
+
"\n",
|
105 |
+
"def debug_SGD_1(wnn, data):\n",
|
106 |
+
" print(\"SGD Problem 1\")\n",
|
107 |
+
" print(\"wnn\", SGD_problem1(*wnn, data))\n",
|
108 |
+
" dflog = pd.DataFrame(log)\n",
|
109 |
+
" dflog[\"mu_star\"] = mu_star\n",
|
110 |
+
" dflog[\"mu_std\"] = sigma_star\n",
|
111 |
+
" print(f\"mu diff at start {dflog.iloc[0]['New mu'] - dflog.iloc[0]['mu_star']}\")\n",
|
112 |
+
" print(f\"mu diff at end {dflog.iloc[-1]['New mu'] - dflog.iloc[-1]['mu_star']}\")\n",
|
113 |
+
" if np.abs(dflog.iloc[-1][\"New mu\"] - dflog.iloc[-1][\"mu_star\"]) < np.abs(\n",
|
114 |
+
" dflog.iloc[0][\"New mu\"] - dflog.iloc[0][\"mu_star\"]\n",
|
115 |
+
" ):\n",
|
116 |
+
" print(\"mu is improving\")\n",
|
117 |
+
" else:\n",
|
118 |
+
" print(\"mu is not improving\")\n",
|
119 |
+
"\n",
|
120 |
+
" print(f\"sigma diff at start {dflog.iloc[0]['New sigma'] - dflog.iloc[0]['mu_std']}\")\n",
|
121 |
+
" print(f\"sigma diff at end {dflog.iloc[-1]['New sigma'] - dflog.iloc[-1]['mu_std']}\")\n",
|
122 |
+
" if np.abs(dflog.iloc[-1][\"New sigma\"] - dflog.iloc[-1][\"mu_std\"]) < np.abs(\n",
|
123 |
+
" dflog.iloc[0][\"New sigma\"] - dflog.iloc[0][\"mu_std\"]\n",
|
124 |
+
" ):\n",
|
125 |
+
" print(\"sigma is improving\")\n",
|
126 |
+
" else:\n",
|
127 |
+
" print(\"sigma is not improving\")\n",
|
128 |
+
"\n",
|
129 |
+
" return dflog"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": 6,
|
135 |
+
"id": "27dd3bc6-b96e-4f8b-9118-01ad344dfd6a",
|
136 |
+
"metadata": {
|
137 |
+
"tags": []
|
138 |
+
},
|
139 |
+
"outputs": [
|
140 |
+
{
|
141 |
+
"name": "stdout",
|
142 |
+
"output_type": "stream",
|
143 |
+
"text": [
|
144 |
+
"SGD Problem 1\n",
|
145 |
+
"wnn [6.2142858 2.42541812]\n",
|
146 |
+
"mu diff at start 0.27610721776969527\n",
|
147 |
+
"mu diff at end 8.978893806244059e-08\n",
|
148 |
+
"mu is improving\n",
|
149 |
+
"sigma diff at start 8.134860851821205\n",
|
150 |
+
"sigma diff at end 1.7124079931818414e-12\n",
|
151 |
+
"sigma is improving\n",
|
152 |
+
"SGD Problem 1\n",
|
153 |
+
"wnn [6.21428571 2.42541812]\n",
|
154 |
+
"mu diff at start -0.24923650064862635\n",
|
155 |
+
"mu diff at end -6.602718372050731e-12\n",
|
156 |
+
"mu is improving\n",
|
157 |
+
"sigma diff at start -0.859536014291925\n",
|
158 |
+
"sigma diff at end -3.552713678800501e-15\n",
|
159 |
+
"sigma is improving\n",
|
160 |
+
"SGD Problem 1\n",
|
161 |
+
"wnn [6.21428572 2.42541812]\n",
|
162 |
+
"mu diff at start 0.13794086144778994\n",
|
163 |
+
"mu diff at end 1.0008935902305893e-09\n",
|
164 |
+
"mu is improving\n",
|
165 |
+
"sigma diff at start 5.786783512688555\n",
|
166 |
+
"sigma diff at end 4.440892098500626e-15\n",
|
167 |
+
"sigma is improving\n",
|
168 |
+
"SGD Problem 1\n",
|
169 |
+
"wnn [6.21428571 2.42541812]\n",
|
170 |
+
"mu diff at start -0.13668036978891251\n",
|
171 |
+
"mu diff at end -8.528289185960602e-12\n",
|
172 |
+
"mu is improving\n",
|
173 |
+
"sigma diff at start 1.091241177336173\n",
|
174 |
+
"sigma diff at end 4.440892098500626e-15\n",
|
175 |
+
"sigma is improving\n"
|
176 |
+
]
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"source": [
|
180 |
+
"_ = debug_SGD_1(w1p, data)\n",
|
181 |
+
"_ = debug_SGD_1(w1n, data)\n",
|
182 |
+
"_ = debug_SGD_1(w2p, data)\n",
|
183 |
+
"_ = debug_SGD_1(w2n, data)"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "markdown",
|
188 |
+
"id": "30096401-0bd5-4cf6-b093-a688476e16f1",
|
189 |
+
"metadata": {
|
190 |
+
"tags": []
|
191 |
+
},
|
192 |
+
"source": [
|
193 |
+
"# MLE of Conditional Gaussian"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "markdown",
|
198 |
+
"id": "101a3c5e-1e02-41e6-9eab-aba65c39627a",
|
199 |
+
"metadata": {},
|
200 |
+
"source": [
|
201 |
+
"dsigma = $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$ \n",
|
202 |
+
"dc = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$ \n",
|
203 |
+
"dm = $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$ "
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"execution_count": 8,
|
209 |
+
"id": "21969012-f81b-43d4-975d-13411e975f8f",
|
210 |
+
"metadata": {
|
211 |
+
"collapsed": true,
|
212 |
+
"jupyter": {
|
213 |
+
"outputs_hidden": true
|
214 |
+
},
|
215 |
+
"tags": []
|
216 |
+
},
|
217 |
+
"outputs": [
|
218 |
+
{
|
219 |
+
"name": "stdout",
|
220 |
+
"output_type": "stream",
|
221 |
+
"text": [
|
222 |
+
"Epoch 0, Loss: 297.82677563555086\n",
|
223 |
+
"Epoch 1, Loss: 297.8267749215061\n",
|
224 |
+
"Epoch 2, Loss: 297.82677420752475\n",
|
225 |
+
"Epoch 3, Loss: 297.82677349360694\n",
|
226 |
+
"Epoch 4, Loss: 297.8267727797526\n",
|
227 |
+
"Epoch 5, Loss: 297.8267720659618\n",
|
228 |
+
"Epoch 6, Loss: 297.82677135223446\n",
|
229 |
+
"Epoch 7, Loss: 297.82677063857074\n",
|
230 |
+
"Epoch 8, Loss: 297.8267699249706\n",
|
231 |
+
"Epoch 9, Loss: 297.826769211434\n",
|
232 |
+
"Epoch 10, Loss: 297.8267684979611\n",
|
233 |
+
"Epoch 11, Loss: 297.82676778455175\n",
|
234 |
+
"Epoch 12, Loss: 297.82676707120623\n",
|
235 |
+
"Epoch 13, Loss: 297.82676635792427\n",
|
236 |
+
"Epoch 14, Loss: 297.8267656447061\n",
|
237 |
+
"Epoch 15, Loss: 297.8267649315517\n",
|
238 |
+
"Epoch 16, Loss: 297.82676421846105\n",
|
239 |
+
"Epoch 17, Loss: 297.82676350543414\n",
|
240 |
+
"Epoch 18, Loss: 297.82676279247113\n",
|
241 |
+
"Epoch 19, Loss: 297.82676207957184\n",
|
242 |
+
"Epoch 20, Loss: 297.82676136673643\n",
|
243 |
+
"Epoch 21, Loss: 297.826760653965\n",
|
244 |
+
"Epoch 22, Loss: 297.82675994125736\n",
|
245 |
+
"Epoch 23, Loss: 297.8267592286137\n",
|
246 |
+
"Epoch 24, Loss: 297.8267585160339\n",
|
247 |
+
"Epoch 25, Loss: 297.8267578035182\n",
|
248 |
+
"Epoch 26, Loss: 297.82675709106655\n",
|
249 |
+
"Epoch 27, Loss: 297.8267563786788\n",
|
250 |
+
"Epoch 28, Loss: 297.82675566635504\n",
|
251 |
+
"Epoch 29, Loss: 297.82675495409546\n",
|
252 |
+
"Epoch 30, Loss: 297.82675424189983\n",
|
253 |
+
"Epoch 31, Loss: 297.82675352976844\n",
|
254 |
+
"Epoch 32, Loss: 297.8267528177011\n",
|
255 |
+
"Epoch 33, Loss: 297.82675210569795\n",
|
256 |
+
"Epoch 34, Loss: 297.826751393759\n",
|
257 |
+
"Epoch 35, Loss: 297.8267506818843\n",
|
258 |
+
"Epoch 36, Loss: 297.8267499700737\n",
|
259 |
+
"Epoch 37, Loss: 297.8267492583274\n",
|
260 |
+
"Epoch 38, Loss: 297.82674854664543\n",
|
261 |
+
"Epoch 39, Loss: 297.8267478350277\n",
|
262 |
+
"Epoch 40, Loss: 297.8267471234743\n",
|
263 |
+
"Epoch 41, Loss: 297.82674641198514\n",
|
264 |
+
"Epoch 42, Loss: 297.8267457005605\n",
|
265 |
+
"Epoch 43, Loss: 297.82674498920017\n",
|
266 |
+
"Epoch 44, Loss: 297.82674427790425\n",
|
267 |
+
"Epoch 45, Loss: 297.82674356667275\n",
|
268 |
+
"Epoch 46, Loss: 297.82674285550564\n",
|
269 |
+
"Epoch 47, Loss: 297.8267421444032\n",
|
270 |
+
"Epoch 48, Loss: 297.82674143336516\n",
|
271 |
+
"Epoch 49, Loss: 297.8267407223916\n",
|
272 |
+
"Epoch 50, Loss: 297.8267400114827\n",
|
273 |
+
"Epoch 51, Loss: 297.82673930063817\n",
|
274 |
+
"Epoch 52, Loss: 297.8267385898584\n",
|
275 |
+
"Epoch 53, Loss: 297.8267378791432\n",
|
276 |
+
"Epoch 54, Loss: 297.8267371684925\n",
|
277 |
+
"Epoch 55, Loss: 297.8267364579067\n",
|
278 |
+
"Epoch 56, Loss: 297.8267357473855\n",
|
279 |
+
"Epoch 57, Loss: 297.82673503692894\n",
|
280 |
+
"Epoch 58, Loss: 297.82673432653723\n",
|
281 |
+
"Epoch 59, Loss: 297.8267336162103\n",
|
282 |
+
"Epoch 60, Loss: 297.826732905948\n",
|
283 |
+
"Epoch 61, Loss: 297.82673219575054\n",
|
284 |
+
"Epoch 62, Loss: 297.826731485618\n",
|
285 |
+
"Epoch 63, Loss: 297.82673077555023\n",
|
286 |
+
"Epoch 64, Loss: 297.82673006554734\n",
|
287 |
+
"Epoch 65, Loss: 297.8267293556093\n",
|
288 |
+
"Epoch 66, Loss: 297.82672864573624\n",
|
289 |
+
"Epoch 67, Loss: 297.82672793592815\n",
|
290 |
+
"Epoch 68, Loss: 297.82672722618497\n",
|
291 |
+
"Epoch 69, Loss: 297.8267265165068\n",
|
292 |
+
"Epoch 70, Loss: 297.8267258068936\n",
|
293 |
+
"Epoch 71, Loss: 297.8267250973455\n",
|
294 |
+
"Epoch 72, Loss: 297.8267243878624\n",
|
295 |
+
"Epoch 73, Loss: 297.8267236784444\n",
|
296 |
+
"Epoch 74, Loss: 297.82672296909146\n",
|
297 |
+
"Epoch 75, Loss: 297.8267222598038\n",
|
298 |
+
"Epoch 76, Loss: 297.8267215505811\n",
|
299 |
+
"Epoch 77, Loss: 297.8267208414237\n",
|
300 |
+
"Epoch 78, Loss: 297.82672013233156\n",
|
301 |
+
"Epoch 79, Loss: 297.8267194233045\n",
|
302 |
+
"Epoch 80, Loss: 297.8267187143427\n",
|
303 |
+
"Epoch 81, Loss: 297.8267180054462\n",
|
304 |
+
"Epoch 82, Loss: 297.82671729661496\n",
|
305 |
+
"Epoch 83, Loss: 297.82671658784903\n",
|
306 |
+
"Epoch 84, Loss: 297.8267158791485\n",
|
307 |
+
"Epoch 85, Loss: 297.82671517051335\n",
|
308 |
+
"Epoch 86, Loss: 297.8267144619435\n",
|
309 |
+
"Epoch 87, Loss: 297.8267137534391\n",
|
310 |
+
"Epoch 88, Loss: 297.82671304500013\n",
|
311 |
+
"Epoch 89, Loss: 297.82671233662654\n",
|
312 |
+
"Epoch 90, Loss: 297.82671162831855\n",
|
313 |
+
"Epoch 91, Loss: 297.8267109200759\n",
|
314 |
+
"Epoch 92, Loss: 297.82671021189896\n",
|
315 |
+
"Epoch 93, Loss: 297.8267095037876\n",
|
316 |
+
"Epoch 94, Loss: 297.8267087957417\n",
|
317 |
+
"Epoch 95, Loss: 297.82670808776135\n",
|
318 |
+
"Epoch 96, Loss: 297.82670737984665\n",
|
319 |
+
"Epoch 97, Loss: 297.8267066719976\n",
|
320 |
+
"Epoch 98, Loss: 297.82670596421417\n",
|
321 |
+
"Epoch 99, Loss: 297.8267052564966\n",
|
322 |
+
"Epoch 100, Loss: 297.82670454884465\n",
|
323 |
+
"Epoch 101, Loss: 297.82670384125834\n",
|
324 |
+
"Epoch 102, Loss: 297.8267031337379\n",
|
325 |
+
"Epoch 103, Loss: 297.8267024262833\n",
|
326 |
+
"Epoch 104, Loss: 297.82670171889436\n",
|
327 |
+
"Epoch 105, Loss: 297.8267010115713\n",
|
328 |
+
"Epoch 106, Loss: 297.82670030431416\n",
|
329 |
+
"Epoch 107, Loss: 297.82669959712285\n",
|
330 |
+
"Epoch 108, Loss: 297.82669888999743\n",
|
331 |
+
"Epoch 109, Loss: 297.8266981829379\n",
|
332 |
+
"Epoch 110, Loss: 297.8266974759444\n",
|
333 |
+
"Epoch 111, Loss: 297.8266967690169\n",
|
334 |
+
"Epoch 112, Loss: 297.8266960621552\n",
|
335 |
+
"Epoch 113, Loss: 297.8266953553597\n",
|
336 |
+
"Epoch 114, Loss: 297.82669464863017\n",
|
337 |
+
"Epoch 115, Loss: 297.82669394196677\n",
|
338 |
+
"Epoch 116, Loss: 297.8266932353694\n",
|
339 |
+
"Epoch 117, Loss: 297.82669252883824\n",
|
340 |
+
"Epoch 118, Loss: 297.8266918223731\n",
|
341 |
+
"Epoch 119, Loss: 297.8266911159742\n",
|
342 |
+
"Epoch 120, Loss: 297.82669040964146\n",
|
343 |
+
"Epoch 121, Loss: 297.8266897033749\n",
|
344 |
+
"Epoch 122, Loss: 297.8266889971746\n",
|
345 |
+
"Epoch 123, Loss: 297.82668829104057\n",
|
346 |
+
"Epoch 124, Loss: 297.8266875849728\n",
|
347 |
+
"Epoch 125, Loss: 297.8266868789714\n",
|
348 |
+
"Epoch 126, Loss: 297.8266861730362\n",
|
349 |
+
"Epoch 127, Loss: 297.8266854671674\n",
|
350 |
+
"Epoch 128, Loss: 297.826684761365\n",
|
351 |
+
"Epoch 129, Loss: 297.82668405562896\n",
|
352 |
+
"Epoch 130, Loss: 297.8266833499594\n",
|
353 |
+
"Epoch 131, Loss: 297.8266826443563\n",
|
354 |
+
"Epoch 132, Loss: 297.8266819388196\n",
|
355 |
+
"Epoch 133, Loss: 297.82668123334935\n",
|
356 |
+
"Epoch 134, Loss: 297.8266805279458\n",
|
357 |
+
"Epoch 135, Loss: 297.8266798226087\n",
|
358 |
+
"Epoch 136, Loss: 297.82667911733813\n",
|
359 |
+
"Epoch 137, Loss: 297.8266784121342\n",
|
360 |
+
"Epoch 138, Loss: 297.82667770699686\n",
|
361 |
+
"Epoch 139, Loss: 297.82667700192616\n",
|
362 |
+
"Epoch 140, Loss: 297.8266762969221\n",
|
363 |
+
"Epoch 141, Loss: 297.8266755919847\n",
|
364 |
+
"Epoch 142, Loss: 297.82667488711405\n",
|
365 |
+
"Epoch 143, Loss: 297.8266741823101\n",
|
366 |
+
"Epoch 144, Loss: 297.82667347757297\n",
|
367 |
+
"Epoch 145, Loss: 297.82667277290255\n",
|
368 |
+
"Epoch 146, Loss: 297.82667206829893\n",
|
369 |
+
"Epoch 147, Loss: 297.8266713637622\n",
|
370 |
+
"Epoch 148, Loss: 297.8266706592923\n",
|
371 |
+
"Epoch 149, Loss: 297.8266699548893\n",
|
372 |
+
"Epoch 150, Loss: 297.8266692505533\n",
|
373 |
+
"Epoch 151, Loss: 297.82666854628405\n",
|
374 |
+
"Epoch 152, Loss: 297.82666784208175\n",
|
375 |
+
"Epoch 153, Loss: 297.8266671379465\n",
|
376 |
+
"Epoch 154, Loss: 297.8266664338782\n",
|
377 |
+
"Epoch 155, Loss: 297.82666572987705\n",
|
378 |
+
"Epoch 156, Loss: 297.8266650259427\n",
|
379 |
+
"Epoch 157, Loss: 297.82666432207566\n",
|
380 |
+
"Epoch 158, Loss: 297.82666361827546\n",
|
381 |
+
"Epoch 159, Loss: 297.8266629145426\n",
|
382 |
+
"Epoch 160, Loss: 297.8266622108768\n",
|
383 |
+
"Epoch 161, Loss: 297.8266615072782\n",
|
384 |
+
"Epoch 162, Loss: 297.8266608037467\n",
|
385 |
+
"Epoch 163, Loss: 297.8266601002824\n",
|
386 |
+
"Epoch 164, Loss: 297.8266593968855\n",
|
387 |
+
"Epoch 165, Loss: 297.8266586935557\n",
|
388 |
+
"Epoch 166, Loss: 297.82665799029326\n",
|
389 |
+
"Epoch 167, Loss: 297.82665728709804\n",
|
390 |
+
"Epoch 168, Loss: 297.82665658397036\n",
|
391 |
+
"Epoch 169, Loss: 297.8266558809098\n",
|
392 |
+
"Epoch 170, Loss: 297.82665517791673\n",
|
393 |
+
"Epoch 171, Loss: 297.8266544749911\n",
|
394 |
+
"Epoch 172, Loss: 297.82665377213283\n",
|
395 |
+
"Epoch 173, Loss: 297.826653069342\n",
|
396 |
+
"Epoch 174, Loss: 297.8266523666187\n",
|
397 |
+
"Epoch 175, Loss: 297.82665166396293\n",
|
398 |
+
"Epoch 176, Loss: 297.82665096137464\n",
|
399 |
+
"Epoch 177, Loss: 297.8266502588538\n",
|
400 |
+
"Epoch 178, Loss: 297.82664955640064\n",
|
401 |
+
"Epoch 179, Loss: 297.82664885401516\n",
|
402 |
+
"Epoch 180, Loss: 297.82664815169716\n",
|
403 |
+
"Epoch 181, Loss: 297.8266474494469\n",
|
404 |
+
"Epoch 182, Loss: 297.8266467472642\n",
|
405 |
+
"Epoch 183, Loss: 297.8266460451493\n",
|
406 |
+
"Epoch 184, Loss: 297.82664534310203\n",
|
407 |
+
"Epoch 185, Loss: 297.82664464112264\n",
|
408 |
+
"Epoch 186, Loss: 297.82664393921095\n",
|
409 |
+
"Epoch 187, Loss: 297.82664323736697\n",
|
410 |
+
"Epoch 188, Loss: 297.8266425355909\n",
|
411 |
+
"Epoch 189, Loss: 297.8266418338827\n",
|
412 |
+
"Epoch 190, Loss: 297.82664113224223\n",
|
413 |
+
"Epoch 191, Loss: 297.8266404306698\n",
|
414 |
+
"Epoch 192, Loss: 297.82663972916515\n",
|
415 |
+
"Epoch 193, Loss: 297.82663902772856\n",
|
416 |
+
"Epoch 194, Loss: 297.82663832635984\n",
|
417 |
+
"Epoch 195, Loss: 297.8266376250591\n",
|
418 |
+
"Epoch 196, Loss: 297.8266369238264\n",
|
419 |
+
"Epoch 197, Loss: 297.82663622266176\n",
|
420 |
+
"Epoch 198, Loss: 297.8266355215652\n",
|
421 |
+
"Epoch 199, Loss: 297.8266348205367\n",
|
422 |
+
"Epoch 200, Loss: 297.8266341195762\n",
|
423 |
+
"Epoch 201, Loss: 297.82663341868397\n",
|
424 |
+
"Epoch 202, Loss: 297.82663271785987\n",
|
425 |
+
"Epoch 203, Loss: 297.82663201710386\n",
|
426 |
+
"Epoch 204, Loss: 297.8266313164161\n",
|
427 |
+
"Epoch 205, Loss: 297.82663061579655\n",
|
428 |
+
"Epoch 206, Loss: 297.8266299152454\n",
|
429 |
+
"Epoch 207, Loss: 297.8266292147624\n",
|
430 |
+
"Epoch 208, Loss: 297.8266285143477\n",
|
431 |
+
"Epoch 209, Loss: 297.82662781400137\n",
|
432 |
+
"Epoch 210, Loss: 297.82662711372336\n",
|
433 |
+
"Epoch 211, Loss: 297.8266264135138\n",
|
434 |
+
"Epoch 212, Loss: 297.8266257133725\n",
|
435 |
+
"Epoch 213, Loss: 297.8266250132998\n",
|
436 |
+
"Epoch 214, Loss: 297.82662431329544\n",
|
437 |
+
"Epoch 215, Loss: 297.8266236133595\n",
|
438 |
+
"Epoch 216, Loss: 297.8266229134922\n",
|
439 |
+
"Epoch 217, Loss: 297.8266222136934\n",
|
440 |
+
"Epoch 218, Loss: 297.8266215139631\n",
|
441 |
+
"Epoch 219, Loss: 297.8266208143014\n",
|
442 |
+
"Epoch 220, Loss: 297.8266201147082\n",
|
443 |
+
"Epoch 221, Loss: 297.8266194151838\n",
|
444 |
+
"Epoch 222, Loss: 297.82661871572793\n",
|
445 |
+
"Epoch 223, Loss: 297.82661801634066\n",
|
446 |
+
"Epoch 224, Loss: 297.82661731702217\n",
|
447 |
+
"Epoch 225, Loss: 297.82661661777234\n",
|
448 |
+
"Epoch 226, Loss: 297.8266159185914\n",
|
449 |
+
"Epoch 227, Loss: 297.8266152194791\n",
|
450 |
+
"Epoch 228, Loss: 297.82661452043567\n",
|
451 |
+
"Epoch 229, Loss: 297.82661382146097\n",
|
452 |
+
"Epoch 230, Loss: 297.8266131225552\n",
|
453 |
+
"Epoch 231, Loss: 297.82661242371825\n",
|
454 |
+
"Epoch 232, Loss: 297.82661172495017\n",
|
455 |
+
"Epoch 233, Loss: 297.82661102625104\n",
|
456 |
+
"Epoch 234, Loss: 297.8266103276209\n",
|
457 |
+
"Epoch 235, Loss: 297.8266096290597\n",
|
458 |
+
"Epoch 236, Loss: 297.82660893056743\n",
|
459 |
+
"Epoch 237, Loss: 297.8266082321442\n",
|
460 |
+
"Epoch 238, Loss: 297.82660753379\n",
|
461 |
+
"Epoch 239, Loss: 297.8266068355049\n",
|
462 |
+
"Epoch 240, Loss: 297.82660613728893\n",
|
463 |
+
"Epoch 241, Loss: 297.82660543914204\n",
|
464 |
+
"Epoch 242, Loss: 297.8266047410642\n",
|
465 |
+
"Epoch 243, Loss: 297.82660404305557\n",
|
466 |
+
"Epoch 244, Loss: 297.82660334511615\n",
|
467 |
+
"Epoch 245, Loss: 297.826602647246\n",
|
468 |
+
"Epoch 246, Loss: 297.826601949445\n",
|
469 |
+
"Epoch 247, Loss: 297.8266012517133\n",
|
470 |
+
"Epoch 248, Loss: 297.82660055405086\n",
|
471 |
+
"Epoch 249, Loss: 297.82659985645773\n",
|
472 |
+
"Epoch 250, Loss: 297.82659915893396\n",
|
473 |
+
"Epoch 251, Loss: 297.8265984614796\n",
|
474 |
+
"Epoch 252, Loss: 297.8265977640945\n",
|
475 |
+
"Epoch 253, Loss: 297.8265970667789\n",
|
476 |
+
"Epoch 254, Loss: 297.8265963695328\n",
|
477 |
+
"Epoch 255, Loss: 297.82659567235606\n",
|
478 |
+
"Epoch 256, Loss: 297.8265949752488\n",
|
479 |
+
"Epoch 257, Loss: 297.8265942782112\n",
|
480 |
+
"Epoch 258, Loss: 297.82659358124295\n",
|
481 |
+
"Epoch 259, Loss: 297.82659288434434\n",
|
482 |
+
"Epoch 260, Loss: 297.82659218751525\n",
|
483 |
+
"Epoch 261, Loss: 297.82659149075585\n",
|
484 |
+
"Epoch 262, Loss: 297.826590794066\n",
|
485 |
+
"Epoch 263, Loss: 297.8265900974459\n",
|
486 |
+
"Epoch 264, Loss: 297.8265894008955\n",
|
487 |
+
"Epoch 265, Loss: 297.82658870441463\n",
|
488 |
+
"Epoch 266, Loss: 297.8265880080037\n",
|
489 |
+
"Epoch 267, Loss: 297.8265873116626\n",
|
490 |
+
"Epoch 268, Loss: 297.82658661539097\n",
|
491 |
+
"Epoch 269, Loss: 297.8265859191893\n",
|
492 |
+
"Epoch 270, Loss: 297.8265852230575\n",
|
493 |
+
"Epoch 271, Loss: 297.8265845269956\n",
|
494 |
+
"Epoch 272, Loss: 297.82658383100346\n",
|
495 |
+
"Epoch 273, Loss: 297.82658313508136\n",
|
496 |
+
"Epoch 274, Loss: 297.8265824392291\n",
|
497 |
+
"Epoch 275, Loss: 297.8265817434468\n",
|
498 |
+
"Epoch 276, Loss: 297.82658104773463\n",
|
499 |
+
"Epoch 277, Loss: 297.82658035209226\n",
|
500 |
+
"Epoch 278, Loss: 297.82657965652004\n",
|
501 |
+
"Epoch 279, Loss: 297.8265789610178\n",
|
502 |
+
"Epoch 280, Loss: 297.82657826558574\n",
|
503 |
+
"Epoch 281, Loss: 297.8265775702237\n",
|
504 |
+
"Epoch 282, Loss: 297.82657687493185\n",
|
505 |
+
"Epoch 283, Loss: 297.8265761797103\n",
|
506 |
+
"Epoch 284, Loss: 297.82657548455876\n",
|
507 |
+
"Epoch 285, Loss: 297.82657478947743\n",
|
508 |
+
"Epoch 286, Loss: 297.82657409446637\n",
|
509 |
+
"Epoch 287, Loss: 297.8265733995255\n",
|
510 |
+
"Epoch 288, Loss: 297.826572704655\n",
|
511 |
+
"Epoch 289, Loss: 297.8265720098549\n",
|
512 |
+
"Epoch 290, Loss: 297.82657131512497\n",
|
513 |
+
"Epoch 291, Loss: 297.8265706204655\n",
|
514 |
+
"Epoch 292, Loss: 297.82656992587636\n",
|
515 |
+
"Epoch 293, Loss: 297.82656923135755\n",
|
516 |
+
"Epoch 294, Loss: 297.82656853690935\n",
|
517 |
+
"Epoch 295, Loss: 297.8265678425316\n",
|
518 |
+
"Epoch 296, Loss: 297.82656714822417\n",
|
519 |
+
"Epoch 297, Loss: 297.82656645398737\n",
|
520 |
+
"Epoch 298, Loss: 297.8265657598211\n",
|
521 |
+
"Epoch 299, Loss: 297.82656506572545\n",
|
522 |
+
"Epoch 300, Loss: 297.8265643717003\n",
|
523 |
+
"Epoch 301, Loss: 297.8265636777458\n",
|
524 |
+
"Epoch 302, Loss: 297.8265629838619\n",
|
525 |
+
"Epoch 303, Loss: 297.8265622900487\n",
|
526 |
+
"Epoch 304, Loss: 297.8265615963061\n",
|
527 |
+
"Epoch 305, Loss: 297.82656090263424\n",
|
528 |
+
"Epoch 306, Loss: 297.8265602090333\n",
|
529 |
+
"Epoch 307, Loss: 297.82655951550294\n",
|
530 |
+
"Epoch 308, Loss: 297.82655882204335\n",
|
531 |
+
"Epoch 309, Loss: 297.8265581286547\n",
|
532 |
+
"Epoch 310, Loss: 297.82655743533684\n",
|
533 |
+
"Epoch 311, Loss: 297.8265567420898\n",
|
534 |
+
"Epoch 312, Loss: 297.8265560489138\n",
|
535 |
+
"Epoch 313, Loss: 297.8265553558085\n",
|
536 |
+
"Epoch 314, Loss: 297.82655466277436\n",
|
537 |
+
"Epoch 315, Loss: 297.8265539698111\n",
|
538 |
+
"Epoch 316, Loss: 297.8265532769188\n",
|
539 |
+
"Epoch 317, Loss: 297.8265525840975\n",
|
540 |
+
"Epoch 318, Loss: 297.8265518913472\n",
|
541 |
+
"Epoch 319, Loss: 297.8265511986681\n",
|
542 |
+
"Epoch 320, Loss: 297.82655050606\n",
|
543 |
+
"Epoch 321, Loss: 297.8265498135231\n",
|
544 |
+
"Epoch 322, Loss: 297.8265491210572\n",
|
545 |
+
"Epoch 323, Loss: 297.82654842866265\n",
|
546 |
+
"Epoch 324, Loss: 297.8265477363392\n",
|
547 |
+
"Epoch 325, Loss: 297.826547044087\n",
|
548 |
+
"Epoch 326, Loss: 297.826546351906\n",
|
549 |
+
"Epoch 327, Loss: 297.82654565979635\n",
|
550 |
+
"Epoch 328, Loss: 297.8265449677579\n",
|
551 |
+
"Epoch 329, Loss: 297.8265442757908\n",
|
552 |
+
"Epoch 330, Loss: 297.8265435838951\n",
|
553 |
+
"Epoch 331, Loss: 297.82654289207085\n",
|
554 |
+
"Epoch 332, Loss: 297.8265422003178\n",
|
555 |
+
"Epoch 333, Loss: 297.82654150863635\n",
|
556 |
+
"Epoch 334, Loss: 297.8265408170263\n",
|
557 |
+
"Epoch 335, Loss: 297.8265401254876\n",
|
558 |
+
"Epoch 336, Loss: 297.8265394340205\n",
|
559 |
+
"Epoch 337, Loss: 297.826538742625\n",
|
560 |
+
"Epoch 338, Loss: 297.8265380513009\n",
|
561 |
+
"Epoch 339, Loss: 297.8265373600486\n",
|
562 |
+
"Epoch 340, Loss: 297.8265366688676\n",
|
563 |
+
"Epoch 341, Loss: 297.82653597775845\n",
|
564 |
+
"Epoch 342, Loss: 297.82653528672085\n",
|
565 |
+
"Epoch 343, Loss: 297.826534595755\n",
|
566 |
+
"Epoch 344, Loss: 297.82653390486087\n",
|
567 |
+
"Epoch 345, Loss: 297.8265332140384\n",
|
568 |
+
"Epoch 346, Loss: 297.8265325232878\n",
|
569 |
+
"Epoch 347, Loss: 297.8265318326088\n",
|
570 |
+
"Epoch 348, Loss: 297.8265311420018\n",
|
571 |
+
"Epoch 349, Loss: 297.8265304514666\n",
|
572 |
+
"Epoch 350, Loss: 297.82652976100314\n",
|
573 |
+
"Epoch 351, Loss: 297.82652907061174\n",
|
574 |
+
"Epoch 352, Loss: 297.82652838029213\n",
|
575 |
+
"Epoch 353, Loss: 297.8265276900445\n",
|
576 |
+
"Epoch 354, Loss: 297.82652699986875\n",
|
577 |
+
"Epoch 355, Loss: 297.82652630976514\n",
|
578 |
+
"Epoch 356, Loss: 297.82652561973345\n",
|
579 |
+
"Epoch 357, Loss: 297.82652492977377\n",
|
580 |
+
"Epoch 358, Loss: 297.82652423988617\n",
|
581 |
+
"Epoch 359, Loss: 297.82652355007065\n",
|
582 |
+
"Epoch 360, Loss: 297.8265228603273\n",
|
583 |
+
"Epoch 361, Loss: 297.8265221706561\n",
|
584 |
+
"Epoch 362, Loss: 297.82652148105706\n",
|
585 |
+
"Epoch 363, Loss: 297.8265207915302\n",
|
586 |
+
"Epoch 364, Loss: 297.8265201020755\n",
|
587 |
+
"Epoch 365, Loss: 297.8265194126932\n",
|
588 |
+
"Epoch 366, Loss: 297.82651872338306\n",
|
589 |
+
"Epoch 367, Loss: 297.8265180341453\n",
|
590 |
+
"Epoch 368, Loss: 297.8265173449798\n",
|
591 |
+
"Epoch 369, Loss: 297.8265166558866\n",
|
592 |
+
"Epoch 370, Loss: 297.8265159668659\n",
|
593 |
+
"Epoch 371, Loss: 297.82651527791745\n",
|
594 |
+
"Epoch 372, Loss: 297.8265145890415\n",
|
595 |
+
"Epoch 373, Loss: 297.82651390023807\n",
|
596 |
+
"Epoch 374, Loss: 297.82651321150695\n",
|
597 |
+
"Epoch 375, Loss: 297.82651252284853\n",
|
598 |
+
"Epoch 376, Loss: 297.8265118342626\n",
|
599 |
+
"Epoch 377, Loss: 297.8265111457491\n",
|
600 |
+
"Epoch 378, Loss: 297.82651045730825\n",
|
601 |
+
"Epoch 379, Loss: 297.8265097689401\n",
|
602 |
+
"Epoch 380, Loss: 297.8265090806445\n",
|
603 |
+
"Epoch 381, Loss: 297.8265083924216\n",
|
604 |
+
"Epoch 382, Loss: 297.82650770427136\n",
|
605 |
+
"Epoch 383, Loss: 297.82650701619394\n",
|
606 |
+
"Epoch 384, Loss: 297.82650632818905\n",
|
607 |
+
"Epoch 385, Loss: 297.8265056402571\n",
|
608 |
+
"Epoch 386, Loss: 297.8265049523977\n",
|
609 |
+
"Epoch 387, Loss: 297.82650426461134\n",
|
610 |
+
"Epoch 388, Loss: 297.82650357689784\n",
|
611 |
+
"Epoch 389, Loss: 297.82650288925714\n",
|
612 |
+
"Epoch 390, Loss: 297.82650220168927\n",
|
613 |
+
"Epoch 391, Loss: 297.82650151419443\n",
|
614 |
+
"Epoch 392, Loss: 297.8265008267725\n",
|
615 |
+
"Epoch 393, Loss: 297.8265001394235\n",
|
616 |
+
"Epoch 394, Loss: 297.8264994521476\n",
|
617 |
+
"Epoch 395, Loss: 297.8264987649447\n",
|
618 |
+
"Epoch 396, Loss: 297.82649807781473\n",
|
619 |
+
"Epoch 397, Loss: 297.82649739075794\n",
|
620 |
+
"Epoch 398, Loss: 297.82649670377424\n",
|
621 |
+
"Epoch 399, Loss: 297.82649601686364\n",
|
622 |
+
"Epoch 400, Loss: 297.8264953300262\n",
|
623 |
+
"Epoch 401, Loss: 297.826494643262\n",
|
624 |
+
"Epoch 402, Loss: 297.82649395657097\n",
|
625 |
+
"Epoch 403, Loss: 297.8264932699532\n",
|
626 |
+
"Epoch 404, Loss: 297.8264925834086\n",
|
627 |
+
"Epoch 405, Loss: 297.8264918969374\n",
|
628 |
+
"Epoch 406, Loss: 297.8264912105394\n",
|
629 |
+
"Epoch 407, Loss: 297.8264905242148\n",
|
630 |
+
"Epoch 408, Loss: 297.82648983796355\n",
|
631 |
+
"Epoch 409, Loss: 297.8264891517858\n",
|
632 |
+
"Epoch 410, Loss: 297.82648846568134\n",
|
633 |
+
"Epoch 411, Loss: 297.8264877796504\n",
|
634 |
+
"Epoch 412, Loss: 297.8264870936928\n",
|
635 |
+
"Epoch 413, Loss: 297.82648640780883\n",
|
636 |
+
"Epoch 414, Loss: 297.8264857219984\n",
|
637 |
+
"Epoch 415, Loss: 297.8264850362614\n",
|
638 |
+
"Epoch 416, Loss: 297.82648435059804\n",
|
639 |
+
"Epoch 417, Loss: 297.8264836650082\n",
|
640 |
+
"Epoch 418, Loss: 297.8264829794921\n",
|
641 |
+
"Epoch 419, Loss: 297.82648229404964\n",
|
642 |
+
"Epoch 420, Loss: 297.8264816086808\n",
|
643 |
+
"Epoch 421, Loss: 297.8264809233857\n",
|
644 |
+
"Epoch 422, Loss: 297.82648023816444\n",
|
645 |
+
"Epoch 423, Loss: 297.8264795530168\n",
|
646 |
+
"Epoch 424, Loss: 297.82647886794297\n",
|
647 |
+
"Epoch 425, Loss: 297.82647818294294\n",
|
648 |
+
"Epoch 426, Loss: 297.82647749801686\n",
|
649 |
+
"Epoch 427, Loss: 297.8264768131645\n",
|
650 |
+
"Epoch 428, Loss: 297.8264761283861\n",
|
651 |
+
"Epoch 429, Loss: 297.82647544368155\n",
|
652 |
+
"Epoch 430, Loss: 297.82647475905094\n",
|
653 |
+
"Epoch 431, Loss: 297.82647407449446\n",
|
654 |
+
"Epoch 432, Loss: 297.8264733900118\n",
|
655 |
+
"Epoch 433, Loss: 297.82647270560335\n",
|
656 |
+
"Epoch 434, Loss: 297.8264720212687\n",
|
657 |
+
"Epoch 435, Loss: 297.82647133700834\n",
|
658 |
+
"Epoch 436, Loss: 297.8264706528221\n",
|
659 |
+
"Epoch 437, Loss: 297.82646996870983\n",
|
660 |
+
"Epoch 438, Loss: 297.8264692846718\n",
|
661 |
+
"Epoch 439, Loss: 297.8264686007079\n",
|
662 |
+
"Epoch 440, Loss: 297.8264679168184\n",
|
663 |
+
"Epoch 441, Loss: 297.8264672330029\n",
|
664 |
+
"Epoch 442, Loss: 297.8264665492618\n",
|
665 |
+
"Epoch 443, Loss: 297.82646586559486\n",
|
666 |
+
"Epoch 444, Loss: 297.8264651820023\n",
|
667 |
+
"Epoch 445, Loss: 297.82646449848403\n",
|
668 |
+
"Epoch 446, Loss: 297.8264638150402\n",
|
669 |
+
"Epoch 447, Loss: 297.8264631316708\n",
|
670 |
+
"Epoch 448, Loss: 297.8264624483758\n",
|
671 |
+
"Epoch 449, Loss: 297.8264617651551\n",
|
672 |
+
"Epoch 450, Loss: 297.8264610820091\n",
|
673 |
+
"Epoch 451, Loss: 297.82646039893746\n",
|
674 |
+
"Epoch 452, Loss: 297.82645971594036\n",
|
675 |
+
"Epoch 453, Loss: 297.8264590330179\n",
|
676 |
+
"Epoch 454, Loss: 297.8264583501699\n",
|
677 |
+
"Epoch 455, Loss: 297.82645766739654\n",
|
678 |
+
"Epoch 456, Loss: 297.82645698469787\n",
|
679 |
+
"Epoch 457, Loss: 297.8264563020737\n",
|
680 |
+
"Epoch 458, Loss: 297.82645561952444\n",
|
681 |
+
"Epoch 459, Loss: 297.8264549370498\n",
|
682 |
+
"Epoch 460, Loss: 297.82645425464983\n",
|
683 |
+
"Epoch 461, Loss: 297.8264535723248\n",
|
684 |
+
"Epoch 462, Loss: 297.82645289007445\n",
|
685 |
+
"Epoch 463, Loss: 297.8264522078989\n",
|
686 |
+
"Epoch 464, Loss: 297.8264515257983\n",
|
687 |
+
"Epoch 465, Loss: 297.8264508437724\n",
|
688 |
+
"Epoch 466, Loss: 297.8264501618215\n",
|
689 |
+
"Epoch 467, Loss: 297.82644947994555\n",
|
690 |
+
"Epoch 468, Loss: 297.82644879814444\n",
|
691 |
+
"Epoch 469, Loss: 297.8264481164186\n",
|
692 |
+
"Epoch 470, Loss: 297.82644743476743\n",
|
693 |
+
"Epoch 471, Loss: 297.82644675319153\n",
|
694 |
+
"Epoch 472, Loss: 297.82644607169055\n",
|
695 |
+
"Epoch 473, Loss: 297.8264453902647\n",
|
696 |
+
"Epoch 474, Loss: 297.8264447089139\n",
|
697 |
+
"Epoch 475, Loss: 297.82644402763833\n",
|
698 |
+
"Epoch 476, Loss: 297.82644334643805\n",
|
699 |
+
"Epoch 477, Loss: 297.82644266531275\n",
|
700 |
+
"Epoch 478, Loss: 297.8264419842628\n",
|
701 |
+
"Epoch 479, Loss: 297.82644130328805\n",
|
702 |
+
"Epoch 480, Loss: 297.8264406223886\n",
|
703 |
+
"Epoch 481, Loss: 297.8264399415644\n",
|
704 |
+
"Epoch 482, Loss: 297.8264392608157\n",
|
705 |
+
"Epoch 483, Loss: 297.8264385801421\n",
|
706 |
+
"Epoch 484, Loss: 297.82643789954403\n",
|
707 |
+
"Epoch 485, Loss: 297.82643721902133\n",
|
708 |
+
"Epoch 486, Loss: 297.8264365385741\n",
|
709 |
+
"Epoch 487, Loss: 297.8264358582022\n",
|
710 |
+
"Epoch 488, Loss: 297.82643517790603\n",
|
711 |
+
"Epoch 489, Loss: 297.8264344976852\n",
|
712 |
+
"Epoch 490, Loss: 297.82643381753996\n",
|
713 |
+
"Epoch 491, Loss: 297.8264331374703\n",
|
714 |
+
"Epoch 492, Loss: 297.8264324574763\n",
|
715 |
+
"Epoch 493, Loss: 297.82643177755784\n",
|
716 |
+
"Epoch 494, Loss: 297.82643109771504\n",
|
717 |
+
"Epoch 495, Loss: 297.82643041794796\n",
|
718 |
+
"Epoch 496, Loss: 297.8264297382566\n",
|
719 |
+
"Epoch 497, Loss: 297.82642905864094\n",
|
720 |
+
"Epoch 498, Loss: 297.82642837910106\n",
|
721 |
+
"Epoch 499, Loss: 297.82642769963695\n",
|
722 |
+
"Epoch 500, Loss: 297.8264270202487\n",
|
723 |
+
"Epoch 501, Loss: 297.8264263409362\n",
|
724 |
+
"Epoch 502, Loss: 297.82642566169966\n",
|
725 |
+
"Epoch 503, Loss: 297.826424982539\n",
|
726 |
+
"Epoch 504, Loss: 297.8264243034543\n",
|
727 |
+
"Epoch 505, Loss: 297.82642362444545\n",
|
728 |
+
"Epoch 506, Loss: 297.8264229455126\n",
|
729 |
+
"Epoch 507, Loss: 297.8264222666558\n",
|
730 |
+
"Epoch 508, Loss: 297.82642158787496\n",
|
731 |
+
"Epoch 509, Loss: 297.82642090917034\n",
|
732 |
+
"Epoch 510, Loss: 297.8264202305416\n",
|
733 |
+
"Epoch 511, Loss: 297.82641955198915\n",
|
734 |
+
"Epoch 512, Loss: 297.82641887351275\n",
|
735 |
+
"Epoch 513, Loss: 297.8264181951125\n",
|
736 |
+
"Epoch 514, Loss: 297.8264175167885\n",
|
737 |
+
"Epoch 515, Loss: 297.82641683854075\n",
|
738 |
+
"Epoch 516, Loss: 297.82641616036915\n",
|
739 |
+
"Epoch 517, Loss: 297.82641548227383\n",
|
740 |
+
"Epoch 518, Loss: 297.82641480425497\n",
|
741 |
+
"Epoch 519, Loss: 297.8264141263123\n",
|
742 |
+
"Epoch 520, Loss: 297.82641344844603\n",
|
743 |
+
"Epoch 521, Loss: 297.8264127706562\n",
|
744 |
+
"Epoch 522, Loss: 297.82641209294275\n",
|
745 |
+
"Epoch 523, Loss: 297.82641141530576\n",
|
746 |
+
"Epoch 524, Loss: 297.8264107377451\n",
|
747 |
+
"Epoch 525, Loss: 297.826410060261\n",
|
748 |
+
"Epoch 526, Loss: 297.82640938285346\n",
|
749 |
+
"Epoch 527, Loss: 297.8264087055225\n",
|
750 |
+
"Epoch 528, Loss: 297.826408028268\n",
|
751 |
+
"Epoch 529, Loss: 297.82640735109027\n",
|
752 |
+
"Epoch 530, Loss: 297.8264066739891\n",
|
753 |
+
"Epoch 531, Loss: 297.8264059969646\n",
|
754 |
+
"Epoch 532, Loss: 297.8264053200167\n",
|
755 |
+
"Epoch 533, Loss: 297.8264046431456\n",
|
756 |
+
"Epoch 534, Loss: 297.82640396635117\n",
|
757 |
+
"Epoch 535, Loss: 297.8264032896337\n",
|
758 |
+
"Epoch 536, Loss: 297.82640261299275\n",
|
759 |
+
"Epoch 537, Loss: 297.8264019364288\n",
|
760 |
+
"Epoch 538, Loss: 297.82640125994163\n",
|
761 |
+
"Epoch 539, Loss: 297.8264005835315\n",
|
762 |
+
"Epoch 540, Loss: 297.8263999071981\n",
|
763 |
+
"Epoch 541, Loss: 297.82639923094166\n",
|
764 |
+
"Epoch 542, Loss: 297.8263985547622\n",
|
765 |
+
"Epoch 543, Loss: 297.82639787865975\n",
|
766 |
+
"Epoch 544, Loss: 297.82639720263427\n",
|
767 |
+
"Epoch 545, Loss: 297.8263965266858\n",
|
768 |
+
"Epoch 546, Loss: 297.82639585081455\n",
|
769 |
+
"Epoch 547, Loss: 297.82639517502025\n",
|
770 |
+
"Epoch 548, Loss: 297.82639449930315\n",
|
771 |
+
"Epoch 549, Loss: 297.8263938236632\n",
|
772 |
+
"Epoch 550, Loss: 297.8263931481004\n",
|
773 |
+
"Epoch 551, Loss: 297.8263924726149\n",
|
774 |
+
"Epoch 552, Loss: 297.8263917972065\n",
|
775 |
+
"Epoch 553, Loss: 297.8263911218754\n",
|
776 |
+
"Epoch 554, Loss: 297.82639044662164\n",
|
777 |
+
"Epoch 555, Loss: 297.8263897714451\n",
|
778 |
+
"Epoch 556, Loss: 297.8263890963461\n",
|
779 |
+
"Epoch 557, Loss: 297.82638842132434\n",
|
780 |
+
"Epoch 558, Loss: 297.82638774638\n",
|
781 |
+
"Epoch 559, Loss: 297.82638707151307\n",
|
782 |
+
"Epoch 560, Loss: 297.82638639672365\n",
|
783 |
+
"Epoch 561, Loss: 297.8263857220117\n",
|
784 |
+
"Epoch 562, Loss: 297.8263850473772\n",
|
785 |
+
"Epoch 563, Loss: 297.8263843728203\n",
|
786 |
+
"Epoch 564, Loss: 297.826383698341\n",
|
787 |
+
"Epoch 565, Loss: 297.82638302393923\n",
|
788 |
+
"Epoch 566, Loss: 297.82638234961513\n",
|
789 |
+
"Epoch 567, Loss: 297.8263816753686\n",
|
790 |
+
"Epoch 568, Loss: 297.82638100119976\n",
|
791 |
+
"Epoch 569, Loss: 297.82638032710867\n",
|
792 |
+
"Epoch 570, Loss: 297.82637965309533\n",
|
793 |
+
"Epoch 571, Loss: 297.82637897915964\n",
|
794 |
+
"Epoch 572, Loss: 297.8263783053019\n",
|
795 |
+
"Epoch 573, Loss: 297.8263776315219\n",
|
796 |
+
"Epoch 574, Loss: 297.82637695781983\n",
|
797 |
+
"Epoch 575, Loss: 297.8263762841955\n",
|
798 |
+
"Epoch 576, Loss: 297.8263756106491\n",
|
799 |
+
"Epoch 577, Loss: 297.8263749371807\n",
|
800 |
+
"Epoch 578, Loss: 297.8263742637902\n",
|
801 |
+
"Epoch 579, Loss: 297.82637359047766\n",
|
802 |
+
"Epoch 580, Loss: 297.8263729172433\n",
|
803 |
+
"Epoch 581, Loss: 297.82637224408677\n",
|
804 |
+
"Epoch 582, Loss: 297.8263715710084\n",
|
805 |
+
"Epoch 583, Loss: 297.826370898008\n",
|
806 |
+
"Epoch 584, Loss: 297.8263702250859\n",
|
807 |
+
"Epoch 585, Loss: 297.82636955224183\n",
|
808 |
+
"Epoch 586, Loss: 297.82636887947604\n",
|
809 |
+
"Epoch 587, Loss: 297.8263682067884\n",
|
810 |
+
"Epoch 588, Loss: 297.826367534179\n",
|
811 |
+
"Epoch 589, Loss: 297.82636686164784\n",
|
812 |
+
"Epoch 590, Loss: 297.8263661891951\n",
|
813 |
+
"Epoch 591, Loss: 297.82636551682054\n",
|
814 |
+
"Epoch 592, Loss: 297.82636484452433\n",
|
815 |
+
"Epoch 593, Loss: 297.8263641723065\n",
|
816 |
+
"Epoch 594, Loss: 297.8263635001672\n",
|
817 |
+
"Epoch 595, Loss: 297.82636282810614\n",
|
818 |
+
"Epoch 596, Loss: 297.8263621561237\n",
|
819 |
+
"Epoch 597, Loss: 297.8263614842196\n",
|
820 |
+
"Epoch 598, Loss: 297.82636081239417\n",
|
821 |
+
"Epoch 599, Loss: 297.8263601406472\n",
|
822 |
+
"Epoch 600, Loss: 297.8263594689787\n",
|
823 |
+
"Epoch 601, Loss: 297.826358797389\n",
|
824 |
+
"Epoch 602, Loss: 297.82635812587785\n",
|
825 |
+
"Epoch 603, Loss: 297.82635745444526\n",
|
826 |
+
"Epoch 604, Loss: 297.82635678309146\n",
|
827 |
+
"Epoch 605, Loss: 297.8263561118164\n",
|
828 |
+
"Epoch 606, Loss: 297.82635544062\n",
|
829 |
+
"Epoch 607, Loss: 297.8263547695023\n",
|
830 |
+
"Epoch 608, Loss: 297.82635409846347\n",
|
831 |
+
"Epoch 609, Loss: 297.8263534275034\n",
|
832 |
+
"Epoch 610, Loss: 297.8263527566223\n",
|
833 |
+
"Epoch 611, Loss: 297.8263520858201\n",
|
834 |
+
"Epoch 612, Loss: 297.82635141509684\n",
|
835 |
+
"Epoch 613, Loss: 297.8263507444523\n",
|
836 |
+
"Epoch 614, Loss: 297.8263500738869\n",
|
837 |
+
"Epoch 615, Loss: 297.8263494034004\n",
|
838 |
+
"Epoch 616, Loss: 297.8263487329929\n",
|
839 |
+
"Epoch 617, Loss: 297.8263480626645\n",
|
840 |
+
"Epoch 618, Loss: 297.8263473924152\n",
|
841 |
+
"Epoch 619, Loss: 297.82634672224503\n",
|
842 |
+
"Epoch 620, Loss: 297.8263460521538\n",
|
843 |
+
"Epoch 621, Loss: 297.82634538214194\n",
|
844 |
+
"Epoch 622, Loss: 297.8263447122093\n",
|
845 |
+
"Epoch 623, Loss: 297.82634404235574\n",
|
846 |
+
"Epoch 624, Loss: 297.8263433725814\n",
|
847 |
+
"Epoch 625, Loss: 297.8263427028865\n",
|
848 |
+
"Epoch 626, Loss: 297.82634203327075\n",
|
849 |
+
"Epoch 627, Loss: 297.8263413637344\n",
|
850 |
+
"Epoch 628, Loss: 297.82634069427735\n",
|
851 |
+
"Epoch 629, Loss: 297.82634002489976\n",
|
852 |
+
"Epoch 630, Loss: 297.82633935560165\n",
|
853 |
+
"Epoch 631, Loss: 297.8263386863829\n",
|
854 |
+
"Epoch 632, Loss: 297.8263380172436\n",
|
855 |
+
"Epoch 633, Loss: 297.82633734818376\n",
|
856 |
+
"Epoch 634, Loss: 297.8263366792035\n",
|
857 |
+
"Epoch 635, Loss: 297.8263360103027\n",
|
858 |
+
"Epoch 636, Loss: 297.8263353414817\n",
|
859 |
+
"Epoch 637, Loss: 297.82633467274024\n",
|
860 |
+
"Epoch 638, Loss: 297.8263340040784\n",
|
861 |
+
"Epoch 639, Loss: 297.8263333354962\n",
|
862 |
+
"Epoch 640, Loss: 297.8263326669936\n",
|
863 |
+
"Epoch 641, Loss: 297.826331998571\n",
|
864 |
+
"Epoch 642, Loss: 297.82633133022796\n",
|
865 |
+
"Epoch 643, Loss: 297.82633066196473\n",
|
866 |
+
"Epoch 644, Loss: 297.82632999378137\n",
|
867 |
+
"Epoch 645, Loss: 297.8263293256778\n",
|
868 |
+
"Epoch 646, Loss: 297.8263286576541\n",
|
869 |
+
"Epoch 647, Loss: 297.8263279897103\n",
|
870 |
+
"Epoch 648, Loss: 297.82632732184646\n",
|
871 |
+
"Epoch 649, Loss: 297.8263266540626\n",
|
872 |
+
"Epoch 650, Loss: 297.8263259863587\n",
|
873 |
+
"Epoch 651, Loss: 297.8263253187347\n",
|
874 |
+
"Epoch 652, Loss: 297.82632465119093\n",
|
875 |
+
"Epoch 653, Loss: 297.82632398372704\n",
|
876 |
+
"Epoch 654, Loss: 297.8263233163434\n",
|
877 |
+
"Epoch 655, Loss: 297.8263226490398\n",
|
878 |
+
"Epoch 656, Loss: 297.82632198181636\n",
|
879 |
+
"Epoch 657, Loss: 297.82632131467324\n",
|
880 |
+
"Epoch 658, Loss: 297.82632064761026\n",
|
881 |
+
"Epoch 659, Loss: 297.82631998062743\n",
|
882 |
+
"Epoch 660, Loss: 297.8263193137249\n",
|
883 |
+
"Epoch 661, Loss: 297.8263186469027\n",
|
884 |
+
"Epoch 662, Loss: 297.82631798016087\n",
|
885 |
+
"Epoch 663, Loss: 297.8263173134994\n",
|
886 |
+
"Epoch 664, Loss: 297.82631664691814\n",
|
887 |
+
"Epoch 665, Loss: 297.82631598041746\n",
|
888 |
+
"Epoch 666, Loss: 297.8263153139972\n",
|
889 |
+
"Epoch 667, Loss: 297.82631464765734\n",
|
890 |
+
"Epoch 668, Loss: 297.8263139813981\n",
|
891 |
+
"Epoch 669, Loss: 297.82631331521924\n",
|
892 |
+
"Epoch 670, Loss: 297.82631264912106\n",
|
893 |
+
"Epoch 671, Loss: 297.82631198310344\n",
|
894 |
+
"Epoch 672, Loss: 297.8263113171664\n",
|
895 |
+
"Epoch 673, Loss: 297.82631065131005\n",
|
896 |
+
"Epoch 674, Loss: 297.8263099855343\n",
|
897 |
+
"Epoch 675, Loss: 297.82630931983937\n",
|
898 |
+
"Epoch 676, Loss: 297.82630865422504\n",
|
899 |
+
"Epoch 677, Loss: 297.8263079886915\n",
|
900 |
+
"Epoch 678, Loss: 297.8263073232387\n",
|
901 |
+
"Epoch 679, Loss: 297.8263066578669\n",
|
902 |
+
"Epoch 680, Loss: 297.82630599257584\n",
|
903 |
+
"Epoch 681, Loss: 297.8263053273656\n",
|
904 |
+
"Epoch 682, Loss: 297.82630466223634\n",
|
905 |
+
"Epoch 683, Loss: 297.8263039971879\n",
|
906 |
+
"Epoch 684, Loss: 297.82630333222056\n",
|
907 |
+
"Epoch 685, Loss: 297.8263026673342\n",
|
908 |
+
"Epoch 686, Loss: 297.8263020025288\n",
|
909 |
+
"Epoch 687, Loss: 297.82630133780435\n",
|
910 |
+
"Epoch 688, Loss: 297.82630067316114\n",
|
911 |
+
"Epoch 689, Loss: 297.826300008599\n",
|
912 |
+
"Epoch 690, Loss: 297.8262993441179\n",
|
913 |
+
"Epoch 691, Loss: 297.8262986797181\n",
|
914 |
+
"Epoch 692, Loss: 297.8262980153994\n",
|
915 |
+
"Epoch 693, Loss: 297.82629735116194\n",
|
916 |
+
"Epoch 694, Loss: 297.82629668700577\n",
|
917 |
+
"Epoch 695, Loss: 297.8262960229308\n",
|
918 |
+
"Epoch 696, Loss: 297.8262953589371\n",
|
919 |
+
"Epoch 697, Loss: 297.82629469502484\n",
|
920 |
+
"Epoch 698, Loss: 297.826294031194\n",
|
921 |
+
"Epoch 699, Loss: 297.8262933674444\n",
|
922 |
+
"Epoch 700, Loss: 297.8262927037763\n",
|
923 |
+
"Epoch 701, Loss: 297.82629204018974\n",
|
924 |
+
"Epoch 702, Loss: 297.8262913766845\n",
|
925 |
+
"Epoch 703, Loss: 297.82629071326073\n",
|
926 |
+
"Epoch 704, Loss: 297.82629004991867\n",
|
927 |
+
"Epoch 705, Loss: 297.82628938665823\n",
|
928 |
+
"Epoch 706, Loss: 297.8262887234792\n",
|
929 |
+
"Epoch 707, Loss: 297.8262880603819\n",
|
930 |
+
"Epoch 708, Loss: 297.82628739736623\n",
|
931 |
+
"Epoch 709, Loss: 297.8262867344323\n",
|
932 |
+
"Epoch 710, Loss: 297.82628607158\n",
|
933 |
+
"Epoch 711, Loss: 297.8262854088095\n",
|
934 |
+
"Epoch 712, Loss: 297.8262847461207\n",
|
935 |
+
"Epoch 713, Loss: 297.82628408351377\n",
|
936 |
+
"Epoch 714, Loss: 297.8262834209886\n",
|
937 |
+
"Epoch 715, Loss: 297.8262827585453\n",
|
938 |
+
"Epoch 716, Loss: 297.8262820961839\n",
|
939 |
+
"Epoch 717, Loss: 297.8262814339045\n",
|
940 |
+
"Epoch 718, Loss: 297.826280771707\n",
|
941 |
+
"Epoch 719, Loss: 297.8262801095915\n",
|
942 |
+
"Epoch 720, Loss: 297.82627944755797\n",
|
943 |
+
"Epoch 721, Loss: 297.8262787856065\n",
|
944 |
+
"Epoch 722, Loss: 297.82627812373715\n",
|
945 |
+
"Epoch 723, Loss: 297.8262774619497\n",
|
946 |
+
"Epoch 724, Loss: 297.82627680024444\n",
|
947 |
+
"Epoch 725, Loss: 297.8262761386215\n",
|
948 |
+
"Epoch 726, Loss: 297.8262754770806\n",
|
949 |
+
"Epoch 727, Loss: 297.826274815622\n",
|
950 |
+
"Epoch 728, Loss: 297.82627415424554\n",
|
951 |
+
"Epoch 729, Loss: 297.82627349295143\n",
|
952 |
+
"Epoch 730, Loss: 297.8262728317395\n",
|
953 |
+
"Epoch 731, Loss: 297.82627217061\n",
|
954 |
+
"Epoch 732, Loss: 297.82627150956284\n",
|
955 |
+
"Epoch 733, Loss: 297.8262708485981\n",
|
956 |
+
"Epoch 734, Loss: 297.82627018771575\n",
|
957 |
+
"Epoch 735, Loss: 297.8262695269158\n",
|
958 |
+
"Epoch 736, Loss: 297.8262688661983\n",
|
959 |
+
"Epoch 737, Loss: 297.82626820556334\n",
|
960 |
+
"Epoch 738, Loss: 297.826267545011\n",
|
961 |
+
"Epoch 739, Loss: 297.8262668845411\n",
|
962 |
+
"Epoch 740, Loss: 297.82626622415387\n",
|
963 |
+
"Epoch 741, Loss: 297.82626556384935\n",
|
964 |
+
"Epoch 742, Loss: 297.82626490362736\n",
|
965 |
+
"Epoch 743, Loss: 297.826264243488\n",
|
966 |
+
"Epoch 744, Loss: 297.82626358343146\n",
|
967 |
+
"Epoch 745, Loss: 297.82626292345765\n",
|
968 |
+
"Epoch 746, Loss: 297.82626226356655\n",
|
969 |
+
"Epoch 747, Loss: 297.8262616037582\n",
|
970 |
+
"Epoch 748, Loss: 297.8262609440329\n",
|
971 |
+
"Epoch 749, Loss: 297.8262602843903\n",
|
972 |
+
"Epoch 750, Loss: 297.82625962483047\n",
|
973 |
+
"Epoch 751, Loss: 297.82625896535376\n",
|
974 |
+
"Epoch 752, Loss: 297.82625830595987\n",
|
975 |
+
"Epoch 753, Loss: 297.8262576466491\n",
|
976 |
+
"Epoch 754, Loss: 297.82625698742123\n",
|
977 |
+
"Epoch 755, Loss: 297.82625632827643\n",
|
978 |
+
"Epoch 756, Loss: 297.8262556692147\n",
|
979 |
+
"Epoch 757, Loss: 297.82625501023597\n",
|
980 |
+
"Epoch 758, Loss: 297.8262543513405\n",
|
981 |
+
"Epoch 759, Loss: 297.8262536925281\n",
|
982 |
+
"Epoch 760, Loss: 297.82625303379893\n",
|
983 |
+
"Epoch 761, Loss: 297.8262523751529\n",
|
984 |
+
"Epoch 762, Loss: 297.82625171659015\n",
|
985 |
+
"Epoch 763, Loss: 297.82625105811076\n",
|
986 |
+
"Epoch 764, Loss: 297.8262503997146\n",
|
987 |
+
"Epoch 765, Loss: 297.82624974140174\n",
|
988 |
+
"Epoch 766, Loss: 297.82624908317234\n",
|
989 |
+
"Epoch 767, Loss: 297.8262484250262\n",
|
990 |
+
"Epoch 768, Loss: 297.82624776696355\n",
|
991 |
+
"Epoch 769, Loss: 297.8262471089844\n",
|
992 |
+
"Epoch 770, Loss: 297.82624645108865\n",
|
993 |
+
"Epoch 771, Loss: 297.8262457932765\n",
|
994 |
+
"Epoch 772, Loss: 297.82624513554777\n",
|
995 |
+
"Epoch 773, Loss: 297.8262444779027\n",
|
996 |
+
"Epoch 774, Loss: 297.8262438203412\n",
|
997 |
+
"Epoch 775, Loss: 297.8262431628633\n",
|
998 |
+
"Epoch 776, Loss: 297.8262425054691\n",
|
999 |
+
"Epoch 777, Loss: 297.82624184815865\n",
|
1000 |
+
"Epoch 778, Loss: 297.82624119093174\n",
|
1001 |
+
"Epoch 779, Loss: 297.82624053378873\n",
|
1002 |
+
"Epoch 780, Loss: 297.82623987672946\n",
|
1003 |
+
"Epoch 781, Loss: 297.826239219754\n",
|
1004 |
+
"Epoch 782, Loss: 297.8262385628624\n",
|
1005 |
+
"Epoch 783, Loss: 297.82623790605464\n",
|
1006 |
+
"Epoch 784, Loss: 297.82623724933075\n",
|
1007 |
+
"Epoch 785, Loss: 297.82623659269086\n",
|
1008 |
+
"Epoch 786, Loss: 297.8262359361348\n",
|
1009 |
+
"Epoch 787, Loss: 297.8262352796629\n",
|
1010 |
+
"Epoch 788, Loss: 297.8262346232749\n",
|
1011 |
+
"Epoch 789, Loss: 297.8262339669709\n",
|
1012 |
+
"Epoch 790, Loss: 297.8262333107511\n",
|
1013 |
+
"Epoch 791, Loss: 297.82623265461535\n",
|
1014 |
+
"Epoch 792, Loss: 297.8262319985638\n",
|
1015 |
+
"Epoch 793, Loss: 297.8262313425964\n",
|
1016 |
+
"Epoch 794, Loss: 297.8262306867131\n",
|
1017 |
+
"Epoch 795, Loss: 297.8262300309141\n",
|
1018 |
+
"Epoch 796, Loss: 297.82622937519943\n",
|
1019 |
+
"Epoch 797, Loss: 297.826228719569\n",
|
1020 |
+
"Epoch 798, Loss: 297.82622806402287\n",
|
1021 |
+
"Epoch 799, Loss: 297.82622740856107\n",
|
1022 |
+
"Epoch 800, Loss: 297.8262267531837\n",
|
1023 |
+
"Epoch 801, Loss: 297.82622609789064\n",
|
1024 |
+
"Epoch 802, Loss: 297.826225442682\n",
|
1025 |
+
"Epoch 803, Loss: 297.826224787558\n",
|
1026 |
+
"Epoch 804, Loss: 297.8262241325183\n",
|
1027 |
+
"Epoch 805, Loss: 297.8262234775633\n",
|
1028 |
+
"Epoch 806, Loss: 297.8262228226928\n",
|
1029 |
+
"Epoch 807, Loss: 297.82622216790685\n",
|
1030 |
+
"Epoch 808, Loss: 297.8262215132056\n",
|
1031 |
+
"Epoch 809, Loss: 297.8262208585888\n",
|
1032 |
+
"Epoch 810, Loss: 297.8262202040569\n",
|
1033 |
+
"Epoch 811, Loss: 297.8262195496096\n",
|
1034 |
+
"Epoch 812, Loss: 297.82621889524705\n",
|
1035 |
+
"Epoch 813, Loss: 297.82621824096935\n",
|
1036 |
+
"Epoch 814, Loss: 297.8262175867764\n",
|
1037 |
+
"Epoch 815, Loss: 297.8262169326682\n",
|
1038 |
+
"Epoch 816, Loss: 297.82621627864495\n",
|
1039 |
+
"Epoch 817, Loss: 297.82621562470666\n",
|
1040 |
+
"Epoch 818, Loss: 297.8262149708532\n",
|
1041 |
+
"Epoch 819, Loss: 297.82621431708463\n",
|
1042 |
+
"Epoch 820, Loss: 297.8262136634011\n",
|
1043 |
+
"Epoch 821, Loss: 297.82621300980264\n",
|
1044 |
+
"Epoch 822, Loss: 297.82621235628915\n",
|
1045 |
+
"Epoch 823, Loss: 297.82621170286075\n",
|
1046 |
+
"Epoch 824, Loss: 297.8262110495175\n",
|
1047 |
+
"Epoch 825, Loss: 297.8262103962594\n",
|
1048 |
+
"Epoch 826, Loss: 297.8262097430863\n",
|
1049 |
+
"Epoch 827, Loss: 297.8262090899985\n",
|
1050 |
+
"Epoch 828, Loss: 297.82620843699596\n",
|
1051 |
+
"Epoch 829, Loss: 297.8262077840786\n",
|
1052 |
+
"Epoch 830, Loss: 297.8262071312466\n",
|
1053 |
+
"Epoch 831, Loss: 297.82620647849984\n",
|
1054 |
+
"Epoch 832, Loss: 297.8262058258384\n",
|
1055 |
+
"Epoch 833, Loss: 297.82620517326245\n",
|
1056 |
+
"Epoch 834, Loss: 297.8262045207719\n",
|
1057 |
+
"Epoch 835, Loss: 297.82620386836675\n",
|
1058 |
+
"Epoch 836, Loss: 297.82620321604696\n",
|
1059 |
+
"Epoch 837, Loss: 297.8262025638128\n",
|
1060 |
+
"Epoch 838, Loss: 297.82620191166416\n",
|
1061 |
+
"Epoch 839, Loss: 297.8262012596011\n",
|
1062 |
+
"Epoch 840, Loss: 297.82620060762355\n",
|
1063 |
+
"Epoch 841, Loss: 297.8261999557316\n",
|
1064 |
+
"Epoch 842, Loss: 297.8261993039254\n",
|
1065 |
+
"Epoch 843, Loss: 297.8261986522048\n",
|
1066 |
+
"Epoch 844, Loss: 297.82619800056995\n",
|
1067 |
+
"Epoch 845, Loss: 297.8261973490208\n",
|
1068 |
+
"Epoch 846, Loss: 297.82619669755746\n",
|
1069 |
+
"Epoch 847, Loss: 297.8261960461799\n",
|
1070 |
+
"Epoch 848, Loss: 297.82619539488826\n",
|
1071 |
+
"Epoch 849, Loss: 297.82619474368244\n",
|
1072 |
+
"Epoch 850, Loss: 297.82619409256245\n",
|
1073 |
+
"Epoch 851, Loss: 297.8261934415284\n",
|
1074 |
+
"Epoch 852, Loss: 297.82619279058036\n",
|
1075 |
+
"Epoch 853, Loss: 297.82619213971833\n",
|
1076 |
+
"Epoch 854, Loss: 297.8261914889422\n",
|
1077 |
+
"Epoch 855, Loss: 297.82619083825216\n",
|
1078 |
+
"Epoch 856, Loss: 297.82619018764836\n",
|
1079 |
+
"Epoch 857, Loss: 297.8261895371304\n",
|
1080 |
+
"Epoch 858, Loss: 297.82618888669873\n",
|
1081 |
+
"Epoch 859, Loss: 297.8261882363531\n",
|
1082 |
+
"Epoch 860, Loss: 297.8261875860939\n",
|
1083 |
+
"Epoch 861, Loss: 297.8261869359208\n",
|
1084 |
+
"Epoch 862, Loss: 297.826186285834\n",
|
1085 |
+
"Epoch 863, Loss: 297.82618563583344\n",
|
1086 |
+
"Epoch 864, Loss: 297.8261849859192\n",
|
1087 |
+
"Epoch 865, Loss: 297.8261843360914\n",
|
1088 |
+
"Epoch 866, Loss: 297.82618368634996\n",
|
1089 |
+
"Epoch 867, Loss: 297.8261830366949\n",
|
1090 |
+
"Epoch 868, Loss: 297.82618238712627\n",
|
1091 |
+
"Epoch 869, Loss: 297.82618173764416\n",
|
1092 |
+
"Epoch 870, Loss: 297.8261810882486\n",
|
1093 |
+
"Epoch 871, Loss: 297.82618043893956\n",
|
1094 |
+
"Epoch 872, Loss: 297.826179789717\n",
|
1095 |
+
"Epoch 873, Loss: 297.8261791405811\n",
|
1096 |
+
"Epoch 874, Loss: 297.82617849153183\n",
|
1097 |
+
"Epoch 875, Loss: 297.82617784256917\n",
|
1098 |
+
"Epoch 876, Loss: 297.8261771936933\n",
|
1099 |
+
"Epoch 877, Loss: 297.8261765449042\n",
|
1100 |
+
"Epoch 878, Loss: 297.8261758962016\n",
|
1101 |
+
"Epoch 879, Loss: 297.826175247586\n",
|
1102 |
+
"Epoch 880, Loss: 297.82617459905725\n",
|
1103 |
+
"Epoch 881, Loss: 297.82617395061516\n",
|
1104 |
+
"Epoch 882, Loss: 297.8261733022601\n",
|
1105 |
+
"Epoch 883, Loss: 297.82617265399193\n",
|
1106 |
+
"Epoch 884, Loss: 297.82617200581075\n",
|
1107 |
+
"Epoch 885, Loss: 297.8261713577164\n",
|
1108 |
+
"Epoch 886, Loss: 297.8261707097091\n",
|
1109 |
+
"Epoch 887, Loss: 297.82617006178884\n",
|
1110 |
+
"Epoch 888, Loss: 297.8261694139557\n",
|
1111 |
+
"Epoch 889, Loss: 297.82616876620966\n",
|
1112 |
+
"Epoch 890, Loss: 297.82616811855064\n",
|
1113 |
+
"Epoch 891, Loss: 297.8261674709788\n",
|
1114 |
+
"Epoch 892, Loss: 297.82616682349425\n",
|
1115 |
+
"Epoch 893, Loss: 297.8261661760969\n",
|
1116 |
+
"Epoch 894, Loss: 297.82616552878676\n",
|
1117 |
+
"Epoch 895, Loss: 297.82616488156384\n",
|
1118 |
+
"Epoch 896, Loss: 297.8261642344284\n",
|
1119 |
+
"Epoch 897, Loss: 297.8261635873801\n",
|
1120 |
+
"Epoch 898, Loss: 297.82616294041935\n",
|
1121 |
+
"Epoch 899, Loss: 297.8261622935459\n",
|
1122 |
+
"Epoch 900, Loss: 297.82616164676\n",
|
1123 |
+
"Epoch 901, Loss: 297.8261610000614\n",
|
1124 |
+
"Epoch 902, Loss: 297.8261603534504\n",
|
1125 |
+
"Epoch 903, Loss: 297.82615970692694\n",
|
1126 |
+
"Epoch 904, Loss: 297.826159060491\n",
|
1127 |
+
"Epoch 905, Loss: 297.82615841414264\n",
|
1128 |
+
"Epoch 906, Loss: 297.82615776788197\n",
|
1129 |
+
"Epoch 907, Loss: 297.826157121709\n",
|
1130 |
+
"Epoch 908, Loss: 297.82615647562363\n",
|
1131 |
+
"Epoch 909, Loss: 297.8261558296259\n",
|
1132 |
+
"Epoch 910, Loss: 297.8261551837161\n",
|
1133 |
+
"Epoch 911, Loss: 297.826154537894\n",
|
1134 |
+
"Epoch 912, Loss: 297.82615389215977\n",
|
1135 |
+
"Epoch 913, Loss: 297.82615324651323\n",
|
1136 |
+
"Epoch 914, Loss: 297.8261526009547\n",
|
1137 |
+
"Epoch 915, Loss: 297.826151955484\n",
|
1138 |
+
"Epoch 916, Loss: 297.8261513101013\n",
|
1139 |
+
"Epoch 917, Loss: 297.8261506648065\n",
|
1140 |
+
"Epoch 918, Loss: 297.8261500195997\n",
|
1141 |
+
"Epoch 919, Loss: 297.826149374481\n",
|
1142 |
+
"Epoch 920, Loss: 297.82614872945044\n",
|
1143 |
+
"Epoch 921, Loss: 297.8261480845078\n",
|
1144 |
+
"Epoch 922, Loss: 297.8261474396533\n",
|
1145 |
+
"Epoch 923, Loss: 297.8261467948871\n",
|
1146 |
+
"Epoch 924, Loss: 297.826146150209\n",
|
1147 |
+
"Epoch 925, Loss: 297.82614550561914\n",
|
1148 |
+
"Epoch 926, Loss: 297.8261448611175\n",
|
1149 |
+
"Epoch 927, Loss: 297.82614421670417\n",
|
1150 |
+
"Epoch 928, Loss: 297.8261435723791\n",
|
1151 |
+
"Epoch 929, Loss: 297.8261429281424\n",
|
1152 |
+
"Epoch 930, Loss: 297.8261422839941\n",
|
1153 |
+
"Epoch 931, Loss: 297.82614163993424\n",
|
1154 |
+
"Epoch 932, Loss: 297.8261409959628\n",
|
1155 |
+
"Epoch 933, Loss: 297.82614035207973\n",
|
1156 |
+
"Epoch 934, Loss: 297.82613970828527\n",
|
1157 |
+
"Epoch 935, Loss: 297.8261390645793\n",
|
1158 |
+
"Epoch 936, Loss: 297.826138420962\n",
|
1159 |
+
"Epoch 937, Loss: 297.8261377774331\n",
|
1160 |
+
"Epoch 938, Loss: 297.82613713399303\n",
|
1161 |
+
"Epoch 939, Loss: 297.82613649064143\n",
|
1162 |
+
"Epoch 940, Loss: 297.8261358473787\n",
|
1163 |
+
"Epoch 941, Loss: 297.8261352042046\n",
|
1164 |
+
"Epoch 942, Loss: 297.8261345611193\n",
|
1165 |
+
"Epoch 943, Loss: 297.8261339181227\n",
|
1166 |
+
"Epoch 944, Loss: 297.826133275215\n",
|
1167 |
+
"Epoch 945, Loss: 297.82613263239614\n",
|
1168 |
+
"Epoch 946, Loss: 297.8261319896662\n",
|
1169 |
+
"Epoch 947, Loss: 297.82613134702507\n",
|
1170 |
+
"Epoch 948, Loss: 297.826130704473\n",
|
1171 |
+
"Epoch 949, Loss: 297.8261300620098\n",
|
1172 |
+
"Epoch 950, Loss: 297.8261294196356\n",
|
1173 |
+
"Epoch 951, Loss: 297.8261287773505\n",
|
1174 |
+
"Epoch 952, Loss: 297.8261281351545\n",
|
1175 |
+
"Epoch 953, Loss: 297.8261274930475\n",
|
1176 |
+
"Epoch 954, Loss: 297.82612685102976\n",
|
1177 |
+
"Epoch 955, Loss: 297.8261262091011\n",
|
1178 |
+
"Epoch 956, Loss: 297.82612556726167\n",
|
1179 |
+
"Epoch 957, Loss: 297.8261249255115\n",
|
1180 |
+
"Epoch 958, Loss: 297.8261242838506\n",
|
1181 |
+
"Epoch 959, Loss: 297.8261236422789\n",
|
1182 |
+
"Epoch 960, Loss: 297.8261230007966\n",
|
1183 |
+
"Epoch 961, Loss: 297.8261223594037\n",
|
1184 |
+
"Epoch 962, Loss: 297.82612171810007\n",
|
1185 |
+
"Epoch 963, Loss: 297.82612107688584\n",
|
1186 |
+
"Epoch 964, Loss: 297.8261204357612\n",
|
1187 |
+
"Epoch 965, Loss: 297.82611979472597\n",
|
1188 |
+
"Epoch 966, Loss: 297.8261191537804\n",
|
1189 |
+
"Epoch 967, Loss: 297.82611851292415\n",
|
1190 |
+
"Epoch 968, Loss: 297.8261178721576\n",
|
1191 |
+
"Epoch 969, Loss: 297.8261172314807\n",
|
1192 |
+
"Epoch 970, Loss: 297.8261165908933\n",
|
1193 |
+
"Epoch 971, Loss: 297.82611595039566\n",
|
1194 |
+
"Epoch 972, Loss: 297.8261153099878\n",
|
1195 |
+
"Epoch 973, Loss: 297.82611466966955\n",
|
1196 |
+
"Epoch 974, Loss: 297.8261140294412\n",
|
1197 |
+
"Epoch 975, Loss: 297.8261133893026\n",
|
1198 |
+
"Epoch 976, Loss: 297.82611274925375\n",
|
1199 |
+
"Epoch 977, Loss: 297.82611210929485\n",
|
1200 |
+
"Epoch 978, Loss: 297.82611146942594\n",
|
1201 |
+
"Epoch 979, Loss: 297.8261108296469\n",
|
1202 |
+
"Epoch 980, Loss: 297.8261101899578\n",
|
1203 |
+
"Epoch 981, Loss: 297.8261095503587\n",
|
1204 |
+
"Epoch 982, Loss: 297.8261089108496\n",
|
1205 |
+
"Epoch 983, Loss: 297.82610827143054\n",
|
1206 |
+
"Epoch 984, Loss: 297.8261076321016\n",
|
1207 |
+
"Epoch 985, Loss: 297.8261069928628\n",
|
1208 |
+
"Epoch 986, Loss: 297.8261063537141\n",
|
1209 |
+
"Epoch 987, Loss: 297.8261057146557\n",
|
1210 |
+
"Epoch 988, Loss: 297.8261050756874\n",
|
1211 |
+
"Epoch 989, Loss: 297.82610443680943\n",
|
1212 |
+
"Epoch 990, Loss: 297.82610379802173\n",
|
1213 |
+
"Epoch 991, Loss: 297.82610315932436\n",
|
1214 |
+
"Epoch 992, Loss: 297.8261025207173\n",
|
1215 |
+
"Epoch 993, Loss: 297.8261018822007\n",
|
1216 |
+
"Epoch 994, Loss: 297.8261012437744\n",
|
1217 |
+
"Epoch 995, Loss: 297.8261006054387\n",
|
1218 |
+
"Epoch 996, Loss: 297.82609996719333\n",
|
1219 |
+
"Epoch 997, Loss: 297.8260993290385\n",
|
1220 |
+
"Epoch 998, Loss: 297.8260986909742\n",
|
1221 |
+
"Epoch 999, Loss: 297.8260980530006\n",
|
1222 |
+
"final parameters: m=0.45136980910052144, c=0.49775672565271384, sigma=1562.2616856027405\n"
|
1223 |
+
]
|
1224 |
+
}
|
1225 |
+
],
|
1226 |
+
"source": [
|
1227 |
+
"## Problem 2\n",
|
1228 |
+
"x = np.array([8, 16, 22, 33, 50, 51])\n",
|
1229 |
+
"y = np.array([5, 20, 14, 32, 42, 58])\n",
|
1230 |
+
"\n",
|
1231 |
+
"# $-\\frac{n}{\\sigma}+\\frac{1}{\\sigma^3}\\sum_{i=1}^n(y_i - (mx+c))^2$\n",
|
1232 |
+
"dsigma = lambda sigma, c, m, x: -len(x) / sigma + np.sum(\n",
|
1233 |
+
" [(xi - (m * x + c)) ** 2 for xi in x]\n",
|
1234 |
+
") / (sigma**3)\n",
|
1235 |
+
"# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(y_i - (mx+c))$\n",
|
1236 |
+
"dc = lambda sigma, c, m, x: -np.sum([xi - (m * x + c) for xi in x]) / (sigma**2)\n",
|
1237 |
+
"# $-\\frac{1}{\\sigma^2}\\sum_{i=1}^n(x_i(y_i - (mx+c)))$\n",
|
1238 |
+
"dm = lambda sigma, c, m, x: -np.sum([x * (xi - (m * x + c)) for xi in x]) / (sigma**2)\n",
|
1239 |
+
"\n",
|
1240 |
+
"\n",
|
1241 |
+
"log2 = []\n",
|
1242 |
+
"\n",
|
1243 |
+
"\n",
|
1244 |
+
"def SGD_problem2(\n",
|
1245 |
+
" sigma: float,\n",
|
1246 |
+
" c: float,\n",
|
1247 |
+
" m: float,\n",
|
1248 |
+
" x: np.array,\n",
|
1249 |
+
" y: np.array,\n",
|
1250 |
+
" learning_rate=0.01,\n",
|
1251 |
+
" n_epochs=1000,\n",
|
1252 |
+
"):\n",
|
1253 |
+
" global log2\n",
|
1254 |
+
" log2 = []\n",
|
1255 |
+
" for epoch in range(n_epochs):\n",
|
1256 |
+
" sigma += learning_rate * dsigma(sigma, c, m, x)\n",
|
1257 |
+
" c += learning_rate * dc(sigma, c, m, x)\n",
|
1258 |
+
" m += learning_rate * dm(sigma, c, m, x)\n",
|
1259 |
+
"\n",
|
1260 |
+
" log2.append(\n",
|
1261 |
+
" {\n",
|
1262 |
+
" \"Epoch\": epoch,\n",
|
1263 |
+
" \"New sigma\": sigma,\n",
|
1264 |
+
" \"New c\": c,\n",
|
1265 |
+
" \"New m\": m,\n",
|
1266 |
+
" \"dc\": dc(sigma, c, m, x),\n",
|
1267 |
+
" \"dm\": dm(sigma, c, m, x),\n",
|
1268 |
+
" \"dsigma\": dsigma(sigma, c, m, x),\n",
|
1269 |
+
" \"Loss\": loss((m * x + c), sigma, y),\n",
|
1270 |
+
" }\n",
|
1271 |
+
" )\n",
|
1272 |
+
" print(f\"Epoch {epoch}, Loss: {loss((m * x + c), sigma, y)}\")\n",
|
1273 |
+
" return np.array([sigma, c, m])\n",
|
1274 |
+
"\n",
|
1275 |
+
"\n",
|
1276 |
+
"result = SGD_problem2(0.5, 0.5, 0.5, x, y)\n",
|
1277 |
+
"print(f\"final parameters: m={result[2]}, c={result[1]}, sigma={result[0]}\")"
|
1278 |
+
]
|
1279 |
+
},
|
1280 |
+
{
|
1281 |
+
"cell_type": "markdown",
|
1282 |
+
"id": "0562b012-f4ca-47de-bc76-e0eb2bf1e509",
|
1283 |
+
"metadata": {},
|
1284 |
+
"source": [
|
1285 |
+
"loss appears to be decreasing. Uncollapse cell for output"
|
1286 |
+
]
|
1287 |
+
},
|
1288 |
+
{
|
1289 |
+
"cell_type": "markdown",
|
1290 |
+
"id": "bed9f3ce-c15c-4f30-8906-26f3e51acf30",
|
1291 |
+
"metadata": {},
|
1292 |
+
"source": [
|
1293 |
+
"# Bike Rides and the Poisson Model"
|
1294 |
+
]
|
1295 |
+
},
|
1296 |
+
{
|
1297 |
+
"cell_type": "markdown",
|
1298 |
+
"id": "975e2ef5-f5d5-45a3-b635-8faef035906f",
|
1299 |
+
"metadata": {},
|
1300 |
+
"source": [
|
1301 |
+
"Knowing that the poisson pdf is $P(k) = \\frac{\\lambda^k e^{-\\lambda}}{k!}$, we can find the negative log likelihood of the data as $-\\log(\\Pi_{i=1}^n P(k_i)) = -\\sum_{i=1}^n \\log(\\frac{\\lambda^k_i e^{-\\lambda}}{k_i!}) = \\sum_{i=1}^n -\\ln(\\lambda) k_i + \\ln(k_i!) + \\lambda$. Which simplified, gives $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$. Differentiating with respect to $\\lambda$ gives $n - \\sum_{i=1}^n \\frac{k_i}{\\lambda}$. Which is our desired $\\frac{\\partial L}{\\partial \\lambda}$!"
|
1302 |
+
]
|
1303 |
+
},
|
1304 |
+
{
|
1305 |
+
"cell_type": "code",
|
1306 |
+
"execution_count": 10,
|
1307 |
+
"id": "3877723c-179e-4759-bed5-9eb70110ded2",
|
1308 |
+
"metadata": {
|
1309 |
+
"tags": []
|
1310 |
+
},
|
1311 |
+
"outputs": [
|
1312 |
+
{
|
1313 |
+
"name": "stdout",
|
1314 |
+
"output_type": "stream",
|
1315 |
+
"text": [
|
1316 |
+
"SGD Problem 3\n",
|
1317 |
+
"l: [3215.17703224]\n",
|
1318 |
+
"l diff at start 999.4184849065878\n",
|
1319 |
+
"l diff at end 535.134976163929\n",
|
1320 |
+
"l is improving\n",
|
1321 |
+
"SGD Problem 3\n",
|
1322 |
+
"l: [2326.70336987]\n",
|
1323 |
+
"l diff at start -998.7262223631474\n",
|
1324 |
+
"l diff at end -353.33868620734074\n",
|
1325 |
+
"l is improving\n"
|
1326 |
+
]
|
1327 |
+
},
|
1328 |
+
{
|
1329 |
+
"data": {
|
1330 |
+
"text/html": [
|
1331 |
+
"<div>\n",
|
1332 |
+
"<style scoped>\n",
|
1333 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
1334 |
+
" vertical-align: middle;\n",
|
1335 |
+
" }\n",
|
1336 |
+
"\n",
|
1337 |
+
" .dataframe tbody tr th {\n",
|
1338 |
+
" vertical-align: top;\n",
|
1339 |
+
" }\n",
|
1340 |
+
"\n",
|
1341 |
+
" .dataframe thead th {\n",
|
1342 |
+
" text-align: right;\n",
|
1343 |
+
" }\n",
|
1344 |
+
"</style>\n",
|
1345 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
1346 |
+
" <thead>\n",
|
1347 |
+
" <tr style=\"text-align: right;\">\n",
|
1348 |
+
" <th></th>\n",
|
1349 |
+
" <th>Epoch</th>\n",
|
1350 |
+
" <th>New lambda</th>\n",
|
1351 |
+
" <th>dlambda</th>\n",
|
1352 |
+
" <th>Loss</th>\n",
|
1353 |
+
" <th>l_star</th>\n",
|
1354 |
+
" </tr>\n",
|
1355 |
+
" </thead>\n",
|
1356 |
+
" <tbody>\n",
|
1357 |
+
" <tr>\n",
|
1358 |
+
" <th>0</th>\n",
|
1359 |
+
" <td>0</td>\n",
|
1360 |
+
" <td>1681.315834</td>\n",
|
1361 |
+
" <td>-127.119133</td>\n",
|
1362 |
+
" <td>-3.899989e+06</td>\n",
|
1363 |
+
" <td>2680.042056</td>\n",
|
1364 |
+
" </tr>\n",
|
1365 |
+
" <tr>\n",
|
1366 |
+
" <th>1</th>\n",
|
1367 |
+
" <td>1</td>\n",
|
1368 |
+
" <td>1682.587025</td>\n",
|
1369 |
+
" <td>-126.861418</td>\n",
|
1370 |
+
" <td>-3.900150e+06</td>\n",
|
1371 |
+
" <td>2680.042056</td>\n",
|
1372 |
+
" </tr>\n",
|
1373 |
+
" <tr>\n",
|
1374 |
+
" <th>2</th>\n",
|
1375 |
+
" <td>2</td>\n",
|
1376 |
+
" <td>1683.855639</td>\n",
|
1377 |
+
" <td>-126.604614</td>\n",
|
1378 |
+
" <td>-3.900311e+06</td>\n",
|
1379 |
+
" <td>2680.042056</td>\n",
|
1380 |
+
" </tr>\n",
|
1381 |
+
" <tr>\n",
|
1382 |
+
" <th>3</th>\n",
|
1383 |
+
" <td>3</td>\n",
|
1384 |
+
" <td>1685.121685</td>\n",
|
1385 |
+
" <td>-126.348715</td>\n",
|
1386 |
+
" <td>-3.900471e+06</td>\n",
|
1387 |
+
" <td>2680.042056</td>\n",
|
1388 |
+
" </tr>\n",
|
1389 |
+
" <tr>\n",
|
1390 |
+
" <th>4</th>\n",
|
1391 |
+
" <td>4</td>\n",
|
1392 |
+
" <td>1686.385173</td>\n",
|
1393 |
+
" <td>-126.093716</td>\n",
|
1394 |
+
" <td>-3.900631e+06</td>\n",
|
1395 |
+
" <td>2680.042056</td>\n",
|
1396 |
+
" </tr>\n",
|
1397 |
+
" <tr>\n",
|
1398 |
+
" <th>...</th>\n",
|
1399 |
+
" <td>...</td>\n",
|
1400 |
+
" <td>...</td>\n",
|
1401 |
+
" <td>...</td>\n",
|
1402 |
+
" <td>...</td>\n",
|
1403 |
+
" <td>...</td>\n",
|
1404 |
+
" </tr>\n",
|
1405 |
+
" <tr>\n",
|
1406 |
+
" <th>995</th>\n",
|
1407 |
+
" <td>995</td>\n",
|
1408 |
+
" <td>2325.399976</td>\n",
|
1409 |
+
" <td>-32.636710</td>\n",
|
1410 |
+
" <td>-3.948159e+06</td>\n",
|
1411 |
+
" <td>2680.042056</td>\n",
|
1412 |
+
" </tr>\n",
|
1413 |
+
" <tr>\n",
|
1414 |
+
" <th>996</th>\n",
|
1415 |
+
" <td>996</td>\n",
|
1416 |
+
" <td>2325.726343</td>\n",
|
1417 |
+
" <td>-32.602100</td>\n",
|
1418 |
+
" <td>-3.948170e+06</td>\n",
|
1419 |
+
" <td>2680.042056</td>\n",
|
1420 |
+
" </tr>\n",
|
1421 |
+
" <tr>\n",
|
1422 |
+
" <th>997</th>\n",
|
1423 |
+
" <td>997</td>\n",
|
1424 |
+
" <td>2326.052364</td>\n",
|
1425 |
+
" <td>-32.567536</td>\n",
|
1426 |
+
" <td>-3.948180e+06</td>\n",
|
1427 |
+
" <td>2680.042056</td>\n",
|
1428 |
+
" </tr>\n",
|
1429 |
+
" <tr>\n",
|
1430 |
+
" <th>998</th>\n",
|
1431 |
+
" <td>998</td>\n",
|
1432 |
+
" <td>2326.378040</td>\n",
|
1433 |
+
" <td>-32.533018</td>\n",
|
1434 |
+
" <td>-3.948191e+06</td>\n",
|
1435 |
+
" <td>2680.042056</td>\n",
|
1436 |
+
" </tr>\n",
|
1437 |
+
" <tr>\n",
|
1438 |
+
" <th>999</th>\n",
|
1439 |
+
" <td>999</td>\n",
|
1440 |
+
" <td>2326.703370</td>\n",
|
1441 |
+
" <td>-32.498547</td>\n",
|
1442 |
+
" <td>-3.948201e+06</td>\n",
|
1443 |
+
" <td>2680.042056</td>\n",
|
1444 |
+
" </tr>\n",
|
1445 |
+
" </tbody>\n",
|
1446 |
+
"</table>\n",
|
1447 |
+
"<p>1000 rows × 5 columns</p>\n",
|
1448 |
+
"</div>"
|
1449 |
+
],
|
1450 |
+
"text/plain": [
|
1451 |
+
" Epoch New lambda dlambda Loss l_star\n",
|
1452 |
+
"0 0 1681.315834 -127.119133 -3.899989e+06 2680.042056\n",
|
1453 |
+
"1 1 1682.587025 -126.861418 -3.900150e+06 2680.042056\n",
|
1454 |
+
"2 2 1683.855639 -126.604614 -3.900311e+06 2680.042056\n",
|
1455 |
+
"3 3 1685.121685 -126.348715 -3.900471e+06 2680.042056\n",
|
1456 |
+
"4 4 1686.385173 -126.093716 -3.900631e+06 2680.042056\n",
|
1457 |
+
".. ... ... ... ... ...\n",
|
1458 |
+
"995 995 2325.399976 -32.636710 -3.948159e+06 2680.042056\n",
|
1459 |
+
"996 996 2325.726343 -32.602100 -3.948170e+06 2680.042056\n",
|
1460 |
+
"997 997 2326.052364 -32.567536 -3.948180e+06 2680.042056\n",
|
1461 |
+
"998 998 2326.378040 -32.533018 -3.948191e+06 2680.042056\n",
|
1462 |
+
"999 999 2326.703370 -32.498547 -3.948201e+06 2680.042056\n",
|
1463 |
+
"\n",
|
1464 |
+
"[1000 rows x 5 columns]"
|
1465 |
+
]
|
1466 |
+
},
|
1467 |
+
"execution_count": 10,
|
1468 |
+
"metadata": {},
|
1469 |
+
"output_type": "execute_result"
|
1470 |
+
}
|
1471 |
+
],
|
1472 |
+
"source": [
|
1473 |
+
"import pandas as pd\n",
|
1474 |
+
"\n",
|
1475 |
+
"df = pd.read_csv(\"../data/01_raw/nyc_bb_bicyclist_counts.csv\")\n",
|
1476 |
+
"\n",
|
1477 |
+
"dlambda = lambda l, k: len(k) - np.sum([ki / l for ki in k])\n",
|
1478 |
+
"\n",
|
1479 |
+
"\n",
|
1480 |
+
"def SGD_problem3(\n",
|
1481 |
+
" l: float,\n",
|
1482 |
+
" k: np.array,\n",
|
1483 |
+
" learning_rate=0.01,\n",
|
1484 |
+
" n_epochs=1000,\n",
|
1485 |
+
"):\n",
|
1486 |
+
" global log3\n",
|
1487 |
+
" log3 = []\n",
|
1488 |
+
" for epoch in range(n_epochs):\n",
|
1489 |
+
" l -= learning_rate * dlambda(l, k)\n",
|
1490 |
+
" # $n\\lambda + \\sum_{i=1}^n \\ln(k_i!) - \\sum_{i=1}^n k_i \\ln(\\lambda)$\n",
|
1491 |
+
" # the rest of the loss function is commented out because it's a\n",
|
1492 |
+
" # constant and was causing overflows. It is unnecessary, and a useless\n",
|
1493 |
+
" # pain.\n",
|
1494 |
+
" loss = len(k) * l - np.sum(\n",
|
1495 |
+
" [ki * np.log(l) for ki in k]\n",
|
1496 |
+
" ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n",
|
1497 |
+
"\n",
|
1498 |
+
" log3.append(\n",
|
1499 |
+
" {\n",
|
1500 |
+
" \"Epoch\": epoch,\n",
|
1501 |
+
" \"New lambda\": l,\n",
|
1502 |
+
" \"dlambda\": dlambda(l, k),\n",
|
1503 |
+
" \"Loss\": loss,\n",
|
1504 |
+
" }\n",
|
1505 |
+
" )\n",
|
1506 |
+
" # print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n",
|
1507 |
+
" return np.array([l])\n",
|
1508 |
+
"\n",
|
1509 |
+
"\n",
|
1510 |
+
"l_star = df[\"BB_COUNT\"].mean()\n",
|
1511 |
+
"\n",
|
1512 |
+
"\n",
|
1513 |
+
"def debug_SGD_3(data, l=1000):\n",
|
1514 |
+
" print(\"SGD Problem 3\")\n",
|
1515 |
+
" print(f\"l: {SGD_problem3(l, data)}\")\n",
|
1516 |
+
" dflog = pd.DataFrame(log3)\n",
|
1517 |
+
" dflog[\"l_star\"] = l_star\n",
|
1518 |
+
" print(f\"l diff at start {dflog.iloc[0]['New lambda'] - dflog.iloc[0]['l_star']}\")\n",
|
1519 |
+
" print(f\"l diff at end {dflog.iloc[-1]['New lambda'] - dflog.iloc[-1]['l_star']}\")\n",
|
1520 |
+
" if np.abs(dflog.iloc[-1][\"New lambda\"] - dflog.iloc[-1][\"l_star\"]) < np.abs(\n",
|
1521 |
+
" dflog.iloc[0][\"New lambda\"] - dflog.iloc[0][\"l_star\"]\n",
|
1522 |
+
" ):\n",
|
1523 |
+
" print(\"l is improving\")\n",
|
1524 |
+
" else:\n",
|
1525 |
+
" print(\"l is not improving\")\n",
|
1526 |
+
" return dflog\n",
|
1527 |
+
"\n",
|
1528 |
+
"\n",
|
1529 |
+
"debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star + 1000)\n",
|
1530 |
+
"debug_SGD_3(data=df[\"BB_COUNT\"].values, l=l_star - 1000)"
|
1531 |
+
]
|
1532 |
+
},
|
1533 |
+
{
|
1534 |
+
"cell_type": "markdown",
|
1535 |
+
"id": "c05192f9-78ae-4bdb-9df5-cac91006d79f",
|
1536 |
+
"metadata": {},
|
1537 |
+
"source": [
|
1538 |
+
"l approaches the l_star and decreases the loss function."
|
1539 |
+
]
|
1540 |
+
},
|
1541 |
+
{
|
1542 |
+
"cell_type": "markdown",
|
1543 |
+
"id": "4955b868-7f67-4760-bf86-39f6edd55871",
|
1544 |
+
"metadata": {},
|
1545 |
+
"source": [
|
1546 |
+
"## Maximum Likelihood II"
|
1547 |
+
]
|
1548 |
+
},
|
1549 |
+
{
|
1550 |
+
"cell_type": "code",
|
1551 |
+
"execution_count": 7,
|
1552 |
+
"id": "7c8b167d-c397-4155-93f3-d826c279fbb2",
|
1553 |
+
"metadata": {
|
1554 |
+
"tags": []
|
1555 |
+
},
|
1556 |
+
"outputs": [
|
1557 |
+
{
|
1558 |
+
"ename": "SyntaxError",
|
1559 |
+
"evalue": "invalid syntax (3400372070.py, line 66)",
|
1560 |
+
"output_type": "error",
|
1561 |
+
"traceback": [
|
1562 |
+
"\u001b[0;36m Cell \u001b[0;32mIn[7], line 66\u001b[0;36m\u001b[0m\n\u001b[0;31m p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
|
1563 |
+
]
|
1564 |
+
}
|
1565 |
+
],
|
1566 |
+
"source": [
|
1567 |
+
"## pset 4\n",
|
1568 |
+
"\n",
|
1569 |
+
"# dw = lambda w, x: len(x) * np.exp(np.dot(x, w)) * x - np.sum()\n",
|
1570 |
+
"\n",
|
1571 |
+
"primitive = lambda xi, wi: (x.shape[0] * np.exp(wi * xi) * xi) - (xi**2)\n",
|
1572 |
+
"p_dw = lambda w, xi: np.array([primitive(xi, wi) for xi, wi in ])\n",
|
1573 |
+
"\n",
|
1574 |
+
"\n",
|
1575 |
+
"def SGD_problem4(\n",
|
1576 |
+
" w: np.array,\n",
|
1577 |
+
" x: np.array,\n",
|
1578 |
+
" learning_rate=0.01,\n",
|
1579 |
+
" n_epochs=1000,\n",
|
1580 |
+
"):\n",
|
1581 |
+
" global log4\n",
|
1582 |
+
" log4 = []\n",
|
1583 |
+
" for epoch in range(n_epochs):\n",
|
1584 |
+
" w -= learning_rate * p_dw(w, x)\n",
|
1585 |
+
" # custom\n",
|
1586 |
+
" # loss = x.shape[0] * np.exp(np.dot(x, w))\n",
|
1587 |
+
" loss_fn = lambda k, l: len(k) * l - np.sum(\n",
|
1588 |
+
" [ki * np.log(l) for ki in k]\n",
|
1589 |
+
" ) # + np.sum([np.log(np.math.factorial(ki)) for ki in k])\n",
|
1590 |
+
" loss = loss_fn(x, np.exp(np.dot(x, w)))\n",
|
1591 |
+
" log4.append(\n",
|
1592 |
+
" {\n",
|
1593 |
+
" \"Epoch\": epoch,\n",
|
1594 |
+
" \"New w\": w,\n",
|
1595 |
+
" \"dw\": dw(w, x),\n",
|
1596 |
+
" \"Loss\": loss,\n",
|
1597 |
+
" }\n",
|
1598 |
+
" )\n",
|
1599 |
+
" print(f\"Epoch {epoch}\", f\"Loss: {loss}\")\n",
|
1600 |
+
" return w\n",
|
1601 |
+
"\n",
|
1602 |
+
"\n",
|
1603 |
+
"def debug_SGD_3(data, w=np.array([1, 1])):\n",
|
1604 |
+
" print(\"SGD Problem 4\")\n",
|
1605 |
+
" print(f\"w: {SGD_problem4(w, data)}\")\n",
|
1606 |
+
" dflog = pd.DataFrame(log4)\n",
|
1607 |
+
" return dflog\n",
|
1608 |
+
"\n",
|
1609 |
+
"\n",
|
1610 |
+
"_ = debug_SGD_3(\n",
|
1611 |
+
" data=df[[\"HIGH_T\", \"LOW_T\", \"PRECIP\"]].to_numpy(),\n",
|
1612 |
+
" w=np.array([1.0, 1.0, 1.0]),\n",
|
1613 |
+
")"
|
1614 |
+
]
|
1615 |
+
},
|
1616 |
+
{
|
1617 |
+
"cell_type": "markdown",
|
1618 |
+
"id": "69e9148b-70fb-46e3-bc29-a08f471cccab",
|
1619 |
+
"metadata": {},
|
1620 |
+
"source": [
|
1621 |
+
"Seek to maximize likelihood, which is equivalent to minimizing log likelihood, which is equivalent to maximizing MSE.\n",
|
1622 |
+
"\n",
|
1623 |
+
"-> Find MSE"
|
1624 |
+
]
|
1625 |
+
},
|
1626 |
+
{
|
1627 |
+
"cell_type": "code",
|
1628 |
+
"execution_count": 2,
|
1629 |
+
"id": "fb9d2e20-8a02-4f78-a3d3-6fc171f17af6",
|
1630 |
+
"metadata": {},
|
1631 |
+
"outputs": [
|
1632 |
+
{
|
1633 |
+
"ename": "SyntaxError",
|
1634 |
+
"evalue": "incomplete input (3680490224.py, line 3)",
|
1635 |
+
"output_type": "error",
|
1636 |
+
"traceback": [
|
1637 |
+
"\u001b[0;36m Cell \u001b[0;32mIn[2], line 3\u001b[0;36m\u001b[0m\n\u001b[0;31m def MSE(x: List):\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m incomplete input\n"
|
1638 |
+
]
|
1639 |
+
}
|
1640 |
+
],
|
1641 |
+
"source": [
|
1642 |
+
"from typing import List\n",
|
1643 |
+
"from numpy import \n",
|
1644 |
+
"\n",
|
1645 |
+
"w = np.array(2)\n",
|
1646 |
+
"\n",
|
1647 |
+
"def partial_of_Loss(w, x, y):\n",
|
1648 |
+
" \n",
|
1649 |
+
" "
|
1650 |
+
]
|
1651 |
+
},
|
1652 |
+
{
|
1653 |
+
"cell_type": "code",
|
1654 |
+
"execution_count": null,
|
1655 |
+
"id": "7c00197d-873d-41b0-a458-dc8478b40f52",
|
1656 |
+
"metadata": {},
|
1657 |
+
"outputs": [],
|
1658 |
+
"source": []
|
1659 |
+
}
|
1660 |
+
],
|
1661 |
+
"metadata": {
|
1662 |
+
"kernelspec": {
|
1663 |
+
"display_name": "Python 3 (ipykernel)",
|
1664 |
+
"language": "python",
|
1665 |
+
"name": "python3"
|
1666 |
+
},
|
1667 |
+
"language_info": {
|
1668 |
+
"codemirror_mode": {
|
1669 |
+
"name": "ipython",
|
1670 |
+
"version": 3
|
1671 |
+
},
|
1672 |
+
"file_extension": ".py",
|
1673 |
+
"mimetype": "text/x-python",
|
1674 |
+
"name": "python",
|
1675 |
+
"nbconvert_exporter": "python",
|
1676 |
+
"pygments_lexer": "ipython3",
|
1677 |
+
"version": "3.10.4"
|
1678 |
+
}
|
1679 |
+
},
|
1680 |
+
"nbformat": 4,
|
1681 |
+
"nbformat_minor": 5
|
1682 |
+
}
|
assignment-2/assignment_2/__init__.py
ADDED
File without changes
|
assignment-2/data/01_raw/nyc_bb_bicyclist_counts.csv
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Date,HIGH_T,LOW_T,PRECIP,BB_COUNT
|
2 |
+
1-Apr-17,46.00,37.00,0.00,606
|
3 |
+
2-Apr-17,62.10,41.00,0.00,2021
|
4 |
+
3-Apr-17,63.00,50.00,0.03,2470
|
5 |
+
4-Apr-17,51.10,46.00,1.18,723
|
6 |
+
5-Apr-17,63.00,46.00,0.00,2807
|
7 |
+
6-Apr-17,48.90,41.00,0.73,461
|
8 |
+
7-Apr-17,48.00,43.00,0.01,1222
|
9 |
+
8-Apr-17,55.90,39.90,0.00,1674
|
10 |
+
9-Apr-17,66.00,45.00,0.00,2375
|
11 |
+
10-Apr-17,73.90,55.00,0.00,3324
|
12 |
+
11-Apr-17,80.10,62.10,0.00,3887
|
13 |
+
12-Apr-17,73.90,57.90,0.02,2565
|
14 |
+
13-Apr-17,64.00,48.90,0.00,3353
|
15 |
+
14-Apr-17,64.90,48.90,0.00,2942
|
16 |
+
15-Apr-17,64.90,52.00,0.00,2253
|
17 |
+
16-Apr-17,84.90,62.10,0.01,2877
|
18 |
+
17-Apr-17,73.90,64.00,0.01,3152
|
19 |
+
18-Apr-17,66.00,50.00,0.00,3415
|
20 |
+
19-Apr-17,52.00,45.00,0.01,1965
|
21 |
+
20-Apr-17,64.90,50.00,0.17,1567
|
22 |
+
21-Apr-17,53.10,48.00,0.29,1426
|
23 |
+
22-Apr-17,55.90,52.00,0.11,1318
|
24 |
+
23-Apr-17,64.90,46.90,0.00,2520
|
25 |
+
24-Apr-17,60.10,50.00,0.01,2544
|
26 |
+
25-Apr-17,54.00,50.00,0.91,611
|
27 |
+
26-Apr-17,59.00,54.00,0.34,1247
|
28 |
+
27-Apr-17,68.00,59.00,0.00,2959
|
29 |
+
28-Apr-17,82.90,57.90,0.00,3679
|
30 |
+
29-Apr-17,84.00,64.00,0.06,3315
|
31 |
+
30-Apr-17,64.00,54.00,0.00,2225
|
32 |
+
1-May-17,72.00,50.00,0.00,3084
|
33 |
+
2-May-17,73.90,66.90,0.00,3423
|
34 |
+
3-May-17,64.90,57.90,0.00,3342
|
35 |
+
4-May-17,63.00,50.00,0.00,3019
|
36 |
+
5-May-17,59.00,52.00,3.02,513
|
37 |
+
6-May-17,64.90,57.00,0.18,1892
|
38 |
+
7-May-17,54.00,48.90,0.01,3539
|
39 |
+
8-May-17,57.00,45.00,0.00,2886
|
40 |
+
9-May-17,61.00,48.00,0.00,2718
|
41 |
+
10-May-17,70.00,51.10,0.00,2810
|
42 |
+
11-May-17,61.00,51.80,0.00,2657
|
43 |
+
12-May-17,62.10,51.10,0.00,2640
|
44 |
+
13-May-17,51.10,45.00,1.31,151
|
45 |
+
14-May-17,64.90,46.00,0.02,1452
|
46 |
+
15-May-17,66.90,55.90,0.00,2685
|
47 |
+
16-May-17,78.10,57.90,0.00,3666
|
48 |
+
17-May-17,90.00,66.00,0.00,3535
|
49 |
+
18-May-17,91.90,75.00,0.00,3190
|
50 |
+
19-May-17,90.00,75.90,0.00,2952
|
51 |
+
20-May-17,64.00,55.90,0.01,2161
|
52 |
+
21-May-17,66.90,55.00,0.00,2612
|
53 |
+
22-May-17,61.00,54.00,0.59,768
|
54 |
+
23-May-17,68.00,57.90,0.00,3174
|
55 |
+
24-May-17,66.90,57.00,0.04,2969
|
56 |
+
25-May-17,57.90,55.90,0.58,488
|
57 |
+
26-May-17,73.00,55.90,0.10,2590
|
58 |
+
27-May-17,71.10,61.00,0.00,2609
|
59 |
+
28-May-17,71.10,59.00,0.00,2640
|
60 |
+
29-May-17,57.90,55.90,0.13,836
|
61 |
+
30-May-17,59.00,55.90,0.06,2301
|
62 |
+
31-May-17,75.00,57.90,0.03,2689
|
63 |
+
1-Jun-17,78.10,62.10,0.00,3468
|
64 |
+
2-Jun-17,73.90,60.10,0.01,3271
|
65 |
+
3-Jun-17,72.00,55.00,0.01,2589
|
66 |
+
4-Jun-17,68.00,60.10,0.09,1805
|
67 |
+
5-Jun-17,66.90,60.10,0.02,2171
|
68 |
+
6-Jun-17,55.90,53.10,0.06,1193
|
69 |
+
7-Jun-17,66.90,54.00,0.00,3211
|
70 |
+
8-Jun-17,68.00,59.00,0.00,3253
|
71 |
+
9-Jun-17,80.10,59.00,0.00,3401
|
72 |
+
10-Jun-17,84.00,68.00,0.00,3066
|
73 |
+
11-Jun-17,90.00,73.00,0.00,2465
|
74 |
+
12-Jun-17,91.90,77.00,0.00,2854
|
75 |
+
13-Jun-17,93.90,78.10,0.01,2882
|
76 |
+
14-Jun-17,84.00,69.10,0.29,2596
|
77 |
+
15-Jun-17,75.00,66.00,0.00,3510
|
78 |
+
16-Jun-17,68.00,66.00,0.00,2054
|
79 |
+
17-Jun-17,73.00,66.90,1.39,1399
|
80 |
+
18-Jun-17,84.00,72.00,0.01,2199
|
81 |
+
19-Jun-17,87.10,70.00,1.35,1648
|
82 |
+
20-Jun-17,82.00,72.00,0.03,3407
|
83 |
+
21-Jun-17,82.00,72.00,0.00,3304
|
84 |
+
22-Jun-17,82.00,70.00,0.00,3368
|
85 |
+
23-Jun-17,82.90,75.90,0.04,2283
|
86 |
+
24-Jun-17,82.90,71.10,1.29,2307
|
87 |
+
25-Jun-17,82.00,69.10,0.00,2625
|
88 |
+
26-Jun-17,78.10,66.00,0.00,3386
|
89 |
+
27-Jun-17,75.90,61.00,0.18,3182
|
90 |
+
28-Jun-17,78.10,62.10,0.00,3766
|
91 |
+
29-Jun-17,81.00,68.00,0.00,3356
|
92 |
+
30-Jun-17,88.00,73.90,0.01,2687
|
93 |
+
1-Jul-17,84.90,72.00,0.23,1848
|
94 |
+
2-Jul-17,87.10,73.00,0.00,2467
|
95 |
+
3-Jul-17,87.10,71.10,0.45,2714
|
96 |
+
4-Jul-17,82.90,70.00,0.00,2296
|
97 |
+
5-Jul-17,84.90,71.10,0.00,3170
|
98 |
+
6-Jul-17,75.00,71.10,0.01,3065
|
99 |
+
7-Jul-17,79.00,68.00,1.78,1513
|
100 |
+
8-Jul-17,82.90,70.00,0.00,2718
|
101 |
+
9-Jul-17,81.00,69.10,0.00,3048
|
102 |
+
10-Jul-17,82.90,71.10,0.00,3506
|
103 |
+
11-Jul-17,84.00,75.00,0.00,2929
|
104 |
+
12-Jul-17,87.10,77.00,0.00,2860
|
105 |
+
13-Jul-17,89.10,77.00,0.00,2563
|
106 |
+
14-Jul-17,69.10,64.90,0.35,907
|
107 |
+
15-Jul-17,82.90,68.00,0.00,2853
|
108 |
+
16-Jul-17,84.90,70.00,0.00,2917
|
109 |
+
17-Jul-17,84.90,73.90,0.00,3264
|
110 |
+
18-Jul-17,87.10,75.90,0.00,3507
|
111 |
+
19-Jul-17,91.00,77.00,0.00,3114
|
112 |
+
20-Jul-17,93.00,78.10,0.01,2840
|
113 |
+
21-Jul-17,91.00,77.00,0.00,2751
|
114 |
+
22-Jul-17,91.00,78.10,0.57,2301
|
115 |
+
23-Jul-17,78.10,73.00,0.06,2321
|
116 |
+
24-Jul-17,69.10,63.00,0.74,1576
|
117 |
+
25-Jul-17,71.10,64.00,0.00,3191
|
118 |
+
26-Jul-17,75.90,66.00,0.00,3821
|
119 |
+
27-Jul-17,77.00,66.90,0.01,3287
|
120 |
+
28-Jul-17,84.90,73.00,0.00,3123
|
121 |
+
29-Jul-17,75.90,68.00,0.00,2074
|
122 |
+
30-Jul-17,81.00,64.90,0.00,3331
|
123 |
+
31-Jul-17,88.00,66.90,0.00,3560
|
124 |
+
1-Aug-17,91.00,72.00,0.00,3492
|
125 |
+
2-Aug-17,86.00,69.10,0.09,2637
|
126 |
+
3-Aug-17,86.00,70.00,0.00,3346
|
127 |
+
4-Aug-17,82.90,70.00,0.15,2400
|
128 |
+
5-Aug-17,77.00,70.00,0.30,3409
|
129 |
+
6-Aug-17,75.90,64.00,0.00,3130
|
130 |
+
7-Aug-17,71.10,64.90,0.76,804
|
131 |
+
8-Aug-17,77.00,66.00,0.00,3598
|
132 |
+
9-Aug-17,82.90,66.00,0.00,3893
|
133 |
+
10-Aug-17,82.90,69.10,0.00,3423
|
134 |
+
11-Aug-17,81.00,70.00,0.01,3148
|
135 |
+
12-Aug-17,75.90,64.90,0.11,4146
|
136 |
+
13-Aug-17,82.00,71.10,0.00,3274
|
137 |
+
14-Aug-17,80.10,70.00,0.00,3291
|
138 |
+
15-Aug-17,73.00,69.10,0.45,2149
|
139 |
+
16-Aug-17,84.90,70.00,0.00,3685
|
140 |
+
17-Aug-17,82.00,71.10,0.00,3637
|
141 |
+
18-Aug-17,81.00,73.00,0.88,1064
|
142 |
+
19-Aug-17,84.90,73.00,0.00,4693
|
143 |
+
20-Aug-17,81.00,70.00,0.00,2822
|
144 |
+
21-Aug-17,84.90,73.00,0.00,3088
|
145 |
+
22-Aug-17,88.00,75.00,0.30,2983
|
146 |
+
23-Aug-17,80.10,71.10,0.01,2994
|
147 |
+
24-Aug-17,79.00,66.00,0.00,3688
|
148 |
+
25-Aug-17,78.10,64.00,0.00,3144
|
149 |
+
26-Aug-17,77.00,62.10,0.00,2710
|
150 |
+
27-Aug-17,77.00,63.00,0.00,2676
|
151 |
+
28-Aug-17,75.00,63.00,0.00,3332
|
152 |
+
29-Aug-17,68.00,62.10,0.10,1472
|
153 |
+
30-Aug-17,75.90,61.00,0.01,3468
|
154 |
+
31-Aug-17,81.00,64.00,0.00,3279
|
155 |
+
1-Sep-17,70.00,55.00,0.00,2945
|
156 |
+
2-Sep-17,66.90,54.00,0.53,1876
|
157 |
+
3-Sep-17,69.10,60.10,0.74,1004
|
158 |
+
4-Sep-17,79.00,62.10,0.00,2866
|
159 |
+
5-Sep-17,84.00,70.00,0.01,3244
|
160 |
+
6-Sep-17,70.00,62.10,0.42,1232
|
161 |
+
7-Sep-17,71.10,59.00,0.01,3249
|
162 |
+
8-Sep-17,70.00,59.00,0.00,3234
|
163 |
+
9-Sep-17,69.10,55.00,0.00,2609
|
164 |
+
10-Sep-17,72.00,57.00,0.00,4960
|
165 |
+
11-Sep-17,75.90,55.00,0.00,3657
|
166 |
+
12-Sep-17,78.10,61.00,0.00,3497
|
167 |
+
13-Sep-17,82.00,64.90,0.06,2994
|
168 |
+
14-Sep-17,81.00,70.00,0.02,3013
|
169 |
+
15-Sep-17,81.00,66.90,0.00,3344
|
170 |
+
16-Sep-17,82.00,70.00,0.00,2560
|
171 |
+
17-Sep-17,80.10,70.00,0.00,2676
|
172 |
+
18-Sep-17,73.00,69.10,0.00,2673
|
173 |
+
19-Sep-17,78.10,69.10,0.22,2012
|
174 |
+
20-Sep-17,78.10,71.10,0.00,3296
|
175 |
+
21-Sep-17,80.10,71.10,0.00,3317
|
176 |
+
22-Sep-17,82.00,66.00,0.00,3297
|
177 |
+
23-Sep-17,86.00,68.00,0.00,2810
|
178 |
+
24-Sep-17,90.00,69.10,0.00,2543
|
179 |
+
25-Sep-17,87.10,72.00,0.00,3276
|
180 |
+
26-Sep-17,82.00,69.10,0.00,3157
|
181 |
+
27-Sep-17,84.90,71.10,0.00,3216
|
182 |
+
28-Sep-17,78.10,66.00,0.00,3421
|
183 |
+
29-Sep-17,66.90,55.00,0.00,2988
|
184 |
+
30-Sep-17,64.00,55.90,0.00,1903
|
185 |
+
1-Oct-17,66.90,50.00,0.00,2297
|
186 |
+
2-Oct-17,72.00,52.00,0.00,3387
|
187 |
+
3-Oct-17,70.00,57.00,0.00,3386
|
188 |
+
4-Oct-17,75.00,55.90,0.00,3412
|
189 |
+
5-Oct-17,82.00,64.90,0.00,3312
|
190 |
+
6-Oct-17,81.00,69.10,0.00,2982
|
191 |
+
7-Oct-17,80.10,66.00,0.00,2750
|
192 |
+
8-Oct-17,77.00,72.00,0.22,1235
|
193 |
+
9-Oct-17,75.90,72.00,0.26,898
|
194 |
+
10-Oct-17,80.10,66.00,0.00,3922
|
195 |
+
11-Oct-17,75.00,64.90,0.06,2721
|
196 |
+
12-Oct-17,63.00,55.90,0.07,2411
|
197 |
+
13-Oct-17,64.90,52.00,0.00,2839
|
198 |
+
14-Oct-17,71.10,62.10,0.08,2021
|
199 |
+
15-Oct-17,72.00,66.00,0.01,2169
|
200 |
+
16-Oct-17,60.10,52.00,0.01,2751
|
201 |
+
17-Oct-17,57.90,43.00,0.00,2869
|
202 |
+
18-Oct-17,71.10,50.00,0.00,3264
|
203 |
+
19-Oct-17,70.00,55.90,0.00,3265
|
204 |
+
20-Oct-17,73.00,57.90,0.00,3169
|
205 |
+
21-Oct-17,78.10,57.00,0.00,2538
|
206 |
+
22-Oct-17,75.90,57.00,0.00,2744
|
207 |
+
23-Oct-17,73.90,64.00,0.00,3189
|
208 |
+
24-Oct-17,73.00,66.90,0.20,954
|
209 |
+
25-Oct-17,64.90,57.90,0.00,3367
|
210 |
+
26-Oct-17,57.00,53.10,0.00,2565
|
211 |
+
27-Oct-17,62.10,48.00,0.00,3150
|
212 |
+
28-Oct-17,68.00,55.90,0.00,2245
|
213 |
+
29-Oct-17,64.90,61.00,3.03,183
|
214 |
+
30-Oct-17,55.00,46.00,0.25,1428
|
215 |
+
31-Oct-17,54.00,44.00,0.00,2727
|
assignment-2/data/nyc_bb_bicyclist_counts.csv
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Date,HIGH_T,LOW_T,PRECIP,BB_COUNT
|
2 |
+
1-Apr-17,46.00,37.00,0.00,606
|
3 |
+
2-Apr-17,62.10,41.00,0.00,2021
|
4 |
+
3-Apr-17,63.00,50.00,0.03,2470
|
5 |
+
4-Apr-17,51.10,46.00,1.18,723
|
6 |
+
5-Apr-17,63.00,46.00,0.00,2807
|
7 |
+
6-Apr-17,48.90,41.00,0.73,461
|
8 |
+
7-Apr-17,48.00,43.00,0.01,1222
|
9 |
+
8-Apr-17,55.90,39.90,0.00,1674
|
10 |
+
9-Apr-17,66.00,45.00,0.00,2375
|
11 |
+
10-Apr-17,73.90,55.00,0.00,3324
|
12 |
+
11-Apr-17,80.10,62.10,0.00,3887
|
13 |
+
12-Apr-17,73.90,57.90,0.02,2565
|
14 |
+
13-Apr-17,64.00,48.90,0.00,3353
|
15 |
+
14-Apr-17,64.90,48.90,0.00,2942
|
16 |
+
15-Apr-17,64.90,52.00,0.00,2253
|
17 |
+
16-Apr-17,84.90,62.10,0.01,2877
|
18 |
+
17-Apr-17,73.90,64.00,0.01,3152
|
19 |
+
18-Apr-17,66.00,50.00,0.00,3415
|
20 |
+
19-Apr-17,52.00,45.00,0.01,1965
|
21 |
+
20-Apr-17,64.90,50.00,0.17,1567
|
22 |
+
21-Apr-17,53.10,48.00,0.29,1426
|
23 |
+
22-Apr-17,55.90,52.00,0.11,1318
|
24 |
+
23-Apr-17,64.90,46.90,0.00,2520
|
25 |
+
24-Apr-17,60.10,50.00,0.01,2544
|
26 |
+
25-Apr-17,54.00,50.00,0.91,611
|
27 |
+
26-Apr-17,59.00,54.00,0.34,1247
|
28 |
+
27-Apr-17,68.00,59.00,0.00,2959
|
29 |
+
28-Apr-17,82.90,57.90,0.00,3679
|
30 |
+
29-Apr-17,84.00,64.00,0.06,3315
|
31 |
+
30-Apr-17,64.00,54.00,0.00,2225
|
32 |
+
1-May-17,72.00,50.00,0.00,3084
|
33 |
+
2-May-17,73.90,66.90,0.00,3423
|
34 |
+
3-May-17,64.90,57.90,0.00,3342
|
35 |
+
4-May-17,63.00,50.00,0.00,3019
|
36 |
+
5-May-17,59.00,52.00,3.02,513
|
37 |
+
6-May-17,64.90,57.00,0.18,1892
|
38 |
+
7-May-17,54.00,48.90,0.01,3539
|
39 |
+
8-May-17,57.00,45.00,0.00,2886
|
40 |
+
9-May-17,61.00,48.00,0.00,2718
|
41 |
+
10-May-17,70.00,51.10,0.00,2810
|
42 |
+
11-May-17,61.00,51.80,0.00,2657
|
43 |
+
12-May-17,62.10,51.10,0.00,2640
|
44 |
+
13-May-17,51.10,45.00,1.31,151
|
45 |
+
14-May-17,64.90,46.00,0.02,1452
|
46 |
+
15-May-17,66.90,55.90,0.00,2685
|
47 |
+
16-May-17,78.10,57.90,0.00,3666
|
48 |
+
17-May-17,90.00,66.00,0.00,3535
|
49 |
+
18-May-17,91.90,75.00,0.00,3190
|
50 |
+
19-May-17,90.00,75.90,0.00,2952
|
51 |
+
20-May-17,64.00,55.90,0.01,2161
|
52 |
+
21-May-17,66.90,55.00,0.00,2612
|
53 |
+
22-May-17,61.00,54.00,0.59,768
|
54 |
+
23-May-17,68.00,57.90,0.00,3174
|
55 |
+
24-May-17,66.90,57.00,0.04,2969
|
56 |
+
25-May-17,57.90,55.90,0.58,488
|
57 |
+
26-May-17,73.00,55.90,0.10,2590
|
58 |
+
27-May-17,71.10,61.00,0.00,2609
|
59 |
+
28-May-17,71.10,59.00,0.00,2640
|
60 |
+
29-May-17,57.90,55.90,0.13,836
|
61 |
+
30-May-17,59.00,55.90,0.06,2301
|
62 |
+
31-May-17,75.00,57.90,0.03,2689
|
63 |
+
1-Jun-17,78.10,62.10,0.00,3468
|
64 |
+
2-Jun-17,73.90,60.10,0.01,3271
|
65 |
+
3-Jun-17,72.00,55.00,0.01,2589
|
66 |
+
4-Jun-17,68.00,60.10,0.09,1805
|
67 |
+
5-Jun-17,66.90,60.10,0.02,2171
|
68 |
+
6-Jun-17,55.90,53.10,0.06,1193
|
69 |
+
7-Jun-17,66.90,54.00,0.00,3211
|
70 |
+
8-Jun-17,68.00,59.00,0.00,3253
|
71 |
+
9-Jun-17,80.10,59.00,0.00,3401
|
72 |
+
10-Jun-17,84.00,68.00,0.00,3066
|
73 |
+
11-Jun-17,90.00,73.00,0.00,2465
|
74 |
+
12-Jun-17,91.90,77.00,0.00,2854
|
75 |
+
13-Jun-17,93.90,78.10,0.01,2882
|
76 |
+
14-Jun-17,84.00,69.10,0.29,2596
|
77 |
+
15-Jun-17,75.00,66.00,0.00,3510
|
78 |
+
16-Jun-17,68.00,66.00,0.00,2054
|
79 |
+
17-Jun-17,73.00,66.90,1.39,1399
|
80 |
+
18-Jun-17,84.00,72.00,0.01,2199
|
81 |
+
19-Jun-17,87.10,70.00,1.35,1648
|
82 |
+
20-Jun-17,82.00,72.00,0.03,3407
|
83 |
+
21-Jun-17,82.00,72.00,0.00,3304
|
84 |
+
22-Jun-17,82.00,70.00,0.00,3368
|
85 |
+
23-Jun-17,82.90,75.90,0.04,2283
|
86 |
+
24-Jun-17,82.90,71.10,1.29,2307
|
87 |
+
25-Jun-17,82.00,69.10,0.00,2625
|
88 |
+
26-Jun-17,78.10,66.00,0.00,3386
|
89 |
+
27-Jun-17,75.90,61.00,0.18,3182
|
90 |
+
28-Jun-17,78.10,62.10,0.00,3766
|
91 |
+
29-Jun-17,81.00,68.00,0.00,3356
|
92 |
+
30-Jun-17,88.00,73.90,0.01,2687
|
93 |
+
1-Jul-17,84.90,72.00,0.23,1848
|
94 |
+
2-Jul-17,87.10,73.00,0.00,2467
|
95 |
+
3-Jul-17,87.10,71.10,0.45,2714
|
96 |
+
4-Jul-17,82.90,70.00,0.00,2296
|
97 |
+
5-Jul-17,84.90,71.10,0.00,3170
|
98 |
+
6-Jul-17,75.00,71.10,0.01,3065
|
99 |
+
7-Jul-17,79.00,68.00,1.78,1513
|
100 |
+
8-Jul-17,82.90,70.00,0.00,2718
|
101 |
+
9-Jul-17,81.00,69.10,0.00,3048
|
102 |
+
10-Jul-17,82.90,71.10,0.00,3506
|
103 |
+
11-Jul-17,84.00,75.00,0.00,2929
|
104 |
+
12-Jul-17,87.10,77.00,0.00,2860
|
105 |
+
13-Jul-17,89.10,77.00,0.00,2563
|
106 |
+
14-Jul-17,69.10,64.90,0.35,907
|
107 |
+
15-Jul-17,82.90,68.00,0.00,2853
|
108 |
+
16-Jul-17,84.90,70.00,0.00,2917
|
109 |
+
17-Jul-17,84.90,73.90,0.00,3264
|
110 |
+
18-Jul-17,87.10,75.90,0.00,3507
|
111 |
+
19-Jul-17,91.00,77.00,0.00,3114
|
112 |
+
20-Jul-17,93.00,78.10,0.01,2840
|
113 |
+
21-Jul-17,91.00,77.00,0.00,2751
|
114 |
+
22-Jul-17,91.00,78.10,0.57,2301
|
115 |
+
23-Jul-17,78.10,73.00,0.06,2321
|
116 |
+
24-Jul-17,69.10,63.00,0.74,1576
|
117 |
+
25-Jul-17,71.10,64.00,0.00,3191
|
118 |
+
26-Jul-17,75.90,66.00,0.00,3821
|
119 |
+
27-Jul-17,77.00,66.90,0.01,3287
|
120 |
+
28-Jul-17,84.90,73.00,0.00,3123
|
121 |
+
29-Jul-17,75.90,68.00,0.00,2074
|
122 |
+
30-Jul-17,81.00,64.90,0.00,3331
|
123 |
+
31-Jul-17,88.00,66.90,0.00,3560
|
124 |
+
1-Aug-17,91.00,72.00,0.00,3492
|
125 |
+
2-Aug-17,86.00,69.10,0.09,2637
|
126 |
+
3-Aug-17,86.00,70.00,0.00,3346
|
127 |
+
4-Aug-17,82.90,70.00,0.15,2400
|
128 |
+
5-Aug-17,77.00,70.00,0.30,3409
|
129 |
+
6-Aug-17,75.90,64.00,0.00,3130
|
130 |
+
7-Aug-17,71.10,64.90,0.76,804
|
131 |
+
8-Aug-17,77.00,66.00,0.00,3598
|
132 |
+
9-Aug-17,82.90,66.00,0.00,3893
|
133 |
+
10-Aug-17,82.90,69.10,0.00,3423
|
134 |
+
11-Aug-17,81.00,70.00,0.01,3148
|
135 |
+
12-Aug-17,75.90,64.90,0.11,4146
|
136 |
+
13-Aug-17,82.00,71.10,0.00,3274
|
137 |
+
14-Aug-17,80.10,70.00,0.00,3291
|
138 |
+
15-Aug-17,73.00,69.10,0.45,2149
|
139 |
+
16-Aug-17,84.90,70.00,0.00,3685
|
140 |
+
17-Aug-17,82.00,71.10,0.00,3637
|
141 |
+
18-Aug-17,81.00,73.00,0.88,1064
|
142 |
+
19-Aug-17,84.90,73.00,0.00,4693
|
143 |
+
20-Aug-17,81.00,70.00,0.00,2822
|
144 |
+
21-Aug-17,84.90,73.00,0.00,3088
|
145 |
+
22-Aug-17,88.00,75.00,0.30,2983
|
146 |
+
23-Aug-17,80.10,71.10,0.01,2994
|
147 |
+
24-Aug-17,79.00,66.00,0.00,3688
|
148 |
+
25-Aug-17,78.10,64.00,0.00,3144
|
149 |
+
26-Aug-17,77.00,62.10,0.00,2710
|
150 |
+
27-Aug-17,77.00,63.00,0.00,2676
|
151 |
+
28-Aug-17,75.00,63.00,0.00,3332
|
152 |
+
29-Aug-17,68.00,62.10,0.10,1472
|
153 |
+
30-Aug-17,75.90,61.00,0.01,3468
|
154 |
+
31-Aug-17,81.00,64.00,0.00,3279
|
155 |
+
1-Sep-17,70.00,55.00,0.00,2945
|
156 |
+
2-Sep-17,66.90,54.00,0.53,1876
|
157 |
+
3-Sep-17,69.10,60.10,0.74,1004
|
158 |
+
4-Sep-17,79.00,62.10,0.00,2866
|
159 |
+
5-Sep-17,84.00,70.00,0.01,3244
|
160 |
+
6-Sep-17,70.00,62.10,0.42,1232
|
161 |
+
7-Sep-17,71.10,59.00,0.01,3249
|
162 |
+
8-Sep-17,70.00,59.00,0.00,3234
|
163 |
+
9-Sep-17,69.10,55.00,0.00,2609
|
164 |
+
10-Sep-17,72.00,57.00,0.00,4960
|
165 |
+
11-Sep-17,75.90,55.00,0.00,3657
|
166 |
+
12-Sep-17,78.10,61.00,0.00,3497
|
167 |
+
13-Sep-17,82.00,64.90,0.06,2994
|
168 |
+
14-Sep-17,81.00,70.00,0.02,3013
|
169 |
+
15-Sep-17,81.00,66.90,0.00,3344
|
170 |
+
16-Sep-17,82.00,70.00,0.00,2560
|
171 |
+
17-Sep-17,80.10,70.00,0.00,2676
|
172 |
+
18-Sep-17,73.00,69.10,0.00,2673
|
173 |
+
19-Sep-17,78.10,69.10,0.22,2012
|
174 |
+
20-Sep-17,78.10,71.10,0.00,3296
|
175 |
+
21-Sep-17,80.10,71.10,0.00,3317
|
176 |
+
22-Sep-17,82.00,66.00,0.00,3297
|
177 |
+
23-Sep-17,86.00,68.00,0.00,2810
|
178 |
+
24-Sep-17,90.00,69.10,0.00,2543
|
179 |
+
25-Sep-17,87.10,72.00,0.00,3276
|
180 |
+
26-Sep-17,82.00,69.10,0.00,3157
|
181 |
+
27-Sep-17,84.90,71.10,0.00,3216
|
182 |
+
28-Sep-17,78.10,66.00,0.00,3421
|
183 |
+
29-Sep-17,66.90,55.00,0.00,2988
|
184 |
+
30-Sep-17,64.00,55.90,0.00,1903
|
185 |
+
1-Oct-17,66.90,50.00,0.00,2297
|
186 |
+
2-Oct-17,72.00,52.00,0.00,3387
|
187 |
+
3-Oct-17,70.00,57.00,0.00,3386
|
188 |
+
4-Oct-17,75.00,55.90,0.00,3412
|
189 |
+
5-Oct-17,82.00,64.90,0.00,3312
|
190 |
+
6-Oct-17,81.00,69.10,0.00,2982
|
191 |
+
7-Oct-17,80.10,66.00,0.00,2750
|
192 |
+
8-Oct-17,77.00,72.00,0.22,1235
|
193 |
+
9-Oct-17,75.90,72.00,0.26,898
|
194 |
+
10-Oct-17,80.10,66.00,0.00,3922
|
195 |
+
11-Oct-17,75.00,64.90,0.06,2721
|
196 |
+
12-Oct-17,63.00,55.90,0.07,2411
|
197 |
+
13-Oct-17,64.90,52.00,0.00,2839
|
198 |
+
14-Oct-17,71.10,62.10,0.08,2021
|
199 |
+
15-Oct-17,72.00,66.00,0.01,2169
|
200 |
+
16-Oct-17,60.10,52.00,0.01,2751
|
201 |
+
17-Oct-17,57.90,43.00,0.00,2869
|
202 |
+
18-Oct-17,71.10,50.00,0.00,3264
|
203 |
+
19-Oct-17,70.00,55.90,0.00,3265
|
204 |
+
20-Oct-17,73.00,57.90,0.00,3169
|
205 |
+
21-Oct-17,78.10,57.00,0.00,2538
|
206 |
+
22-Oct-17,75.90,57.00,0.00,2744
|
207 |
+
23-Oct-17,73.90,64.00,0.00,3189
|
208 |
+
24-Oct-17,73.00,66.90,0.20,954
|
209 |
+
25-Oct-17,64.90,57.90,0.00,3367
|
210 |
+
26-Oct-17,57.00,53.10,0.00,2565
|
211 |
+
27-Oct-17,62.10,48.00,0.00,3150
|
212 |
+
28-Oct-17,68.00,55.90,0.00,2245
|
213 |
+
29-Oct-17,64.90,61.00,3.03,183
|
214 |
+
30-Oct-17,55.00,46.00,0.25,1428
|
215 |
+
31-Oct-17,54.00,44.00,0.00,2727
|
assignment-2/poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assignment-2/pyproject.toml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "assignment-2"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Artur Janik <aj2614@nyu.edu>"]
|
6 |
+
readme = "README.md"
|
7 |
+
packages = [{include = "assignment_2"}]
|
8 |
+
|
9 |
+
[tool.poetry.dependencies]
|
10 |
+
python = "^3.10"
|
11 |
+
black = "^23.1.0"
|
12 |
+
jupyterlab = "^3.6.1"
|
13 |
+
ipython = "^8.10.0"
|
14 |
+
numpy = "^1.24.2"
|
15 |
+
pandas = "^1.5.3"
|
16 |
+
jax = "^0.4.4"
|
17 |
+
seaborn = "^0.12.2"
|
18 |
+
matplotlib = "^3.7.0"
|
19 |
+
|
20 |
+
|
21 |
+
[build-system]
|
22 |
+
requires = ["poetry-core"]
|
23 |
+
build-backend = "poetry.core.masonry.api"
|
assignment-2/tests/__init__.py
ADDED
File without changes
|