Spaces:
Runtime error
Runtime error
carlfeynman
commited on
Commit
β’
338bbe8
1
Parent(s):
6c5990e
removed batchnorm2d
Browse files- README.md +1 -1
- __pycache__/mnist_classifier.cpython-39.pyc +0 -0
- __pycache__/server.cpython-39.pyc +0 -0
- classifier.pth +0 -0
- mnist_classifier.ipynb +236 -67
- mnist_classifier.py +18 -8
- server.py +1 -0
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# <img src="/static/favicon.png" alt="Logo" style="float: left; margin-right: 10px; border-radius:100%;margin-top:5px" /> MNIST CLASSIFIER
|
2 |
MNIST classifier from scratch
|
3 |
* Model: CNN
|
4 |
-
* Accuracy:
|
5 |
|
6 |
* Training Notebook: mnist_classifier.ipynb
|
7 |
* Cleaned Python Inference Version: mnist_classifier.py
|
|
|
1 |
# <img src="/static/favicon.png" alt="Logo" style="float: left; margin-right: 10px; border-radius:100%;margin-top:5px" /> MNIST CLASSIFIER
|
2 |
MNIST classifier from scratch
|
3 |
* Model: CNN
|
4 |
+
* Accuracy: 94%
|
5 |
|
6 |
* Training Notebook: mnist_classifier.ipynb
|
7 |
* Cleaned Python Inference Version: mnist_classifier.py
|
__pycache__/mnist_classifier.cpython-39.pyc
CHANGED
Binary files a/__pycache__/mnist_classifier.cpython-39.pyc and b/__pycache__/mnist_classifier.cpython-39.pyc differ
|
|
__pycache__/server.cpython-39.pyc
CHANGED
Binary files a/__pycache__/server.cpython-39.pyc and b/__pycache__/server.cpython-39.pyc differ
|
|
classifier.pth
CHANGED
Binary files a/classifier.pth and b/classifier.pth differ
|
|
mnist_classifier.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
@@ -15,13 +15,12 @@
|
|
15 |
"import matplotlib as mpl\n",
|
16 |
"import torchvision.transforms.functional as TF\n",
|
17 |
"from torch.utils.data import default_collate, DataLoader\n",
|
18 |
-
"import torch.optim as optim\n"
|
19 |
-
"import pickle\n"
|
20 |
]
|
21 |
},
|
22 |
{
|
23 |
"cell_type": "code",
|
24 |
-
"execution_count":
|
25 |
"metadata": {
|
26 |
"tags": [
|
27 |
"exclude"
|
@@ -35,7 +34,7 @@
|
|
35 |
},
|
36 |
{
|
37 |
"cell_type": "code",
|
38 |
-
"execution_count":
|
39 |
"metadata": {
|
40 |
"tags": [
|
41 |
"exclude"
|
@@ -47,7 +46,7 @@
|
|
47 |
"output_type": "stream",
|
48 |
"text": [
|
49 |
"Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
|
50 |
-
"100%|ββββββββββ| 2/2 [00:00<00:00,
|
51 |
]
|
52 |
}
|
53 |
],
|
@@ -59,7 +58,7 @@
|
|
59 |
},
|
60 |
{
|
61 |
"cell_type": "code",
|
62 |
-
"execution_count":
|
63 |
"metadata": {},
|
64 |
"outputs": [],
|
65 |
"source": [
|
@@ -70,40 +69,42 @@
|
|
70 |
},
|
71 |
{
|
72 |
"cell_type": "code",
|
73 |
-
"execution_count":
|
74 |
"metadata": {
|
75 |
"tags": [
|
76 |
"exclude"
|
77 |
]
|
78 |
},
|
79 |
-
"outputs": [],
|
80 |
-
"source": [
|
81 |
-
"dst = ds.with_transform(transform_ds)\n",
|
82 |
-
"plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
|
83 |
-
]
|
84 |
-
},
|
85 |
-
{
|
86 |
-
"cell_type": "code",
|
87 |
-
"execution_count": 103,
|
88 |
-
"metadata": {},
|
89 |
"outputs": [
|
90 |
{
|
91 |
"data": {
|
|
|
92 |
"text/plain": [
|
93 |
-
"
|
94 |
]
|
95 |
},
|
96 |
-
"
|
97 |
-
|
98 |
-
|
|
|
99 |
}
|
100 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
"source": [
|
102 |
"bs = 1024\n",
|
103 |
"class DataLoaders:\n",
|
104 |
" def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
|
105 |
" self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
|
106 |
-
" self.valid = DataLoader(valid_ds, batch_size=bs
|
107 |
"\n",
|
108 |
"def collate_fn(b):\n",
|
109 |
" collate = default_collate(b)\n",
|
@@ -112,13 +113,24 @@
|
|
112 |
},
|
113 |
{
|
114 |
"cell_type": "code",
|
115 |
-
"execution_count":
|
116 |
"metadata": {
|
117 |
"tags": [
|
118 |
"exclude"
|
119 |
]
|
120 |
},
|
121 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
"source": [
|
123 |
"dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
|
124 |
"xb,yb = next(iter(dls.train))\n",
|
@@ -127,7 +139,7 @@
|
|
127 |
},
|
128 |
{
|
129 |
"cell_type": "code",
|
130 |
-
"execution_count":
|
131 |
"metadata": {},
|
132 |
"outputs": [],
|
133 |
"source": [
|
@@ -142,7 +154,7 @@
|
|
142 |
},
|
143 |
{
|
144 |
"cell_type": "code",
|
145 |
-
"execution_count":
|
146 |
"metadata": {},
|
147 |
"outputs": [],
|
148 |
"source": [
|
@@ -174,17 +186,28 @@
|
|
174 |
},
|
175 |
{
|
176 |
"cell_type": "code",
|
177 |
-
"execution_count":
|
178 |
"metadata": {},
|
179 |
"outputs": [],
|
180 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
"def cnn_classifier():\n",
|
182 |
" return nn.Sequential(\n",
|
183 |
-
" ResBlock(1, 8,
|
184 |
-
" ResBlock(8, 16,
|
185 |
-
" ResBlock(16, 32,
|
186 |
-
" ResBlock(32, 64,
|
187 |
-
" ResBlock(64, 64,
|
188 |
" conv(64, 10, act=False),\n",
|
189 |
" nn.Flatten(),\n",
|
190 |
" )"
|
@@ -192,7 +215,7 @@
|
|
192 |
},
|
193 |
{
|
194 |
"cell_type": "code",
|
195 |
-
"execution_count":
|
196 |
"metadata": {},
|
197 |
"outputs": [],
|
198 |
"source": [
|
@@ -203,7 +226,7 @@
|
|
203 |
},
|
204 |
{
|
205 |
"cell_type": "code",
|
206 |
-
"execution_count":
|
207 |
"metadata": {
|
208 |
"tags": [
|
209 |
"exclude"
|
@@ -214,16 +237,16 @@
|
|
214 |
"name": "stdout",
|
215 |
"output_type": "stream",
|
216 |
"text": [
|
217 |
-
"train, epoch:1, loss:
|
218 |
-
"eval, epoch:1, loss: 0.
|
219 |
-
"train, epoch:2, loss: 0.
|
220 |
-
"eval, epoch:2, loss: 0.
|
221 |
-
"train, epoch:3, loss: 0.
|
222 |
-
"eval, epoch:3, loss: 0.
|
223 |
-
"train, epoch:4, loss: 0.
|
224 |
-
"eval, epoch:4, loss: 0.
|
225 |
-
"train, epoch:5, loss: 0.
|
226 |
-
"eval, epoch:5, loss: 0.
|
227 |
]
|
228 |
}
|
229 |
],
|
@@ -238,6 +261,7 @@
|
|
238 |
"for epoch in range(epochs):\n",
|
239 |
" for train in (True, False):\n",
|
240 |
" accuracy = 0\n",
|
|
|
241 |
" dl = dls.train if train else dls.valid\n",
|
242 |
" for xb,yb in dl:\n",
|
243 |
" preds = model(xb)\n",
|
@@ -248,15 +272,90 @@
|
|
248 |
" opt.zero_grad()\n",
|
249 |
" with torch.no_grad():\n",
|
250 |
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
|
|
|
251 |
" if train:\n",
|
252 |
" sched.step()\n",
|
253 |
" accuracy /= len(dl)\n",
|
254 |
-
"
|
|
|
255 |
]
|
256 |
},
|
257 |
{
|
258 |
"cell_type": "code",
|
259 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
"metadata": {
|
261 |
"tags": [
|
262 |
"exclude"
|
@@ -269,18 +368,88 @@
|
|
269 |
},
|
270 |
{
|
271 |
"cell_type": "code",
|
272 |
-
"execution_count":
|
273 |
"metadata": {},
|
274 |
"outputs": [],
|
275 |
"source": [
|
276 |
"loaded_model = cnn_classifier()\n",
|
277 |
-
"loaded_model.load_state_dict(torch.load('classifier.pth'))
|
278 |
"loaded_model.eval();"
|
279 |
]
|
280 |
},
|
281 |
{
|
282 |
"cell_type": "code",
|
283 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
"metadata": {},
|
285 |
"outputs": [],
|
286 |
"source": [
|
@@ -296,7 +465,7 @@
|
|
296 |
},
|
297 |
{
|
298 |
"cell_type": "code",
|
299 |
-
"execution_count":
|
300 |
"metadata": {
|
301 |
"tags": [
|
302 |
"exclude"
|
@@ -307,32 +476,32 @@
|
|
307 |
"name": "stdout",
|
308 |
"output_type": "stream",
|
309 |
"text": [
|
310 |
-
"tensor(
|
311 |
]
|
312 |
},
|
313 |
{
|
314 |
"data": {
|
315 |
"text/plain": [
|
316 |
-
"[{'digit': 0, 'prob': '
|
317 |
-
" {'digit':
|
318 |
-
" {'digit':
|
319 |
-
" {'digit':
|
320 |
-
" {'digit':
|
321 |
-
" {'digit':
|
322 |
-
" {'digit':
|
323 |
-
" {'digit':
|
324 |
-
" {'digit':
|
325 |
-
" {'digit':
|
326 |
]
|
327 |
},
|
328 |
-
"execution_count":
|
329 |
"metadata": {},
|
330 |
"output_type": "execute_result"
|
331 |
}
|
332 |
],
|
333 |
"source": [
|
334 |
-
"img = xb[
|
335 |
-
"print(yb[
|
336 |
"predict(img)"
|
337 |
]
|
338 |
},
|
@@ -349,7 +518,7 @@
|
|
349 |
},
|
350 |
{
|
351 |
"cell_type": "code",
|
352 |
-
"execution_count":
|
353 |
"metadata": {
|
354 |
"tags": [
|
355 |
"exclude"
|
@@ -361,7 +530,7 @@
|
|
361 |
"output_type": "stream",
|
362 |
"text": [
|
363 |
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
|
364 |
-
"[NbConvertApp] Writing
|
365 |
]
|
366 |
}
|
367 |
],
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
|
|
15 |
"import matplotlib as mpl\n",
|
16 |
"import torchvision.transforms.functional as TF\n",
|
17 |
"from torch.utils.data import default_collate, DataLoader\n",
|
18 |
+
"import torch.optim as optim\n"
|
|
|
19 |
]
|
20 |
},
|
21 |
{
|
22 |
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
"metadata": {
|
25 |
"tags": [
|
26 |
"exclude"
|
|
|
34 |
},
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
+
"execution_count": 3,
|
38 |
"metadata": {
|
39 |
"tags": [
|
40 |
"exclude"
|
|
|
46 |
"output_type": "stream",
|
47 |
"text": [
|
48 |
"Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
|
49 |
+
"100%|ββββββββββ| 2/2 [00:00<00:00, 112.23it/s]\n"
|
50 |
]
|
51 |
}
|
52 |
],
|
|
|
58 |
},
|
59 |
{
|
60 |
"cell_type": "code",
|
61 |
+
"execution_count": 4,
|
62 |
"metadata": {},
|
63 |
"outputs": [],
|
64 |
"source": [
|
|
|
69 |
},
|
70 |
{
|
71 |
"cell_type": "code",
|
72 |
+
"execution_count": 5,
|
73 |
"metadata": {
|
74 |
"tags": [
|
75 |
"exclude"
|
76 |
]
|
77 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
"outputs": [
|
79 |
{
|
80 |
"data": {
|
81 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=",
|
82 |
"text/plain": [
|
83 |
+
"<Figure size 144x144 with 1 Axes>"
|
84 |
]
|
85 |
},
|
86 |
+
"metadata": {
|
87 |
+
"needs_background": "light"
|
88 |
+
},
|
89 |
+
"output_type": "display_data"
|
90 |
}
|
91 |
],
|
92 |
+
"source": [
|
93 |
+
"dst = ds.with_transform(transform_ds)\n",
|
94 |
+
"plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": 6,
|
100 |
+
"metadata": {},
|
101 |
+
"outputs": [],
|
102 |
"source": [
|
103 |
"bs = 1024\n",
|
104 |
"class DataLoaders:\n",
|
105 |
" def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
|
106 |
" self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
|
107 |
+
" self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
|
108 |
"\n",
|
109 |
"def collate_fn(b):\n",
|
110 |
" collate = default_collate(b)\n",
|
|
|
113 |
},
|
114 |
{
|
115 |
"cell_type": "code",
|
116 |
+
"execution_count": 7,
|
117 |
"metadata": {
|
118 |
"tags": [
|
119 |
"exclude"
|
120 |
]
|
121 |
},
|
122 |
+
"outputs": [
|
123 |
+
{
|
124 |
+
"data": {
|
125 |
+
"text/plain": [
|
126 |
+
"(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
"execution_count": 7,
|
130 |
+
"metadata": {},
|
131 |
+
"output_type": "execute_result"
|
132 |
+
}
|
133 |
+
],
|
134 |
"source": [
|
135 |
"dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
|
136 |
"xb,yb = next(iter(dls.train))\n",
|
|
|
139 |
},
|
140 |
{
|
141 |
"cell_type": "code",
|
142 |
+
"execution_count": 147,
|
143 |
"metadata": {},
|
144 |
"outputs": [],
|
145 |
"source": [
|
|
|
154 |
},
|
155 |
{
|
156 |
"cell_type": "code",
|
157 |
+
"execution_count": 148,
|
158 |
"metadata": {},
|
159 |
"outputs": [],
|
160 |
"source": [
|
|
|
186 |
},
|
187 |
{
|
188 |
"cell_type": "code",
|
189 |
+
"execution_count": 149,
|
190 |
"metadata": {},
|
191 |
"outputs": [],
|
192 |
"source": [
|
193 |
+
"# def cnn_classifier():\n",
|
194 |
+
"# return nn.Sequential(\n",
|
195 |
+
"# ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n",
|
196 |
+
"# ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n",
|
197 |
+
"# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
|
198 |
+
"# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
|
199 |
+
"# ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n",
|
200 |
+
"# conv(64, 10, act=False),\n",
|
201 |
+
"# nn.Flatten(),\n",
|
202 |
+
"# )\n",
|
203 |
+
"\n",
|
204 |
"def cnn_classifier():\n",
|
205 |
" return nn.Sequential(\n",
|
206 |
+
" ResBlock(1, 8,),\n",
|
207 |
+
" ResBlock(8, 16, ),\n",
|
208 |
+
" ResBlock(16, 32,),\n",
|
209 |
+
" ResBlock(32, 64, ),\n",
|
210 |
+
" ResBlock(64, 64,),\n",
|
211 |
" conv(64, 10, act=False),\n",
|
212 |
" nn.Flatten(),\n",
|
213 |
" )"
|
|
|
215 |
},
|
216 |
{
|
217 |
"cell_type": "code",
|
218 |
+
"execution_count": 150,
|
219 |
"metadata": {},
|
220 |
"outputs": [],
|
221 |
"source": [
|
|
|
226 |
},
|
227 |
{
|
228 |
"cell_type": "code",
|
229 |
+
"execution_count": 151,
|
230 |
"metadata": {
|
231 |
"tags": [
|
232 |
"exclude"
|
|
|
237 |
"name": "stdout",
|
238 |
"output_type": "stream",
|
239 |
"text": [
|
240 |
+
"train, epoch:1, loss: 1.3684, accuracy: 0.5153\n",
|
241 |
+
"eval, epoch:1, loss: 0.4238, accuracy: 0.8648\n",
|
242 |
+
"train, epoch:2, loss: 0.2660, accuracy: 0.9162\n",
|
243 |
+
"eval, epoch:2, loss: 0.1468, accuracy: 0.9552\n",
|
244 |
+
"train, epoch:3, loss: 0.1479, accuracy: 0.9545\n",
|
245 |
+
"eval, epoch:3, loss: 0.1101, accuracy: 0.9647\n",
|
246 |
+
"train, epoch:4, loss: 0.1149, accuracy: 0.9650\n",
|
247 |
+
"eval, epoch:4, loss: 0.0997, accuracy: 0.9705\n",
|
248 |
+
"train, epoch:5, loss: 0.2118, accuracy: 0.9399\n",
|
249 |
+
"eval, epoch:5, loss: 0.1625, accuracy: 0.9478\n"
|
250 |
]
|
251 |
}
|
252 |
],
|
|
|
261 |
"for epoch in range(epochs):\n",
|
262 |
" for train in (True, False):\n",
|
263 |
" accuracy = 0\n",
|
264 |
+
" total_loss = 0\n",
|
265 |
" dl = dls.train if train else dls.valid\n",
|
266 |
" for xb,yb in dl:\n",
|
267 |
" preds = model(xb)\n",
|
|
|
272 |
" opt.zero_grad()\n",
|
273 |
" with torch.no_grad():\n",
|
274 |
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
|
275 |
+
" total_loss += loss.item()\n",
|
276 |
" if train:\n",
|
277 |
" sched.step()\n",
|
278 |
" accuracy /= len(dl)\n",
|
279 |
+
" total_loss /= len(dl)\n",
|
280 |
+
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
|
281 |
]
|
282 |
},
|
283 |
{
|
284 |
"cell_type": "code",
|
285 |
+
"execution_count": 152,
|
286 |
+
"metadata": {
|
287 |
+
"tags": [
|
288 |
+
"exclude"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
"outputs": [
|
292 |
+
{
|
293 |
+
"name": "stdout",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"eval, epoch:1, loss: 0.1625, accuracy: 0.9478\n",
|
297 |
+
"eval, epoch:2, loss: 0.1625, accuracy: 0.9478\n",
|
298 |
+
"eval, epoch:3, loss: 0.1625, accuracy: 0.9478\n",
|
299 |
+
"eval, epoch:4, loss: 0.1625, accuracy: 0.9478\n",
|
300 |
+
"eval, epoch:5, loss: 0.1625, accuracy: 0.9478\n"
|
301 |
+
]
|
302 |
+
}
|
303 |
+
],
|
304 |
+
"source": [
|
305 |
+
"for epoch in range(epochs):\n",
|
306 |
+
" train = False\n",
|
307 |
+
" accuracy = 0\n",
|
308 |
+
" total_loss = 0\n",
|
309 |
+
" dl = dls.valid\n",
|
310 |
+
" for xb,yb in dl:\n",
|
311 |
+
" preds = model(xb)\n",
|
312 |
+
" loss = F.cross_entropy(preds, yb)\n",
|
313 |
+
" with torch.no_grad():\n",
|
314 |
+
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
|
315 |
+
" total_loss += loss.item()\n",
|
316 |
+
" accuracy /= len(dl)\n",
|
317 |
+
" total_loss /= len(dl)\n",
|
318 |
+
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"execution_count": 153,
|
324 |
+
"metadata": {
|
325 |
+
"tags": [
|
326 |
+
"exclude"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
"outputs": [
|
330 |
+
{
|
331 |
+
"data": {
|
332 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZ7UlEQVR4nO3dd5hURboG8PfrCQw5wwgMmSEpYBYRxF1XDKiYBVcERZdVUbyoXF12r3K9qyurrmAEV8QArGJEdFEXAxJWZUUQyQKSk8QhzUzX/aN7qk61c3rO9Mx0Ou/veXj4qqtO6KnpM9VVdeqIUgpEREREfhFI9AkQERERxRMbP0REROQrbPwQERGRr7DxQ0RERL7Cxg8RERH5Chs/RERE5CsJb/yIyHoROTfR5xEvIvKAiLya6POoKqzP9MG6TC+sz/TBuqy4hDd+ykNC/iIiu8P/HhURqeJjDhGRL6vyGBHH6yEii0TkUPj/HvE6dryJyEgR+VFE9ovIFhF5QkQyq/iYcatPEWkkIvPCv6t7RWSBiPSKx7HjTUQ+FJGDjn/HRGRpFR8znnXZO+L9HRQRJSJXxOP48cZrbfrgdbZ0ldr4qeofKIBbAAwA0B1ANwD9Afyuio8ZNyKSDeBdAK8CqA9gCoB3w68n4nyquj5nAjhJKVUHwPEI1esdVXzMeDoI4EYAjRGqz78AmBmHn+svVPUxlVIXKKVqlfwDMB/AG1V5zHhSSs2NeH/9EarffybifHitrZhkutbyOlthMV1ny2z8hLvX7hORH0Rkj4hMFpGccF5fEdkkIqNFZBuAySISEJH/FpG14ZbY6yLSwLG/60VkQzjvD+V8kzcAeEwptUkptRnAYwCGeNlQRNqJyJzwcXeJyGsiUs+Rnycib4nIznCZp0SkM4DnAPQMf9PbGy77mYgMc2xrtXJF5EkR2RhuaS8Skd4e319fAJkA/qaUOqqUGg9AAPzK4/ZlSqb6VEqtVUrtLdkVgCCA9h7fR9LXp1LqiFJqpVIqGH5/xQh9OBtE39KbZKrLiPNqDaA3gFc8lk/6uizFDQBmKKUKYtz+F5KsPnmtrYBkqkteZ0vntefnOgD9ALQDkA9gjCMvN3yQVgh9W7gDoW8MZwNoBmAPgKfDb64LgGcBXB/OawigRcmOROSskh+Ui64AvnOkvwu/5oUAeDh83M4A8gA8ED5uBoD3AWwA0BpAcwDTlVLLAQwHsCD8ja+ex2N9DaAHQj+XqQDeKPnF/8VJiSwRkUHhZFcAS5T9zJEl8P4evUqW+oSIDBKR/QB2IfSN5HmP7yEV6lO/BuAIgPcAvKCU2uHxuF4kTV06DAYwVym1zmP5lKnL8Os1AFyJUG9BZUuW+uS1tuKSpS55nS2NUirqPwDrAQx3pC8EsDYc9wVwDECOI385gF870scBKESolf2n8A+nJK9mePtzyzqPcPliAJ0c6Q4AFADxsn3EvgYA+DYc9wSwE0BmKeWGAPgy4rXPAAyLViai/B4A3cPxAwBedSn3R+fPJ/zaawAeKO/7S4X6jDivDgD+F0BujO8r6eozYpscAAMB3OCDulwDYEgF3ley1+X1ANYhhutOqtQneK1Nm7qMOC9eZ8P/vI41bnTEGxBqBZbYqZQ64ki3AvC2iAQdrxUDaBreTu9LKVUgIrs9ngMQGtur40jXAXBQhd91NCLSBMB4hLrjayPU67UnnJ0HYINSqqgc5xLtWKMADEPo/arweTbysGnk+0M4faAyzsshWepTU0qtFpFlAJ4BcHlZ5VOkPrXwz3SaiCwXkcVKqe/K3MibpKpLETkLoW+1M8qxTUrVJUJDQi97ue7EIFnqk9faikuWutR4nTW8DnvlOeKWALY4jxdRdiOAC5RS9Rz/clRo3Hirc1/h7uOGHs8BAJYh1GVXonv4NS8eDp9rNxWa+PVbhLr0Ss65pZQ+Qaq0D3sBgBqOdG5JEB6nHA3gagD1VajLb5/jWNEsA9BNxLqrohu8v0evkqU+I2Ui1EXsRSrUZ2myALSNcdvSJFtd3gDgLaXUwXJskzJ1KSJ5CH1zf9nrNuWULPXJa23FJUtdRuJ1Ft4bP7eJSIvwBKz7AfwjStnnAPyfiLQCABFpLCKXhvNmAOgfHqPMBjC2HOcAhC44/yUizUWkGYBRAF4qyQxPqHrAZdvaCLX294pIcwD3OPK+QugX7BERqSkiOWJuldsOoIXYdwEsBnC5iNQQkfYAboo4ThHC3YEi8if88huGm88Qau3fISLVROT28OtzPG7vVVLUp4gMC3+zKBnXvg/Avxz5KV2fInJGyc9GRKqLyGiEvsn928v2HiVFXYb3Vx3AVXB8Jh15KV2XDtcDmK+UWlvO7bxKlvrktbbikqIueZ0tndcf4FQAHwH4MfzvoShln0RowtFHInIAwEIApwOAUmoZgNvC+9uKUPfZJseb6C0i0b4xPo/QbXtLAXwPYBbsiVt5AOa5bPsggJMQak3OAvBWSYZSqhjAxQjNgP8pfE7XhLPnIPRtYJuI7Aq/9gRCY67bEZr0+JrjOLMBfAhgFUJdnUdgd39aRGSZiFwXPo9jCI2pDgawF6Hb9waEX69MyVKfvQAsFZECAB+E/93vyE/p+gRQDaFJi7sBbEZo3P8ipdQWt+1jkCx1CYR+d/cB+LSUvFSvyxKDUTUTnUskS33yWltxyVKXvM6Wto+yhnBFZD1Ck5Q+iVowwUSkBYA3lFI9E30uyYz1mT5Yl+mF9Zk+WJfJL+6LrVUVpdQmhGafUxpgfaYP1mV6YX2mDz/XZUo93oKIiIioosoc9iIiIiJKJ+z5ISIiIl9h44eIiIh8xdOE598EruLYWIJ9HHwj1sWeLKzLxKusugRYn8mAn830wc9meolWn+z5ISIiIl9h44eIiIh8hY0fIiIi8hU2foiIiMhX2PghIiIiX2Hjh4iIiHyFjR8iIiLyFTZ+iIiIyFfY+CEiIiJfYeOHiIiIfIWNHyIiIvIVNn6IiIjIVzw92DQVZXTJ1/HagQ2tvNZnbtTxqnW5Vt6aCyaafYjdNixWQR13/PxGK6/doMUxnyvFTk7uquP97WtbeVv7mvqqlXvQylPKPO+u+sw6Vl6DyQsq8xSJ0ouYz87O352h41vueM8q9quaq3Tc78O7rLz8WxeZRLC4kk+Qqppk2k2HQF1zDVXHCq08deyYKVerpus+1eEjVjp46FBFTrFM7PkhIiIiX2Hjh4iIiHwlbYe9NgxopOPvb5zgXrCTnQw6Y+XeHbuw99NWenCH63VcvPpHT+dIhurZ3UrvPNF0j9a4ZJuOH85/yyrXOvNLHR+XUd11/wGIlQ5C6XjTyYetvMur36vjJs/Mj3baROnJMbR19MJTrKzc+9fqeHabZ6PsxHyG110y0coZ2qO3jreOaG3lqa+XluNEKRF+HHuqlV56w3gdv32wiZX38paeJi//Xdd9dvrn7610/k3fVOQUy8SeHyIiIvIVNn6IiIjIV9j4ISIiIl9Jqzk/ewebscXZwx915LjPBYlV3UCOlT7zzR90PP/yzlZe8Zp1lX78dHP/q69Y6V45haWW++XcHVO3m4rsuTsbiuxb2J1aZe7XcYtM+/dj8j1P6Hj0M6e77oMopQUydFjcx55zd/Be8/lY2GOSledc8uOnInM78opj9a1yCwo66Ph/Gv9g5U1uOVfHPcfac0TqXWaurcEj9u3PlDiZrVvqeOwV013LXVZrh52OMs/H6dzjl1vpn8pxbrFgzw8RERH5Chs/RERE5CspPey1c3hPK/3FGDNcUU0qf6grmtENl+n49ql29+/60+J6Kmnt1s29rPQXs07Ucav391l5apGpk4z2bay84ESz6ujMjvbKtF2zU/pjETeZbVpZ6S0XNfe0XUGvAh0v7zPZtVy0FdYj8zrPM0tNtL5+tZXHoZPSZbQz9Td76ouu5cbsOMFKf/z4WTqu97L7augZjczK+ie9aO/jP6f8Q8cLur9p5Z34j2t13OTSFa77p/jaNKCFjiOHtirDUy0+s9KX4NTSC1YS9vwQERGRr7DxQ0RERL7Cxg8RERH5SspNbthx65k6nnLv41ZeNcmK9+mUqnpG6bdpk7txvX5jpUde3a7Uck3H24+baAmTVhFli351so7rjl1v5U1r87GO9wTtOSHnj71bxw2Rnk94D9SoYeI6tT1vt26YqZfzLvvKynsn963I4mUKRsuL8niZyLwlZ76k40trn28X5pyfUm1/3P3y3295fx1nXGMvIVFvl7fPRPGu3TrOvcn+nn3TTDNv6O8tv7TyXu1u5oGN7H2rlReY+62nY1Pla/6SmUPZ9bSbrbzebc0jTzrV3GbljWxgL3Pg5v2ChmUXqkTs+SEiIiJfYeOHiIiIfCXlhr2m3vtXHbfPqlbp+++xcLCOc+sesPL+2fntSj8ehRRt226lm47f7lLSFqhthmw232zfTjvnrnE6rh+wlz74+LBJjx1jd603nJ5+Q13OYS4A2PW6uW11/onT4n06lACBmjWt9IgOn7mWvaPVJzoev6tThY9dvHOnlV43xgxJF0753Mrrmm0+mztG2UOWuXNBCVK81ywl0m7QYitvw7mmPneOqWXleR32enPnyRGv7CnX+ZUXe36IiIjIV9j4ISIiIl9h44eIiIh8Jenn/Kwebz9Vu23W1+Xex/CNZ1vpb6Z103Hz11ZaeXkH1uh43R9PsndkP6zd1ZL7eljpLHzjbUPSMps3M4lss4TB9l83s8rVuMrcVrno+AkRe8mBmzsXXaPj1tMXxnaSKWT1g/ZTu5ef+FSlH2PRUROvLWziXjCKTtlbddwtOyNKSdu3xxw3zQfdb5H3M2nW1EoPqTMvQWcCZH2ySMe3bupj5U3KM+d1Uu4mK2+L40n0rGd3GR3bW+ntZzfWcZPXllh5wa5tdfzz8Wa+TqM3vrfK7Rh4vI4PNxEr781hZi5u2yzvS870W3aVjve/fZyV17iKlxlhzw8RERH5Chs/RERE5CtJOezl7LJ75PzpVl7AY3vtwhUDdJw9xF77N3ejWRU4suM0o0u+jkdd8a6nY0XKLCiKaTu/cT5pffl/26t7zv7Nkzpuk2mGrwKwu1uDv1jX2Zvin2qWXSiNdBhrd2FfdqpZwfftDu9bec5h4s8WHg+v2rx7TMcZn/6nvKcIAFj/0NU6/n6o96G5oZPu1HGL3fOjlPSv5fd4X0H33u+u0HELLItSsuL+taKj/YJj2GtyS/ve9gvrnqPj4j1Veyt0qtn4R/P0gwlDnrfyzsoxSwbMGJVr5XXI/kzH3bPN6zNG2+WurDUnytG9DXVN3tfaSh+dbI7ReFp8lxhhzw8RERH5Chs/RERE5Cts/BAREZGvJOWcnzVDzG15l9X82fN216w1T3POumK/joscy3KXZdc4M4dkaJ2Nnrej8pv5+ZtRcqtHyTMi5wA5/euwefzJr6sftfJWDnpax91+vt3Ka/Fw+s0ZCR6wH9USuNY87uLivCFWXsY2M5ei/caqXQZAqtmPqDmWW1ilx/ObjHp1dTyn3xMRubXgRr6q65pHySGjQ1srfe91M3TsnOMT6cpa21zznHNqo5Urj65T79Bx/qNrrLw6OxO3zAh7foiIiMhX2PghIiIiX0mKYa/I1SiHXzzb03YfHqptpY9cYm61LS7HUFcsCpV9k/zJk0bquNVX9orOsd2M7S/RblmfdsCsTPvEhKutvGazNkUW11TBYR2PeamelTevh1lCobBbgdfTTBtF27abhDMGEM+FGjKa26u6rrrgeZeSFJMMsyJymyz3Ya5FR49Z6ZbPm9vbuY5yclpzo71i98DamxN0JsD7BWYZhUnXXWrltf/WPJWhuCh5loFhzw8RERH5Chs/RERE5CtJMey1/kr7IYgj6q/2tN0902+w0q33ln+FyOK+9sNLH+r4iqft9gUjuokfNHcIcZjLm96jbtXx1r5BKy/vQxPX+tLcIdAkYvVer52oAXG/e6XwSFJ8DIgS5kDQfghwVU8boIqTiPFI51SMauL94aJOWWKGSQvL8YfszyvNndZND9l/G5NpqMuJPT9ERETkK2z8EBERka+w8UNERES+krDJDvsHnqHjr37/eESu+3jlhD0ddNx2xl4rL4jyO9Q020qfU919ZUyqXLWnL3TE7uVivdU20K2Tjud1n2rlOX9X2v+dN/Mmyorbc8suVIqX9ze30i0+OeBS0t+K95i5O23eucXKWzdgoo5rB+zrXkb9+o59JO7p6S/tt+eDqsOHXUr6T+sx9hzXvuvv0vHPfewV7QM7zN+57L1R+jwcC+Y37WPfOj+l42smL8NemX3+SSbvvDa3WXk5y5CU2PNDREREvsLGDxEREflKwoa9tvczDzAsz215T390no7bf+ftoWiBGjWs9Iqnuuj41b7PeD6204TdZ8a0HcXP6tHuD0eddcjc+p69ZquVl5w3Zqan009fGdN2D83rb6XzI1ZVp7CgGdLtMnaDnTfAhCdXs4f/VzyYr+MOd/y70k8ro2EDHU/oNdW13DvbT7TSwSPbXUpSwxcWOOLK3//lN92j43ljx7uWO9Igw0rnuJRLNPb8EBERka+w8UNERES+wsYPERER+Urc5vxkdO1opf90xsyY9tPxz+bRF9FuUJZq5la8lY+cYOWt6hfbPB+nD148y0o3xXyXkhQvO26152Et7DPOkbLn//xh8mAdt9jGuosnyTLzSzIDsS0z0PkJ+9Z2LlZQuWrmVe3SAYfOaKfji2rMcS23YW99K90E6TfnRzLtP8NFZ3VzLZtZYObKqq+XVtk5AYCcav/dbPLbDS4lbYeaipWuV1knVMnY80NERES+wsYPERER+Urchr0KG9i3m19Xe6tLyei2X57vmnewpYkl/6COV/aq+DAXAHx4qLaOG33PlaATIaNxYyv949NmheBPz3jUyqsbMENdd27paeXljTO3Rpfj4cVUCTbfeYqOZ7ackMAzITePnPCWjsejU5SSVavoywZlF0pxm0eeZqW/uetJ17JrCs1CHIP+NsrKq7fW2yIdR3//s5VuUP2QjgNiroaDm71nlbu05i7Xff7tZ7N8TMtX1lp5ybp0CHt+iIiIyFfY+CEiIiJfYeOHiIiIfCV+j7d4wH28sDxe/4O5fTnLvqMOx2W4P86gMjx21291nPPpV1V6rEQ40t+MPefsipjTtHBJXM8lUNvMr9owuZWOJ5w4zSrXJ+eYI+Ve/wsnnmSlGxYucClJVSGztZmQN/zG2Ja5WFfk+J0s4s3t5RXcb9++fvPGXjqelDfPyjuveoGOR/7Fni/XdnTFPzs7hro/nX3ivmY6zptsP/4kHWu9++U/eC7bPsv8yf7qHve5QbEKOPpDggi6lnPOPQKAGY+fq+MG21Lj2sqeHyIiIvIVNn6IiIjIV+I27LV5Tp79Qox3T7bMrNqhrVWFZhjlusfsWwmbLTRdsOnY/fr5xIk6LlZ2l2f794bruMOrx6w8mbdYx86nNQPA3nPdlyY4cO1+HefWsbvkP+j0btknHDq6jt4ssFeDfXTcIB03nJQaXbHpau9z5lJzS931Me3joul367jtStZneQUPHbLSW65soeNlc+1hqK7Z5jr70bXjrLz+e+/VcYuHva2OvuVue/X1pT2fcqTs7+DPruqj4ya7Vnjafyr7eWgjKz1mmpl+8FDT5JxeMXDxjVb6uMmp93lkzw8RERH5Chs/RERE5Cts/BAREZGvxG3Oz5FmyTlL5r2IeSLP3XyFjpt+bo9nJ+c7qDxLjplbiTtnZVl5qy55Vse7LrLnB6wuqqXjmmLPB+qW/bHr8QKO+TrBiIdMuN1kednq/lZ635PmFuraX6y28hruTr1xaDKGbzzbSuc/u1nHybpkfiop2rhJxyNuGWHlTZhoHjvSNbuWlbfkdjNfZ+4w8yfktsUDrXJN6phHDC3u8pSVlyHme7fzugMAjcbllHnu6aR49TorvfjOHjqeMtF+iv0Ndbw9Wb0yjNpylpWetai7jjvfZ19rU/FvI3t+iIiIyFfY+CEiIiJfiduwV/4LB630vPPNsEqvnMIqPfbf97W00n+ddYmOO0zZY+UFvv+2Ss8lmd35u9t1vHmIXScTT3tFx12y7W7qXtXMIFUQGZ6PN/eI+6/f0C+G6rjx59k6rv+SPZRVA1t1nIpdr+lK9exupQe1nFPuffw0uoOVDqz372ezqmV99I2VHjHMXAuefmGCldc5u4aO+1Y3n/1lPV+LcgT7e/bQn3rreMvINnbJhYvLOt30ErSvXIG55vf8nT5drbxJF1+q4yZzNlt5BV2a6njDFfY0Aq8azTN/lxvPWGbl5e83t92nw7WWPT9ERETkK2z8EBERka+w8UNERES+Erc5P+pbe/zwgRE36XjL4KNW3rLek2M6xjlLr9Lxvk9ydZw3a6dVrt1yM2/E/bm1/pM924z7t5lt5z2MbjouPsd+QvrhxtmIRa3XF7rm5WNRTPuk5LD+0hpWOtZHWlBiZH1iPn+jTrnYylvzVHMdjzjhcxPXt2/DHrC6n443T2lr5TV4yfHYhuCSCp1rOiveaf/tavCiSUcu91Bt/U86zv+gEo5d8V0kNfb8EBERka+w8UNERES+Erdhr0jVPvhax20iuuj64+SY9lkTP5Yap3v3XbxlfPofK13LpRz515fX/TXiFW+r9g780QyVZK/dYeVxVefEKN6120q3udak30f9UuMQszpxA2wHUTJhzw8RERH5Chs/RERE5Cts/BAREZGvJGzODxGll9VTzBIIdQNfRylp7Anaj0pZ9U6+jo/bNL9yToyIKAJ7foiIiMhX2PghIiIiX+GwFxFVilZTzXepM/59p5V37jCzqvqfm5qVxHtNvdsq1/YxDnURUdVjzw8RERH5Chs/RERE5Cts/BAREZGvcM4PEVWK7NlmLk+TiLwlz5jY+fiatlgAIqJ4Y88PERER+QobP0REROQropRK9DkQERERxQ17foiIiMhX2PghIiIiX2Hjh4iIiHyFjR8iIiLyFTZ+iIiIyFfY+CEiIiJf+X8m3SJfmhtA0wAAAABJRU5ErkJggg==",
|
333 |
+
"text/plain": [
|
334 |
+
"<Figure size 720x720 with 5 Axes>"
|
335 |
+
]
|
336 |
+
},
|
337 |
+
"metadata": {
|
338 |
+
"needs_background": "light"
|
339 |
+
},
|
340 |
+
"output_type": "display_data"
|
341 |
+
}
|
342 |
+
],
|
343 |
+
"source": [
|
344 |
+
"xbv,ybv = next(iter(dls.train))\n",
|
345 |
+
"logits = model(xbv)\n",
|
346 |
+
"probs = F.softmax(logits, dim=1)\n",
|
347 |
+
"idx = 5\n",
|
348 |
+
"_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
|
349 |
+
"for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
|
350 |
+
" ax.imshow(im)\n",
|
351 |
+
" ax.set_axis_off()\n",
|
352 |
+
" ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
|
353 |
+
" "
|
354 |
+
]
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"cell_type": "code",
|
358 |
+
"execution_count": 158,
|
359 |
"metadata": {
|
360 |
"tags": [
|
361 |
"exclude"
|
|
|
368 |
},
|
369 |
{
|
370 |
"cell_type": "code",
|
371 |
+
"execution_count": 159,
|
372 |
"metadata": {},
|
373 |
"outputs": [],
|
374 |
"source": [
|
375 |
"loaded_model = cnn_classifier()\n",
|
376 |
+
"loaded_model.load_state_dict(torch.load('classifier.pth'));\n",
|
377 |
"loaded_model.eval();"
|
378 |
]
|
379 |
},
|
380 |
{
|
381 |
"cell_type": "code",
|
382 |
+
"execution_count": 160,
|
383 |
+
"metadata": {
|
384 |
+
"tags": [
|
385 |
+
"exclude"
|
386 |
+
]
|
387 |
+
},
|
388 |
+
"outputs": [
|
389 |
+
{
|
390 |
+
"data": {
|
391 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbn0lEQVR4nO3deZgU1dUG8PfMwLAj+y4gwiDiguICLoALEAEFiUYBURTihgtBjRENCp/RqEFAQZSI+xqIiEgMGo3ivqAIDmSGfd8d9m1m+n5/VFO3TjPdFE1PT0/X+3seHs+dU11VPXe6vF13KTHGgIiIiCgoMkr7BIiIiIiSiY0fIiIiChQ2foiIiChQ2PghIiKiQGHjh4iIiAKFjR8iIiIKlFJv/IjIChG5uLTPI1lE5CERea20z6OksD7TB+syvbA+0wfr8uiVeuPnSIjjMRHZGv73uIhICR9zkIh8UZLHiDheOxGZKyJ7wv9tl6xjJ1v4D7pARHZ5/rUo4WMmuz6NiOz2vL/nk3XsZBKRYSKyTER2iMg6ERkrIuVK+JhJq0sRqSMiX4avO9tE5GsROTcZxy4NAbnWZorIw+G/150i8pOI1EjW8ZOF19niJbTxU9IXOwA3AugD4FQApwDoBeCmEj5m0ohIFoAZAF4DUBPAywBmhH9eGudT0vUJAG8bY6p6/i1LwjGT7VTP+xtSGieQhLqcCeB0Y0x1ACfB+YzeUcLHTKZdAG4AUBfOZ/MxADOT9Bk5BK+1CTEKwDkAOgKoDmAggH3JPgleZxPmiK6zh238hG+v3SciC0UkX0ReFJGK4VwXEVkjIveKyAYAL4pIhoj8SUSWhr8x/ENEann2N1BEVoZz9x/hm7sOwBhjzBpjzFoAYwAM8vNCETleRD4JH3eLiLzubeWLyLEi8o6IbA5vM0FE2gB4FkDHcGtyW3jbT0VkiOe1qpUrIuNFZHX4W/BcETnf5/vrAqAcgHHGmP3GmKcACIALfb7+sFKsPo/mfZSF+ixRqVSXxpilxphtB3cFIASgpc/3kfJ1aYzZZ4zJNcaEwu+vCE4jqFbsV/qXSvWJNL/WikhNAMMA/N4Ys9I4fjHGJKTxk2J1eTTvI+XrMl5+7/wMANAdwPEAsgE84Mk1gHMBaAbn28IdcL4xdAbQCEA+gIkAICInApgEp4XdCEBtAE0O7khEzjv4i4qiLYCfPeWfwz/zQwA8Gj5uGwDHAngofNxMAO8DWAmgOYDGAN4yxiwCcDOAr8OtyRo+j/U9gHZwfi9vAJh68A//kJMSmS8i/cPFtgDmG/3Mkfnw/x79SpX6BIBLReRXEckRkVuO4D2Uhfo8aI6IbAhfJJr7PKZfKVOXItJfRHYA2ALnjsFzPt9DmalLEZkP5+7AewCeN8Zs8nlcv1KlPtP9WnsygEIAV4Q/m3kiMtTnMf1KlboEeJ09lDEm5j8AKwDc7Cn3ALA0HHcBcABARU9+EYCLPOWGAArg3NEYCeeXczBXJfz6iw93HuHtiwCc4Cm3AmAAiJ/XR+yrD4CfwnFHAJsBlCtmu0EAvoj42acAhsTaJmL7fDi35ADnD+e1KNv92fv7Cf/sdQAPHen7KyP1eSKcD1UmnNvP6wH0i/N9pVx9hvOdAGQBqAFgAoBfijuvsl6XEefVCsD/AWiQTnXpeU1FAP0AXJeIekzF+kT6X2v7h9/PFACV4HTtbQbQNQ3rktfZYv757Wtc7YlXhn+RB202+lZhMwDTRSTk+VkRgPrh17n7MsbsFpGtPs8BcPrdq3vK1QHsMuF3H4uI1APwFIDzAVSDc9crP5w+FsBKY0zhEZxLrGPdBWAInPdrwudZx8dLI98fwuWdiTgvj5SoT2PMQk/xKxEZD+AKAG8e7rVlpD5hjJkTDg+IyJ0AdsD5BrUgEeeGFKlLL2PMYhHJAfAMgL6H276s1OVB4d/pmyKySETmGWN+PuyL/EuV+kz3a+3e8H9HG2P2ApgvIm/BaaR8lIhzQ4rUJa+zxfPb7XWsJ24KYJ33uBHbrgZwiTGmhudfReP0G6/37ktEKsO5hedXDpzb6QedGv6ZH4+Gz/UU4wzKvAbOLb2D59xUih94VtyHfTeAyp5yg4NBuJ/yXgC/A1DTOLf8tnuOFUsOgFNE1KyKU+D/PfqVKvUZycDf7wkoG/VZnCN5j36kal2Wg3O734+yWpflASR61kyq1Ge6X2vnxzhmoqRKXUbidRb+Gz9DRaSJOAOwRgB4O8a2zwL4i4g0AwARqSsivcO5aQB6hfsoswCMPoJzAIBXAAwXkcYi0gjAXQBeOpgMD6h6KMprq8H5NrNNRBoDuMeT+w7OH9hfRaSKiFQUO411I4AmomdczQPQV0Qqi0hLAIMjjlOI8O1AERmJQ+/mRPMpnNb+HSJSQURuC//8E5+v9ysl6lNEeotITXGcBaffe4YnX6brU0TairN0QaaIVIUzaHQtnFvciZIqdTkk/C3x4BiF+wB87MmX9brscPB3IyKVROReON/Kv/Xz+iOQEvWJNL/WGmOWAvgcwP3ha20bAFfBGcOSKClRl7zOFs/vL/ANAB8CWBb+93CMbcfDGQz4oYjsBPANgLMBwBiTA2BoeH/r4dw+W+N5E+eLyK4Y+34OzpTaBXD69GZBD6o8FsCXUV47CsDpcFqTswC8czBhjCkCcCmc2Smrwud0VTj9CZxvPBtEZEv4Z2Ph9LluhDMd/XXPcWYD+ABAHpxbnfugb38q4gxAGxA+jwNw+lSvBbANztTaPuGfJ1Kq1OfVAJbA6dZ7BcBjxpiXPfkyXZ9w/uf4NpxbsMvgDArsZYwpiPb6OKRKXZ4LYIGI7Abwr/C/EZ58Wa/LCnAGoG6Fc2HtAaCnMWZdtNfHKVXqM62vtWH94HQ3bQ2f55+NMR8X++L4pEpd8jpb3D4O14UrIivgDFL6T8wNS5mINAEw1RjTsbTPJZWxPtMH6zK9sD7TB+sy9ZXKAl0lwRizBs7oc0oDrM/0wbpML6zP9BHkuixTj7cgIiIiOlqH7fYiIiIiSie880NERESBwsYPERERBYqvAc9dM65k31gp+yg0NSEL47EuS1+i6hJgfaYCfjbTBz+b6SVWffLODxEREQUKGz9EREQUKGz8EBERUaCw8UNERESBwsYPERERBQobP0RERBQobPwQERFRoLDxQ0RERIGSNk91JyIiTdq3dePcoZVUbsaFE9x4+PW3qlzmf38s2RMjKmW880NERESBwsYPERERBQobP0RERBQoHPNDRJSmNrev7sb/6/5URDbTjdZ2rqgyTf9bkmdFJc2c206Vl/e2470WXzNJ5cblN3fjlyf1cOP6z36n91lYmLgTTAG880NERESBwsYPERERBQq7vSjlZVa3t+6LduzQubato76uoFZlN85a86vKFS5fmaCzK5syKtvfTe6jp6jceWctdOPPc7JV7l8X266TnnNuU7nFFz3vxpmiv1cVmZCv83p1ZwM3/sflnfU+cpfZQqjI1/6CrsEs+3c+7vYTVW5YLVvPD/Z7U+VGmX5u3HTUVyV0dlRSMncfUOVhvT524wKjPztDayy18X1Pu/HlV/RS2+U/08yNa8xZrnJF+dvc2Ozfr3JSzjYzMuvUPtypuwo3brIFY3y/zi/e+SEiIqJAYeOHiIiIAoWNHyIiIgqUQI752XJjx4Tvs87krxO+z3STWbeuKq8a3MqNT+iZp3InV1/nxq0rLnDj3H0N1XYP1LFjFUKI3i/cNee3qlyhm48TTmPHf2b7/d9rNDH6hofMea7gRrkX/V1lvKN6Qia+MTkDqq238X/eUrlLBvzejTM/5eMX/Chcaz9Hq/fVirrd5VU3qfKUT/aW2DlRyQvNW6jKsy4/243H3Kcvfp9eON6NG2baKfHTW72vdzo2+vGyZ9/kxpWWZqncvnr2ypB7RYxrTYTe7X7jxkWbN/t+nV+880NERESBwsYPERERBUogu72+edA+zThy2l95yYwrd3qTO9242cjgdoFFTj1ffL+9jfpShxdUrkOF2W4cq8tKqbol4gfi62WNqmxX5a3+jpa21uypUdqncMSueMb+vczsfprKFa5ek+zTSWsZn/9U2qdACVSUu8SN29zfSOXWfW6v0Q0zEZe87s/F90KPnw5ELIdRwstZ8M4PERERBQobP0RERBQobPwQERFRoARizM/K0XpqewbsNFnvOB4nJ3HlcgbbKXwt69+kctk3fX+EZ1y2rLnvHDf+zy2Pq1wdz9TJSLuMXYK95y8DVG7rV/YxB1n6iRZKRoGNay/cp3JFWbZtX+m7pSBr21+buvGqZ/W05qUFNePa581fXOvGlXIrRN1uTwtbaXk9nvW9/8HHrHLjd+t00UmO+TmsDNFjKjL43TcwyrVo7sZN3tqocu2jfFRbT7816v5yL38mrvO4bsXFqvzt93aMaNMP9FPjs7b+ENcx/OJfPxEREQUKGz9EREQUKGWu28vbheV3SnnkdqHBdlp15JT1TqPslPVDZl97ZlWfdaOeCjq+0ZdunNNTr2J5+ug0mwYvenr5eZfb30W9zMoqt9fTtXXGC39QueNfsqv5Vlm2TOWqQJfj4e2Y5DPAtQof2K7Ym667Q+XiXT25FeZGzRVdcLrd/8Wbom4Xy/oi2z2XcUDfImf9Hl7I6O+6IYSibFm6Qp3tMgZLb9DXmnHn2FW/H5g0SOUaPsmnz0ezv5ld3XtC439G3W76brtd67t/VrnQAdtdfenLg+I6j8x1epGRlmu/iWs/icA7P0RERBQobPwQERFRoLDxQ0RERIGSkmN+9l9ypht/OkU/ObrA2HEF5YdkRuRsz3+vxu2j7j9Wrjb8jclZqk8Ld/3QwY3HNNT9mN162Cl7S147XuWK8sreFOxCz/gNAJjQeLIbR44iOHfMcDduNlb3yReCUkG8Y3wyKlZUZdPW/m0vvaq6yj3d1z7a5KJKe+I6Xp95g924bk5uXPug1JBZ244tye+erXIPjn7RjSP/VrzjlLJunaJyY77v78Z8PIe2tF/051a8sqOxG789uLsby76fi9scAGC+XxDXeaTSNZ93foiIiChQ2PghIiKiQEmJbq/8Wa1UeUT2m24cORXdW75z3bkq991zdoqk3+6rRPly8hluXDDyS5XzdoPd+bq+/bj0TJQ5686Pvnrv5O3NVbnRJNulciQTa8sd18yNDzStFWPL6LJW/erGhctXxrUPitDhFDc0j+hpqzNbv1Kihx7ZZpYb3/P4QJVr9Xe7am3R4qNfJiFoFkU+UbuE5Y63n++FFzwdY8vo38+n/6q734Pe1ZVRrZob/+8p3ZU4r+tTnlKWys3e2taN5avoXV3phnd+iIiIKFDY+CEiIqJAYeOHiIiIAiVpY34yW7dU5Y2d67rxd+304yBCnudKeJ+kDgB3r+/kxkvP1E/xTvY4H686k+2xu/TWTyj/ut3bbux9DAYA9EL0afep6rwe0fuFZ244RZXlhGPceNM5NVTu/BvsIxbOq56ncidk2XFSbcqXj3o8799HKOJ5JIsK7HLsr+d3ULnp/7aPSWnxzi6Vi3caZxAsH2bjnNbvJfXYPStvt/GACSo3vY8dF/ZyJ13XhRv0U6zJ4R3nc/cNt6hcJuJb/iAa7yMrAGBihzcSun8C0MxOWc/rNjkimYVolrxmxwfVLcX/hyYb7/wQERFRoLDxQ0RERIGStG6vFq+tVuVpDe1tzxCir9Ts7eYCgBUDGntKqbk6sjG6q87bHRM5db8surpO9CfxTsvWTwzeM9O+35oZekXgWF1WeQW2fMGCK6MeL0PsdjUq7lW5f7a0U6MfrqefOP7wtba865r9Knf32m5uvPb241SOXWJHb+I2uxL0i4s7qtyBH2u68Z1Xz1C5wcesirrPy6vYZQ3uHdlM5dpMrOHGRQFeGTpDQhFlE2XLxBs25U1VzvAsfHHSZ79XuYWd7crN5SXy/w02zkzi+acK76rqi6e0UblvOnm7gvW1NpbCKvY6vHWI/TzWeel7tZ0pTKX1mY8e7/wQERFRoLDxQ0RERIGStG6vCY2/VeUCY29nRs7o8t7qjJzRlapdXXl/t0s15532rMp531/kbdyyaMicQaq8pJt9ymsl0bMKdobsgwkHruiqcgveP8GNm7+qV2AuXLPWjavA34q9+yPK3pl0O6/SM4A2enpbvuj7N5WbfOwcW3h3jsq1eXWoGx93X0T3nwnebfhons7Xq7a/vORsN25yS74bN1i/KOo+Zr6sV/B9aZKttGknvaRy9TMruXFe70kql13+ZhvrHpZACRn9Xbd1eXstevLFZ1Tunub683K0Hv7zIFWuvMnOxGy1ZrvKhT6xXWIFER8p74NN5/2tncpVQ/Tu+HSRUcPOnl3U5fmIrP+uLq+5w4tfYfuiK69QZZlgZ2hXnPldXMdKJbzzQ0RERIHCxg8REREFChs/REREFChJG/MT6+nskeNgsmfa1UazUTb6FvN62HE+ke/V+/687w0oO+/Pq9k03Wae18VOgXxs7SUqt+EJO6250gz9XpvgKzcu6UmU1d7+JqJs48GTBqlcnZc2ufGUpv9VuZyBdjpph2W36ddNTv/VUVs8Yf+2L379lqjbVf1prSo3WGPH9vit68KVenmMY3rYuMsT96jcov56xWevCRe86sa3T7xO5Vrd/oMthMr+MhTx8o7/AYBVD57jxk1HfRW5+RGr/mb08TiLnj3L935+2m+vPVVX7Y2xJR2tj0+apsq7nrGjKvtgmMqVxTFAvPNDREREgcLGDxEREQVK0rq9Yk3xPnfe1aqcfXNq3ELbf8mZqryyr4293VxA7OnsJ0+x3SPZI8t+10iFWXrlz5Hz7S8mtE1PW63a1Jbz3mincsf3n5fwc4tHUe4SVd76uyZu3HXKb1Xuo7Z2BetXRjypcsMn69WK05GZm+PGFedG366kuzFbP6W7xFpXvNWNc/vqadvdKu22uT46d+lw271j9qd3t9en/4x4iPLtX0Tddl/jgqi5RCjodoYbj70o+kNOe/yvjyoXPtHAjbO+/h5BU5S/zY1bT79V5W7o9Jkbz36os+99bm1j/3817nq7bEnnSnvUdlUzKrjx/WNfVLnRQy914+rX6P8HFG3Z6vtckol3foiIiChQ2PghIiKiQGHjh4iIiAIlJaa6Rz4FvTTlz7LL8o/I1k8i7lnZ9mUe0XT2NBjnE0vh6jVRc3uya7jxos56vEWfJpfZfazRU6NLk/f9rM6JWOa/rQ1n7GiXnBOiQ0T+zTWc09iN8/voR+LUzIi+7P/qu+w4mCaPHP2U7lTW7B/rVPmRfu3ceESdeSV67FDn01T5j8/Y5QcuqLQr6uuWLG2gytn/Dt44Hy+z3043b3WbfmTU557HW1SGzsVS2ROP+Yu9wN088Wy13Z8unOnG11fXY+4uOHmqzb13kcrNn2bH1TV8MnU+Y7zzQ0RERIHCxg8REREFSkpMdf/mtLdUueVzN7lxrbmJP8Vfz7QTcQ+dsv6jG4dgInK2e+7u9Z1Uzvv0+bK4anNpWDGwmRs3eTR1ur28pP6+qLkXPrxAlY8PwFOlU1VhJfs9bnOR/k5XM8ZXvClD7BOtH3ykffQN00DhshWqPHWqnQ79wC3zVW55Lzvl+bShekp1vYn+ui4KL7S/z/W37le5i9Q06ugVlH1jsLu54rXuj+eo8jEXbrDxYH1NK1yru0MPajVUd51NP8n+P2/sFbVU7l/XP+7GLzb7WOW2DHvfjbtl/FHlGv2t9LrBeOeHiIiIAoWNHyIiIgoUNn6IiIgoUJI25ueMB/X077Nu/MmNxzf6UuVyek504/K99FihWE+Dj5aLNS39SKas155rc/U+2wRtKehQlafbfuMnR5+gcjm326nvfX7TXeW2P9rUjbM+tOOwEvXkbalgl2ovOrONyi25tryNOz8X+Uo3qrac3x2iKde4kSr/2tnWZ9VVdszB7iZ6Gvru+v5+pztO1I9fmNfTPmqksmT5Ps/2FQ6/Tbo67sUVbjy5f3OVu/EYm5tx7+Mq17XhPW5cdVX0/e/tttONf+qoH4cQ8sTrCvV4oL6P2XEh9ZA6U6PLkr119XjVHz1T0fu9pa+1ee96pqKPif77Dv3yPzdu+ovO9dpv62zebU+rXJ3MSm58at+FKpf/Wn03LtywMeqxSwKv3kRERBQobPwQERFRoIgx5rAbdc248vAbHYXWP5RX5TEN7ZRh7/RyQE8/95uLnLLufYp85OrStXrl+T3tpPooNDUhy2CXdF3GsqevXjH0zXFj3Li+59ZopLvW21WWfz1QReUyxN5ADxn/bflTPSuUDq+5WOUi/1682n1zrRsfe7V+nSk44OvYiapLoHTrM5bV005S5Tln2+7DjZ6p6HUzQ2q7WKsxl4Ts2XZZjewbfohrH+nw2cw8MVuVB7072417V9kS1z4zPN+tQ9D1vKzAdltePf5ulWswtvS6utLls5lZv57+wVt2hMuM7JkqtaVorxt3esN2aTaeU6i2q/KzXY4kcnq8lLddzXlP6tW8c/vqlf29enfr78ZFOblRt4tXrPrknR8iIiIKFDZ+iIiIKFDY+CEiIqJASdpU91iW9G+qyh07nWELkT123l5Uv7mIntfaz6f3U9ZTVeV39HLp/UPD3bjeXctUblyzd93YOwYsUqyxXf7pP6RXdtgnhD/yQR+Vyx4xzx7P5xifIKpTbbcqH+MZy3NMKX7lGrW5nSq3uXelGydmEYWyqWihHuv4Qr9ebvz46dVVbsgf3nPj6z1T4o/ElZPsOJ/GpTjGJ10VbdRLsZS7tokbD3izm8q9ftyHbrxw4ASbGKj32TvvUjdevPr0qMfuesL8qLlUwjs/REREFChs/BAREVGgpES3V1GeXh25dh5XSw6CSu9+58Y739W5G1vae65bOzZw41oLtqvtVvSucdTn0eBbvVpwxc/tKqQtd+suNz1hl8oCb1fXj10bqFzR5s1JPpuywczNcePac3VucrnL3Hj9zXPceESdeVH3d+Lbt6vyCVPtVOnCyI0p4QpXr3HjnZdUU7kzXr3GjUe1tdPge1bW11o1RV6vjODb5O3NVVl27y1+wyTgnR8iIiIKFDZ+iIiIKFDY+CEiIqJASYkxP0SRipYsd+ManjhyzE3TeYk/Nsf1HL2NP+ixNTip+O0S5bGtbd3468v0gITQJvt4htAejvE5WnUn2aVCvplkH010Gc6M+pqW0GPnOM6n9IR27lTlBn0WufH47vbRT6OOz1LbXTzE1vsj9eN7FMy4mb1UucWK0lt2hnd+iIiIKFDY+CEiIqJAYbcXESVc8wd0N8elo88pdru1d7RX5T2No3c6tp6w0Y1Dq/VTpRGyq3ubglV+T5OIPLJm2+6siOfCY77n4ey90B7xaIHUeboC7/wQERFRoLDxQ0RERIHCxg8REREFCsf8EFHiGaOL+/cXu1mjJ/w/0TvIT10nosTinR8iIiIKFDZ+iIiIKFDY+CEiIqJAYeOHiIiIAoWNHyIiIgoUNn6IiIgoUNj4ISIiokBh44eIiIgChY0fIiIiChQxESuxEhEREaUz3vkhIiKiQGHjh4iIiAKFjR8iIiIKFDZ+iIiIKFDY+CEiIqJAYeOHiIiIAuX/AW21zFHP4lndAAAAAElFTkSuQmCC",
|
392 |
+
"text/plain": [
|
393 |
+
"<Figure size 720x720 with 5 Axes>"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
"metadata": {
|
397 |
+
"needs_background": "light"
|
398 |
+
},
|
399 |
+
"output_type": "display_data"
|
400 |
+
}
|
401 |
+
],
|
402 |
+
"source": [
|
403 |
+
"with torch.no_grad():\n",
|
404 |
+
" xbv,ybv = next(iter(dls.train))\n",
|
405 |
+
" logits = loaded_model(xbv)\n",
|
406 |
+
" probs = F.softmax(logits, dim=1)\n",
|
407 |
+
" idx = 5\n",
|
408 |
+
" _,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
|
409 |
+
" for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
|
410 |
+
" ax.imshow(im)\n",
|
411 |
+
" ax.set_axis_off()\n",
|
412 |
+
" ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
|
413 |
+
" "
|
414 |
+
]
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "code",
|
418 |
+
"execution_count": 161,
|
419 |
+
"metadata": {
|
420 |
+
"tags": [
|
421 |
+
"exclude"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
"outputs": [
|
425 |
+
{
|
426 |
+
"data": {
|
427 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAB+CAYAAADLN3DXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbn0lEQVR4nO3deZgU1dUG8PfMwLAj+y4gwiDiguICLoALEAEFiUYBURTihgtBjRENCp/RqEFAQZSI+xqIiEgMGo3ivqAIDmSGfd8d9m1m+n5/VFO3TjPdFE1PT0/X+3seHs+dU11VPXe6vF13KTHGgIiIiCgoMkr7BIiIiIiSiY0fIiIiChQ2foiIiChQ2PghIiKiQGHjh4iIiAKFjR8iIiIKlFJv/IjIChG5uLTPI1lE5CERea20z6OksD7TB+syvbA+0wfr8uiVeuPnSIjjMRHZGv73uIhICR9zkIh8UZLHiDheOxGZKyJ7wv9tl6xjJ1v4D7pARHZ5/rUo4WMmuz6NiOz2vL/nk3XsZBKRYSKyTER2iMg6ERkrIuVK+JhJq0sRqSMiX4avO9tE5GsROTcZxy4NAbnWZorIw+G/150i8pOI1EjW8ZOF19niJbTxU9IXOwA3AugD4FQApwDoBeCmEj5m0ohIFoAZAF4DUBPAywBmhH9eGudT0vUJAG8bY6p6/i1LwjGT7VTP+xtSGieQhLqcCeB0Y0x1ACfB+YzeUcLHTKZdAG4AUBfOZ/MxADOT9Bk5BK+1CTEKwDkAOgKoDmAggH3JPgleZxPmiK6zh238hG+v3SciC0UkX0ReFJGK4VwXEVkjIveKyAYAL4pIhoj8SUSWhr8x/ENEann2N1BEVoZz9x/hm7sOwBhjzBpjzFoAYwAM8vNCETleRD4JH3eLiLzubeWLyLEi8o6IbA5vM0FE2gB4FkDHcGtyW3jbT0VkiOe1qpUrIuNFZHX4W/BcETnf5/vrAqAcgHHGmP3GmKcACIALfb7+sFKsPo/mfZSF+ixRqVSXxpilxphtB3cFIASgpc/3kfJ1aYzZZ4zJNcaEwu+vCE4jqFbsV/qXSvWJNL/WikhNAMMA/N4Ys9I4fjHGJKTxk2J1eTTvI+XrMl5+7/wMANAdwPEAsgE84Mk1gHMBaAbn28IdcL4xdAbQCEA+gIkAICInApgEp4XdCEBtAE0O7khEzjv4i4qiLYCfPeWfwz/zQwA8Gj5uGwDHAngofNxMAO8DWAmgOYDGAN4yxiwCcDOAr8OtyRo+j/U9gHZwfi9vAJh68A//kJMSmS8i/cPFtgDmG/3Mkfnw/x79SpX6BIBLReRXEckRkVuO4D2Uhfo8aI6IbAhfJJr7PKZfKVOXItJfRHYA2ALnjsFzPt9DmalLEZkP5+7AewCeN8Zs8nlcv1KlPtP9WnsygEIAV4Q/m3kiMtTnMf1KlboEeJ09lDEm5j8AKwDc7Cn3ALA0HHcBcABARU9+EYCLPOWGAArg3NEYCeeXczBXJfz6iw93HuHtiwCc4Cm3AmAAiJ/XR+yrD4CfwnFHAJsBlCtmu0EAvoj42acAhsTaJmL7fDi35ADnD+e1KNv92fv7Cf/sdQAPHen7KyP1eSKcD1UmnNvP6wH0i/N9pVx9hvOdAGQBqAFgAoBfijuvsl6XEefVCsD/AWiQTnXpeU1FAP0AXJeIekzF+kT6X2v7h9/PFACV4HTtbQbQNQ3rktfZYv757Wtc7YlXhn+RB202+lZhMwDTRSTk+VkRgPrh17n7MsbsFpGtPs8BcPrdq3vK1QHsMuF3H4uI1APwFIDzAVSDc9crP5w+FsBKY0zhEZxLrGPdBWAInPdrwudZx8dLI98fwuWdiTgvj5SoT2PMQk/xKxEZD+AKAG8e7rVlpD5hjJkTDg+IyJ0AdsD5BrUgEeeGFKlLL2PMYhHJAfAMgL6H276s1OVB4d/pmyKySETmGWN+PuyL/EuV+kz3a+3e8H9HG2P2ApgvIm/BaaR8lIhzQ4rUJa+zxfPb7XWsJ24KYJ33uBHbrgZwiTGmhudfReP0G6/37ktEKsO5hedXDpzb6QedGv6ZH4+Gz/UU4wzKvAbOLb2D59xUih94VtyHfTeAyp5yg4NBuJ/yXgC/A1DTOLf8tnuOFUsOgFNE1KyKU+D/PfqVKvUZycDf7wkoG/VZnCN5j36kal2Wg3O734+yWpflASR61kyq1Ge6X2vnxzhmoqRKXUbidRb+Gz9DRaSJOAOwRgB4O8a2zwL4i4g0AwARqSsivcO5aQB6hfsoswCMPoJzAIBXAAwXkcYi0gjAXQBeOpgMD6h6KMprq8H5NrNNRBoDuMeT+w7OH9hfRaSKiFQUO411I4AmomdczQPQV0Qqi0hLAIMjjlOI8O1AERmJQ+/mRPMpnNb+HSJSQURuC//8E5+v9ysl6lNEeotITXGcBaffe4YnX6brU0TairN0QaaIVIUzaHQtnFvciZIqdTkk/C3x4BiF+wB87MmX9brscPB3IyKVROReON/Kv/Xz+iOQEvWJNL/WGmOWAvgcwP3ha20bAFfBGcOSKClRl7zOFs/vL/ANAB8CWBb+93CMbcfDGQz4oYjsBPANgLMBwBiTA2BoeH/r4dw+W+N5E+eLyK4Y+34OzpTaBXD69GZBD6o8FsCXUV47CsDpcFqTswC8czBhjCkCcCmc2Smrwud0VTj9CZxvPBtEZEv4Z2Ph9LluhDMd/XXPcWYD+ABAHpxbnfugb38q4gxAGxA+jwNw+lSvBbANztTaPuGfJ1Kq1OfVAJbA6dZ7BcBjxpiXPfkyXZ9w/uf4NpxbsMvgDArsZYwpiPb6OKRKXZ4LYIGI7Abwr/C/EZ58Wa/LCnAGoG6Fc2HtAaCnMWZdtNfHKVXqM62vtWH94HQ3bQ2f55+NMR8X++L4pEpd8jpb3D4O14UrIivgDFL6T8wNS5mINAEw1RjTsbTPJZWxPtMH6zK9sD7TB+sy9ZXKAl0lwRizBs7oc0oDrM/0wbpML6zP9BHkuixTj7cgIiIiOlqH7fYiIiIiSie880NERESBwsYPERERBYqvAc9dM65k31gp+yg0NSEL47EuS1+i6hJgfaYCfjbTBz+b6SVWffLODxEREQUKGz9EREQUKGz8EBERUaCw8UNERESBwsYPERERBQobP0RERBQobPwQERFRoLDxQ0RERIGSNk91JyIiTdq3dePcoZVUbsaFE9x4+PW3qlzmf38s2RMjKmW880NERESBwsYPERERBQobP0RERBQoHPNDRJSmNrev7sb/6/5URDbTjdZ2rqgyTf9bkmdFJc2c206Vl/e2470WXzNJ5cblN3fjlyf1cOP6z36n91lYmLgTTAG880NERESBwsYPERERBQq7vSjlZVa3t+6LduzQubato76uoFZlN85a86vKFS5fmaCzK5syKtvfTe6jp6jceWctdOPPc7JV7l8X266TnnNuU7nFFz3vxpmiv1cVmZCv83p1ZwM3/sflnfU+cpfZQqjI1/6CrsEs+3c+7vYTVW5YLVvPD/Z7U+VGmX5u3HTUVyV0dlRSMncfUOVhvT524wKjPztDayy18X1Pu/HlV/RS2+U/08yNa8xZrnJF+dvc2Ozfr3JSzjYzMuvUPtypuwo3brIFY3y/zi/e+SEiIqJAYeOHiIiIAoWNHyIiIgqUQI752XJjx4Tvs87krxO+z3STWbeuKq8a3MqNT+iZp3InV1/nxq0rLnDj3H0N1XYP1LFjFUKI3i/cNee3qlyhm48TTmPHf2b7/d9rNDH6hofMea7gRrkX/V1lvKN6Qia+MTkDqq238X/eUrlLBvzejTM/5eMX/Chcaz9Hq/fVirrd5VU3qfKUT/aW2DlRyQvNW6jKsy4/243H3Kcvfp9eON6NG2baKfHTW72vdzo2+vGyZ9/kxpWWZqncvnr2ypB7RYxrTYTe7X7jxkWbN/t+nV+880NERESBwsYPERERBUogu72+edA+zThy2l95yYwrd3qTO9242cjgdoFFTj1ffL+9jfpShxdUrkOF2W4cq8tKqbol4gfi62WNqmxX5a3+jpa21uypUdqncMSueMb+vczsfprKFa5ek+zTSWsZn/9U2qdACVSUu8SN29zfSOXWfW6v0Q0zEZe87s/F90KPnw5ELIdRwstZ8M4PERERBQobP0RERBQobPwQERFRoARizM/K0XpqewbsNFnvOB4nJ3HlcgbbKXwt69+kctk3fX+EZ1y2rLnvHDf+zy2Pq1wdz9TJSLuMXYK95y8DVG7rV/YxB1n6iRZKRoGNay/cp3JFWbZtX+m7pSBr21+buvGqZ/W05qUFNePa581fXOvGlXIrRN1uTwtbaXk9nvW9/8HHrHLjd+t00UmO+TmsDNFjKjL43TcwyrVo7sZN3tqocu2jfFRbT7816v5yL38mrvO4bsXFqvzt93aMaNMP9FPjs7b+ENcx/OJfPxEREQUKGz9EREQUKGWu28vbheV3SnnkdqHBdlp15JT1TqPslPVDZl97ZlWfdaOeCjq+0ZdunNNTr2J5+ug0mwYvenr5eZfb30W9zMoqt9fTtXXGC39QueNfsqv5Vlm2TOWqQJfj4e2Y5DPAtQof2K7Ym667Q+XiXT25FeZGzRVdcLrd/8Wbom4Xy/oi2z2XcUDfImf9Hl7I6O+6IYSibFm6Qp3tMgZLb9DXmnHn2FW/H5g0SOUaPsmnz0ezv5ld3XtC439G3W76brtd67t/VrnQAdtdfenLg+I6j8x1epGRlmu/iWs/icA7P0RERBQobPwQERFRoLDxQ0RERIGSkmN+9l9ypht/OkU/ObrA2HEF5YdkRuRsz3+vxu2j7j9Wrjb8jclZqk8Ld/3QwY3HNNT9mN162Cl7S147XuWK8sreFOxCz/gNAJjQeLIbR44iOHfMcDduNlb3yReCUkG8Y3wyKlZUZdPW/m0vvaq6yj3d1z7a5KJKe+I6Xp95g924bk5uXPug1JBZ244tye+erXIPjn7RjSP/VrzjlLJunaJyY77v78Z8PIe2tF/051a8sqOxG789uLsby76fi9scAGC+XxDXeaTSNZ93foiIiChQ2PghIiKiQEmJbq/8Wa1UeUT2m24cORXdW75z3bkq991zdoqk3+6rRPly8hluXDDyS5XzdoPd+bq+/bj0TJQ5686Pvnrv5O3NVbnRJNulciQTa8sd18yNDzStFWPL6LJW/erGhctXxrUPitDhFDc0j+hpqzNbv1Kihx7ZZpYb3/P4QJVr9Xe7am3R4qNfJiFoFkU+UbuE5Y63n++FFzwdY8vo38+n/6q734Pe1ZVRrZob/+8p3ZU4r+tTnlKWys3e2taN5avoXV3phnd+iIiIKFDY+CEiIqJAYeOHiIiIAiVpY34yW7dU5Y2d67rxd+304yBCnudKeJ+kDgB3r+/kxkvP1E/xTvY4H686k+2xu/TWTyj/ut3bbux9DAYA9EL0afep6rwe0fuFZ244RZXlhGPceNM5NVTu/BvsIxbOq56ncidk2XFSbcqXj3o8799HKOJ5JIsK7HLsr+d3ULnp/7aPSWnxzi6Vi3caZxAsH2bjnNbvJfXYPStvt/GACSo3vY8dF/ZyJ13XhRv0U6zJ4R3nc/cNt6hcJuJb/iAa7yMrAGBihzcSun8C0MxOWc/rNjkimYVolrxmxwfVLcX/hyYb7/wQERFRoLDxQ0RERIGStG6vFq+tVuVpDe1tzxCir9Ts7eYCgBUDGntKqbk6sjG6q87bHRM5db8surpO9CfxTsvWTwzeM9O+35oZekXgWF1WeQW2fMGCK6MeL0PsdjUq7lW5f7a0U6MfrqefOP7wtba865r9Knf32m5uvPb241SOXWJHb+I2uxL0i4s7qtyBH2u68Z1Xz1C5wcesirrPy6vYZQ3uHdlM5dpMrOHGRQFeGTpDQhFlE2XLxBs25U1VzvAsfHHSZ79XuYWd7crN5SXy/w02zkzi+acK76rqi6e0UblvOnm7gvW1NpbCKvY6vHWI/TzWeel7tZ0pTKX1mY8e7/wQERFRoLDxQ0RERIGStG6vCY2/VeUCY29nRs7o8t7qjJzRlapdXXl/t0s15532rMp531/kbdyyaMicQaq8pJt9ymsl0bMKdobsgwkHruiqcgveP8GNm7+qV2AuXLPWjavA34q9+yPK3pl0O6/SM4A2enpbvuj7N5WbfOwcW3h3jsq1eXWoGx93X0T3nwnebfhons7Xq7a/vORsN25yS74bN1i/KOo+Zr6sV/B9aZKttGknvaRy9TMruXFe70kql13+ZhvrHpZACRn9Xbd1eXstevLFZ1Tunub683K0Hv7zIFWuvMnOxGy1ZrvKhT6xXWIFER8p74NN5/2tncpVQ/Tu+HSRUcPOnl3U5fmIrP+uLq+5w4tfYfuiK69QZZlgZ2hXnPldXMdKJbzzQ0RERIHCxg8REREFChs/REREFChJG/MT6+nskeNgsmfa1UazUTb6FvN62HE+ke/V+/687w0oO+/Pq9k03Wae18VOgXxs7SUqt+EJO6250gz9XpvgKzcu6UmU1d7+JqJs48GTBqlcnZc2ufGUpv9VuZyBdjpph2W36ddNTv/VUVs8Yf+2L379lqjbVf1prSo3WGPH9vit68KVenmMY3rYuMsT96jcov56xWevCRe86sa3T7xO5Vrd/oMthMr+MhTx8o7/AYBVD57jxk1HfRW5+RGr/mb08TiLnj3L935+2m+vPVVX7Y2xJR2tj0+apsq7nrGjKvtgmMqVxTFAvPNDREREgcLGDxEREQVK0rq9Yk3xPnfe1aqcfXNq3ELbf8mZqryyr4293VxA7OnsJ0+x3SPZI8t+10iFWXrlz5Hz7S8mtE1PW63a1Jbz3mincsf3n5fwc4tHUe4SVd76uyZu3HXKb1Xuo7Z2BetXRjypcsMn69WK05GZm+PGFedG366kuzFbP6W7xFpXvNWNc/vqadvdKu22uT46d+lw271j9qd3t9en/4x4iPLtX0Tddl/jgqi5RCjodoYbj70o+kNOe/yvjyoXPtHAjbO+/h5BU5S/zY1bT79V5W7o9Jkbz36os+99bm1j/3817nq7bEnnSnvUdlUzKrjx/WNfVLnRQy914+rX6P8HFG3Z6vtckol3foiIiChQ2PghIiKiQGHjh4iIiAIlJaa6Rz4FvTTlz7LL8o/I1k8i7lnZ9mUe0XT2NBjnE0vh6jVRc3uya7jxos56vEWfJpfZfazRU6NLk/f9rM6JWOa/rQ1n7GiXnBOiQ0T+zTWc09iN8/voR+LUzIi+7P/qu+w4mCaPHP2U7lTW7B/rVPmRfu3ceESdeSV67FDn01T5j8/Y5QcuqLQr6uuWLG2gytn/Dt44Hy+z3043b3WbfmTU557HW1SGzsVS2ROP+Yu9wN088Wy13Z8unOnG11fXY+4uOHmqzb13kcrNn2bH1TV8MnU+Y7zzQ0RERIHCxg8REREFSkpMdf/mtLdUueVzN7lxrbmJP8Vfz7QTcQ+dsv6jG4dgInK2e+7u9Z1Uzvv0+bK4anNpWDGwmRs3eTR1ur28pP6+qLkXPrxAlY8PwFOlU1VhJfs9bnOR/k5XM8ZXvClD7BOtH3ykffQN00DhshWqPHWqnQ79wC3zVW55Lzvl+bShekp1vYn+ui4KL7S/z/W37le5i9Q06ugVlH1jsLu54rXuj+eo8jEXbrDxYH1NK1yru0MPajVUd51NP8n+P2/sFbVU7l/XP+7GLzb7WOW2DHvfjbtl/FHlGv2t9LrBeOeHiIiIAoWNHyIiIgoUNn6IiIgoUJI25ueMB/X077Nu/MmNxzf6UuVyek504/K99FihWE+Dj5aLNS39SKas155rc/U+2wRtKehQlafbfuMnR5+gcjm326nvfX7TXeW2P9rUjbM+tOOwEvXkbalgl2ovOrONyi25tryNOz8X+Uo3qrac3x2iKde4kSr/2tnWZ9VVdszB7iZ6Gvru+v5+pztO1I9fmNfTPmqksmT5Ps/2FQ6/Tbo67sUVbjy5f3OVu/EYm5tx7+Mq17XhPW5cdVX0/e/tttONf+qoH4cQ8sTrCvV4oL6P2XEh9ZA6U6PLkr119XjVHz1T0fu9pa+1ee96pqKPif77Dv3yPzdu+ovO9dpv62zebU+rXJ3MSm58at+FKpf/Wn03LtywMeqxSwKv3kRERBQobPwQERFRoIgx5rAbdc248vAbHYXWP5RX5TEN7ZRh7/RyQE8/95uLnLLufYp85OrStXrl+T3tpPooNDUhy2CXdF3GsqevXjH0zXFj3Li+59ZopLvW21WWfz1QReUyxN5ADxn/bflTPSuUDq+5WOUi/1682n1zrRsfe7V+nSk44OvYiapLoHTrM5bV005S5Tln2+7DjZ6p6HUzQ2q7WKsxl4Ts2XZZjewbfohrH+nw2cw8MVuVB7072417V9kS1z4zPN+tQ9D1vKzAdltePf5ulWswtvS6utLls5lZv57+wVt2hMuM7JkqtaVorxt3esN2aTaeU6i2q/KzXY4kcnq8lLddzXlP6tW8c/vqlf29enfr78ZFOblRt4tXrPrknR8iIiIKFDZ+iIiIKFDY+CEiIqJASdpU91iW9G+qyh07nWELkT123l5Uv7mIntfaz6f3U9ZTVeV39HLp/UPD3bjeXctUblyzd93YOwYsUqyxXf7pP6RXdtgnhD/yQR+Vyx4xzx7P5xifIKpTbbcqH+MZy3NMKX7lGrW5nSq3uXelGydmEYWyqWihHuv4Qr9ebvz46dVVbsgf3nPj6z1T4o/ElZPsOJ/GpTjGJ10VbdRLsZS7tokbD3izm8q9ftyHbrxw4ASbGKj32TvvUjdevPr0qMfuesL8qLlUwjs/REREFChs/BAREVGgpES3V1GeXh25dh5XSw6CSu9+58Y739W5G1vae65bOzZw41oLtqvtVvSucdTn0eBbvVpwxc/tKqQtd+suNz1hl8oCb1fXj10bqFzR5s1JPpuywczNcePac3VucrnL3Hj9zXPceESdeVH3d+Lbt6vyCVPtVOnCyI0p4QpXr3HjnZdUU7kzXr3GjUe1tdPge1bW11o1RV6vjODb5O3NVVl27y1+wyTgnR8iIiIKFDZ+iIiIKFDY+CEiIqJASYkxP0SRipYsd+ManjhyzE3TeYk/Nsf1HL2NP+ixNTip+O0S5bGtbd3468v0gITQJvt4htAejvE5WnUn2aVCvplkH010Gc6M+pqW0GPnOM6n9IR27lTlBn0WufH47vbRT6OOz1LbXTzE1vsj9eN7FMy4mb1UucWK0lt2hnd+iIiIKFDY+CEiIqJAYbcXESVc8wd0N8elo88pdru1d7RX5T2No3c6tp6w0Y1Dq/VTpRGyq3ubglV+T5OIPLJm2+6siOfCY77n4ey90B7xaIHUeboC7/wQERFRoLDxQ0RERIHCxg8REREFCsf8EFHiGaOL+/cXu1mjJ/w/0TvIT10nosTinR8iIiIKFDZ+iIiIKFDY+CEiIqJAYeOHiIiIAoWNHyIiIgoUNn6IiIgoUNj4ISIiokBh44eIiIgChY0fIiIiChQxESuxEhEREaUz3vkhIiKiQGHjh4iIiAKFjR8iIiIKFDZ+iIiIKFDY+CEiIqJAYeOHiIiIAuX/AW21zFHP4lndAAAAAElFTkSuQmCC",
|
428 |
+
"text/plain": [
|
429 |
+
"<Figure size 720x720 with 5 Axes>"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
"metadata": {
|
433 |
+
"needs_background": "light"
|
434 |
+
},
|
435 |
+
"output_type": "display_data"
|
436 |
+
}
|
437 |
+
],
|
438 |
+
"source": [
|
439 |
+
"logits = model(xbv)\n",
|
440 |
+
"probs = F.softmax(logits, dim=1)\n",
|
441 |
+
"idx = 5\n",
|
442 |
+
"_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
|
443 |
+
"for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
|
444 |
+
" ax.imshow(im)\n",
|
445 |
+
" ax.set_axis_off()\n",
|
446 |
+
" ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
|
447 |
+
" "
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": 164,
|
453 |
"metadata": {},
|
454 |
"outputs": [],
|
455 |
"source": [
|
|
|
465 |
},
|
466 |
{
|
467 |
"cell_type": "code",
|
468 |
+
"execution_count": 167,
|
469 |
"metadata": {
|
470 |
"tags": [
|
471 |
"exclude"
|
|
|
476 |
"name": "stdout",
|
477 |
"output_type": "stream",
|
478 |
"text": [
|
479 |
+
"tensor(4)\n"
|
480 |
]
|
481 |
},
|
482 |
{
|
483 |
"data": {
|
484 |
"text/plain": [
|
485 |
+
"[{'digit': 0, 'prob': '0.12%', 'logits': tensor(-1.1319)},\n",
|
486 |
+
" {'digit': 1, 'prob': '0.00%', 'logits': tensor(-4.7852)},\n",
|
487 |
+
" {'digit': 2, 'prob': '2.15%', 'logits': tensor(1.7912)},\n",
|
488 |
+
" {'digit': 3, 'prob': '0.07%', 'logits': tensor(-1.6584)},\n",
|
489 |
+
" {'digit': 4, 'prob': '97.03%', 'logits': tensor(5.5990)},\n",
|
490 |
+
" {'digit': 5, 'prob': '0.01%', 'logits': tensor(-3.5289)},\n",
|
491 |
+
" {'digit': 6, 'prob': '0.00%', 'logits': tensor(-4.4016)},\n",
|
492 |
+
" {'digit': 7, 'prob': '0.09%', 'logits': tensor(-1.3343)},\n",
|
493 |
+
" {'digit': 8, 'prob': '0.07%', 'logits': tensor(-1.6577)},\n",
|
494 |
+
" {'digit': 9, 'prob': '0.45%', 'logits': tensor(0.2194)}]"
|
495 |
]
|
496 |
},
|
497 |
+
"execution_count": 167,
|
498 |
"metadata": {},
|
499 |
"output_type": "execute_result"
|
500 |
}
|
501 |
],
|
502 |
"source": [
|
503 |
+
"img = xb[1].reshape(1, 28, 28)\n",
|
504 |
+
"print(yb[1])\n",
|
505 |
"predict(img)"
|
506 |
]
|
507 |
},
|
|
|
518 |
},
|
519 |
{
|
520 |
"cell_type": "code",
|
521 |
+
"execution_count": 168,
|
522 |
"metadata": {
|
523 |
"tags": [
|
524 |
"exclude"
|
|
|
530 |
"output_type": "stream",
|
531 |
"text": [
|
532 |
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
|
533 |
+
"[NbConvertApp] Writing 3187 bytes to mnist_classifier.py\n"
|
534 |
]
|
535 |
}
|
536 |
],
|
mnist_classifier.py
CHANGED
@@ -11,7 +11,6 @@ import matplotlib as mpl
|
|
11 |
import torchvision.transforms.functional as TF
|
12 |
from torch.utils.data import default_collate, DataLoader
|
13 |
import torch.optim as optim
|
14 |
-
import pickle
|
15 |
|
16 |
|
17 |
def transform_ds(b):
|
@@ -23,7 +22,7 @@ bs = 1024
|
|
23 |
class DataLoaders:
|
24 |
def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
|
25 |
self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
|
26 |
-
self.valid = DataLoader(valid_ds, batch_size=bs
|
27 |
|
28 |
def collate_fn(b):
|
29 |
collate = default_collate(b)
|
@@ -65,13 +64,24 @@ class ResBlock(nn.Module):
|
|
65 |
return self.act(self.convs(x) + self.idconv(self.pool(x)))
|
66 |
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def cnn_classifier():
|
69 |
return nn.Sequential(
|
70 |
-
ResBlock(1, 8,
|
71 |
-
ResBlock(8, 16,
|
72 |
-
ResBlock(16, 32,
|
73 |
-
ResBlock(32, 64,
|
74 |
-
ResBlock(64, 64,
|
75 |
conv(64, 10, act=False),
|
76 |
nn.Flatten(),
|
77 |
)
|
@@ -83,7 +93,7 @@ def kaiming_init(m):
|
|
83 |
|
84 |
|
85 |
loaded_model = cnn_classifier()
|
86 |
-
loaded_model.load_state_dict(torch.load('classifier.pth'))
|
87 |
loaded_model.eval();
|
88 |
|
89 |
|
|
|
11 |
import torchvision.transforms.functional as TF
|
12 |
from torch.utils.data import default_collate, DataLoader
|
13 |
import torch.optim as optim
|
|
|
14 |
|
15 |
|
16 |
def transform_ds(b):
|
|
|
22 |
class DataLoaders:
|
23 |
def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
|
24 |
self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
|
25 |
+
self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)
|
26 |
|
27 |
def collate_fn(b):
|
28 |
collate = default_collate(b)
|
|
|
64 |
return self.act(self.convs(x) + self.idconv(self.pool(x)))
|
65 |
|
66 |
|
67 |
+
# def cnn_classifier():
|
68 |
+
# return nn.Sequential(
|
69 |
+
# ResBlock(1, 8, norm=nn.BatchNorm2d(8)),
|
70 |
+
# ResBlock(8, 16, norm=nn.BatchNorm2d(16)),
|
71 |
+
# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
|
72 |
+
# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
|
73 |
+
# ResBlock(64, 64, norm=nn.BatchNorm2d(64)),
|
74 |
+
# conv(64, 10, act=False),
|
75 |
+
# nn.Flatten(),
|
76 |
+
# )
|
77 |
+
|
78 |
def cnn_classifier():
|
79 |
return nn.Sequential(
|
80 |
+
ResBlock(1, 8,),
|
81 |
+
ResBlock(8, 16, ),
|
82 |
+
ResBlock(16, 32,),
|
83 |
+
ResBlock(32, 64, ),
|
84 |
+
ResBlock(64, 64,),
|
85 |
conv(64, 10, act=False),
|
86 |
nn.Flatten(),
|
87 |
)
|
|
|
93 |
|
94 |
|
95 |
loaded_model = cnn_classifier()
|
96 |
+
loaded_model.load_state_dict(torch.load('classifier.pth'));
|
97 |
loaded_model.eval();
|
98 |
|
99 |
|
server.py
CHANGED
@@ -6,6 +6,7 @@ from fastapi.staticfiles import StaticFiles
|
|
6 |
from pathlib import Path
|
7 |
import torchvision.transforms as transforms
|
8 |
import mnist_classifier
|
|
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
|
|
6 |
from pathlib import Path
|
7 |
import torchvision.transforms as transforms
|
8 |
import mnist_classifier
|
9 |
+
import torch
|
10 |
|
11 |
app = FastAPI()
|
12 |
|