carlfeynman commited on
Commit
338bbe8
β€’
1 Parent(s): 6c5990e

removed batchnorm2d

Browse files
README.md CHANGED
@@ -1,7 +1,7 @@
1
  # <img src="/static/favicon.png" alt="Logo" style="float: left; margin-right: 10px; border-radius:100%;margin-top:5px" /> MNIST CLASSIFIER
2
  MNIST classifier from scratch
3
  * Model: CNN
4
- * Accuracy: 98%
5
 
6
  * Training Notebook: mnist_classifier.ipynb
7
  * Cleaned Python Inference Version: mnist_classifier.py
 
1
  # <img src="/static/favicon.png" alt="Logo" style="float: left; margin-right: 10px; border-radius:100%;margin-top:5px" /> MNIST CLASSIFIER
2
  MNIST classifier from scratch
3
  * Model: CNN
4
+ * Accuracy: 94%
5
 
6
  * Training Notebook: mnist_classifier.ipynb
7
  * Cleaned Python Inference Version: mnist_classifier.py
__pycache__/mnist_classifier.cpython-39.pyc CHANGED
Binary files a/__pycache__/mnist_classifier.cpython-39.pyc and b/__pycache__/mnist_classifier.cpython-39.pyc differ
 
__pycache__/server.cpython-39.pyc CHANGED
Binary files a/__pycache__/server.cpython-39.pyc and b/__pycache__/server.cpython-39.pyc differ
 
classifier.pth CHANGED
Binary files a/classifier.pth and b/classifier.pth differ
 
mnist_classifier.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 100,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -15,13 +15,12 @@
15
  "import matplotlib as mpl\n",
16
  "import torchvision.transforms.functional as TF\n",
17
  "from torch.utils.data import default_collate, DataLoader\n",
18
- "import torch.optim as optim\n",
19
- "import pickle\n"
20
  ]
21
  },
22
  {
23
  "cell_type": "code",
24
- "execution_count": null,
25
  "metadata": {
26
  "tags": [
27
  "exclude"
@@ -35,7 +34,7 @@
35
  },
36
  {
37
  "cell_type": "code",
38
- "execution_count": 101,
39
  "metadata": {
40
  "tags": [
41
  "exclude"
@@ -47,7 +46,7 @@
47
  "output_type": "stream",
48
  "text": [
49
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
50
- "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 35.54it/s]\n"
51
  ]
52
  }
53
  ],
@@ -59,7 +58,7 @@
59
  },
60
  {
61
  "cell_type": "code",
62
- "execution_count": 112,
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
@@ -70,40 +69,42 @@
70
  },
71
  {
72
  "cell_type": "code",
73
- "execution_count": null,
74
  "metadata": {
75
  "tags": [
76
  "exclude"
77
  ]
78
  },
79
- "outputs": [],
80
- "source": [
81
- "dst = ds.with_transform(transform_ds)\n",
82
- "plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
83
- ]
84
- },
85
- {
86
- "cell_type": "code",
87
- "execution_count": 103,
88
- "metadata": {},
89
  "outputs": [
90
  {
91
  "data": {
 
92
  "text/plain": [
93
- "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
94
  ]
95
  },
96
- "execution_count": 103,
97
- "metadata": {},
98
- "output_type": "execute_result"
 
99
  }
100
  ],
 
 
 
 
 
 
 
 
 
 
101
  "source": [
102
  "bs = 1024\n",
103
  "class DataLoaders:\n",
104
  " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
105
  " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
106
- " self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
107
  "\n",
108
  "def collate_fn(b):\n",
109
  " collate = default_collate(b)\n",
@@ -112,13 +113,24 @@
112
  },
113
  {
114
  "cell_type": "code",
115
- "execution_count": null,
116
  "metadata": {
117
  "tags": [
118
  "exclude"
119
  ]
120
  },
121
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
122
  "source": [
123
  "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
124
  "xb,yb = next(iter(dls.train))\n",
@@ -127,7 +139,7 @@
127
  },
128
  {
129
  "cell_type": "code",
130
- "execution_count": 105,
131
  "metadata": {},
132
  "outputs": [],
133
  "source": [
@@ -142,7 +154,7 @@
142
  },
143
  {
144
  "cell_type": "code",
145
- "execution_count": 106,
146
  "metadata": {},
147
  "outputs": [],
148
  "source": [
@@ -174,17 +186,28 @@
174
  },
175
  {
176
  "cell_type": "code",
177
- "execution_count": 107,
178
  "metadata": {},
179
  "outputs": [],
180
  "source": [
 
 
 
 
 
 
 
 
 
 
 
181
  "def cnn_classifier():\n",
182
  " return nn.Sequential(\n",
183
- " ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n",
184
- " ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n",
185
- " ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
186
- " ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
187
- " ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n",
188
  " conv(64, 10, act=False),\n",
189
  " nn.Flatten(),\n",
190
  " )"
@@ -192,7 +215,7 @@
192
  },
193
  {
194
  "cell_type": "code",
195
- "execution_count": 108,
196
  "metadata": {},
197
  "outputs": [],
198
  "source": [
@@ -203,7 +226,7 @@
203
  },
204
  {
205
  "cell_type": "code",
206
- "execution_count": 195,
207
  "metadata": {
208
  "tags": [
209
  "exclude"
@@ -214,16 +237,16 @@
214
  "name": "stdout",
215
  "output_type": "stream",
216
  "text": [
217
- "train, epoch:1, loss: 0.0776, accuracy: 0.9172\n",
218
- "eval, epoch:1, loss: 0.0372, accuracy: 0.9818\n",
219
- "train, epoch:2, loss: 0.0571, accuracy: 0.9828\n",
220
- "eval, epoch:2, loss: 0.0287, accuracy: 0.9863\n",
221
- "train, epoch:3, loss: 0.0425, accuracy: 0.9847\n",
222
- "eval, epoch:3, loss: 0.0256, accuracy: 0.9865\n",
223
- "train, epoch:4, loss: 0.0271, accuracy: 0.9868\n",
224
- "eval, epoch:4, loss: 0.0378, accuracy: 0.9826\n",
225
- "train, epoch:5, loss: 0.0395, accuracy: 0.9844\n",
226
- "eval, epoch:5, loss: 0.0307, accuracy: 0.9873\n"
227
  ]
228
  }
229
  ],
@@ -238,6 +261,7 @@
238
  "for epoch in range(epochs):\n",
239
  " for train in (True, False):\n",
240
  " accuracy = 0\n",
 
241
  " dl = dls.train if train else dls.valid\n",
242
  " for xb,yb in dl:\n",
243
  " preds = model(xb)\n",
@@ -248,15 +272,90 @@
248
  " opt.zero_grad()\n",
249
  " with torch.no_grad():\n",
250
  " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
 
251
  " if train:\n",
252
  " sched.step()\n",
253
  " accuracy /= len(dl)\n",
254
- " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")"
 
255
  ]
256
  },
257
  {
258
  "cell_type": "code",
259
- "execution_count": 196,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  "metadata": {
261
  "tags": [
262
  "exclude"
@@ -269,18 +368,88 @@
269
  },
270
  {
271
  "cell_type": "code",
272
- "execution_count": 197,
273
  "metadata": {},
274
  "outputs": [],
275
  "source": [
276
  "loaded_model = cnn_classifier()\n",
277
- "loaded_model.load_state_dict(torch.load('classifier.pth'))\n",
278
  "loaded_model.eval();"
279
  ]
280
  },
281
  {
282
  "cell_type": "code",
283
- "execution_count": 206,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  "metadata": {},
285
  "outputs": [],
286
  "source": [
@@ -296,7 +465,7 @@
296
  },
297
  {
298
  "cell_type": "code",
299
- "execution_count": 204,
300
  "metadata": {
301
  "tags": [
302
  "exclude"
@@ -307,32 +476,32 @@
307
  "name": "stdout",
308
  "output_type": "stream",
309
  "text": [
310
- "tensor(5)\n"
311
  ]
312
  },
313
  {
314
  "data": {
315
  "text/plain": [
316
- "[{'digit': 0, 'prob': '21.42%', 'logits': tensor(0.0559)},\n",
317
- " {'digit': 8, 'prob': '19.44%', 'logits': tensor(-0.0408)},\n",
318
- " {'digit': 4, 'prob': '18.08%', 'logits': tensor(-0.1135)},\n",
319
- " {'digit': 9, 'prob': '16.41%', 'logits': tensor(-0.2104)},\n",
320
- " {'digit': 6, 'prob': '12.23%', 'logits': tensor(-0.5049)},\n",
321
- " {'digit': 1, 'prob': '6.87%', 'logits': tensor(-1.0806)},\n",
322
- " {'digit': 7, 'prob': '2.33%', 'logits': tensor(-2.1633)},\n",
323
- " {'digit': 5, 'prob': '1.19%', 'logits': tensor(-2.8386)},\n",
324
- " {'digit': 2, 'prob': '1.06%', 'logits': tensor(-2.9527)},\n",
325
- " {'digit': 3, 'prob': '0.97%', 'logits': tensor(-3.0359)}]"
326
  ]
327
  },
328
- "execution_count": 204,
329
  "metadata": {},
330
  "output_type": "execute_result"
331
  }
332
  ],
333
  "source": [
334
- "img = xb[0].reshape(1, 28, 28)\n",
335
- "print(yb[0])\n",
336
  "predict(img)"
337
  ]
338
  },
@@ -349,7 +518,7 @@
349
  },
350
  {
351
  "cell_type": "code",
352
- "execution_count": 207,
353
  "metadata": {
354
  "tags": [
355
  "exclude"
@@ -361,7 +530,7 @@
361
  "output_type": "stream",
362
  "text": [
363
  "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
364
- "[NbConvertApp] Writing 2905 bytes to mnist_classifier.py\n"
365
  ]
366
  }
367
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
15
  "import matplotlib as mpl\n",
16
  "import torchvision.transforms.functional as TF\n",
17
  "from torch.utils.data import default_collate, DataLoader\n",
18
+ "import torch.optim as optim\n"
 
19
  ]
20
  },
21
  {
22
  "cell_type": "code",
23
+ "execution_count": 2,
24
  "metadata": {
25
  "tags": [
26
  "exclude"
 
34
  },
35
  {
36
  "cell_type": "code",
37
+ "execution_count": 3,
38
  "metadata": {
39
  "tags": [
40
  "exclude"
 
46
  "output_type": "stream",
47
  "text": [
48
  "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
49
+ "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:00<00:00, 112.23it/s]\n"
50
  ]
51
  }
52
  ],
 
58
  },
59
  {
60
  "cell_type": "code",
61
+ "execution_count": 4,
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
 
69
  },
70
  {
71
  "cell_type": "code",
72
+ "execution_count": 5,
73
  "metadata": {
74
  "tags": [
75
  "exclude"
76
  ]
77
  },
 
 
 
 
 
 
 
 
 
 
78
  "outputs": [
79
  {
80
  "data": {
81
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=",
82
  "text/plain": [
83
+ "<Figure size 144x144 with 1 Axes>"
84
  ]
85
  },
86
+ "metadata": {
87
+ "needs_background": "light"
88
+ },
89
+ "output_type": "display_data"
90
  }
91
  ],
92
+ "source": [
93
+ "dst = ds.with_transform(transform_ds)\n",
94
+ "plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 6,
100
+ "metadata": {},
101
+ "outputs": [],
102
  "source": [
103
  "bs = 1024\n",
104
  "class DataLoaders:\n",
105
  " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
106
  " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
107
+ " self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
108
  "\n",
109
  "def collate_fn(b):\n",
110
  " collate = default_collate(b)\n",
 
113
  },
114
  {
115
  "cell_type": "code",
116
+ "execution_count": 7,
117
  "metadata": {
118
  "tags": [
119
  "exclude"
120
  ]
121
  },
122
+ "outputs": [
123
+ {
124
+ "data": {
125
+ "text/plain": [
126
+ "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
127
+ ]
128
+ },
129
+ "execution_count": 7,
130
+ "metadata": {},
131
+ "output_type": "execute_result"
132
+ }
133
+ ],
134
  "source": [
135
  "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
136
  "xb,yb = next(iter(dls.train))\n",
 
139
  },
140
  {
141
  "cell_type": "code",
142
+ "execution_count": 147,
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
 
154
  },
155
  {
156
  "cell_type": "code",
157
+ "execution_count": 148,
158
  "metadata": {},
159
  "outputs": [],
160
  "source": [
 
186
  },
187
  {
188
  "cell_type": "code",
189
+ "execution_count": 149,
190
  "metadata": {},
191
  "outputs": [],
192
  "source": [
193
+ "# def cnn_classifier():\n",
194
+ "# return nn.Sequential(\n",
195
+ "# ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n",
196
+ "# ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n",
197
+ "# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
198
+ "# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
199
+ "# ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n",
200
+ "# conv(64, 10, act=False),\n",
201
+ "# nn.Flatten(),\n",
202
+ "# )\n",
203
+ "\n",
204
  "def cnn_classifier():\n",
205
  " return nn.Sequential(\n",
206
+ " ResBlock(1, 8,),\n",
207
+ " ResBlock(8, 16, ),\n",
208
+ " ResBlock(16, 32,),\n",
209
+ " ResBlock(32, 64, ),\n",
210
+ " ResBlock(64, 64,),\n",
211
  " conv(64, 10, act=False),\n",
212
  " nn.Flatten(),\n",
213
  " )"
 
215
  },
216
  {
217
  "cell_type": "code",
218
+ "execution_count": 150,
219
  "metadata": {},
220
  "outputs": [],
221
  "source": [
 
226
  },
227
  {
228
  "cell_type": "code",
229
+ "execution_count": 151,
230
  "metadata": {
231
  "tags": [
232
  "exclude"
 
237
  "name": "stdout",
238
  "output_type": "stream",
239
  "text": [
240
+ "train, epoch:1, loss: 1.3684, accuracy: 0.5153\n",
241
+ "eval, epoch:1, loss: 0.4238, accuracy: 0.8648\n",
242
+ "train, epoch:2, loss: 0.2660, accuracy: 0.9162\n",
243
+ "eval, epoch:2, loss: 0.1468, accuracy: 0.9552\n",
244
+ "train, epoch:3, loss: 0.1479, accuracy: 0.9545\n",
245
+ "eval, epoch:3, loss: 0.1101, accuracy: 0.9647\n",
246
+ "train, epoch:4, loss: 0.1149, accuracy: 0.9650\n",
247
+ "eval, epoch:4, loss: 0.0997, accuracy: 0.9705\n",
248
+ "train, epoch:5, loss: 0.2118, accuracy: 0.9399\n",
249
+ "eval, epoch:5, loss: 0.1625, accuracy: 0.9478\n"
250
  ]
251
  }
252
  ],
 
261
  "for epoch in range(epochs):\n",
262
  " for train in (True, False):\n",
263
  " accuracy = 0\n",
264
+ " total_loss = 0\n",
265
  " dl = dls.train if train else dls.valid\n",
266
  " for xb,yb in dl:\n",
267
  " preds = model(xb)\n",
 
272
  " opt.zero_grad()\n",
273
  " with torch.no_grad():\n",
274
  " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
275
+ " total_loss += loss.item()\n",
276
  " if train:\n",
277
  " sched.step()\n",
278
  " accuracy /= len(dl)\n",
279
+ " total_loss /= len(dl)\n",
280
+ " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
281
  ]
282
  },
283
  {
284
  "cell_type": "code",
285
+ "execution_count": 152,
286
+ "metadata": {
287
+ "tags": [
288
+ "exclude"
289
+ ]
290
+ },
291
+ "outputs": [
292
+ {
293
+ "name": "stdout",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "eval, epoch:1, loss: 0.1625, accuracy: 0.9478\n",
297
+ "eval, epoch:2, loss: 0.1625, accuracy: 0.9478\n",
298
+ "eval, epoch:3, loss: 0.1625, accuracy: 0.9478\n",
299
+ "eval, epoch:4, loss: 0.1625, accuracy: 0.9478\n",
300
+ "eval, epoch:5, loss: 0.1625, accuracy: 0.9478\n"
301
+ ]
302
+ }
303
+ ],
304
+ "source": [
305
+ "for epoch in range(epochs):\n",
306
+ " train = False\n",
307
+ " accuracy = 0\n",
308
+ " total_loss = 0\n",
309
+ " dl = dls.valid\n",
310
+ " for xb,yb in dl:\n",
311
+ " preds = model(xb)\n",
312
+ " loss = F.cross_entropy(preds, yb)\n",
313
+ " with torch.no_grad():\n",
314
+ " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
315
+ " total_loss += loss.item()\n",
316
+ " accuracy /= len(dl)\n",
317
+ " total_loss /= len(dl)\n",
318
+ " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {total_loss:.4f}, accuracy: {accuracy:.4f}\")"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": 153,
324
+ "metadata": {
325
+ "tags": [
326
+ "exclude"
327
+ ]
328
+ },
329
+ "outputs": [
330
+ {
331
+ "data": {
332
+ "image/png": "",
333
+ "text/plain": [
334
+ "<Figure size 720x720 with 5 Axes>"
335
+ ]
336
+ },
337
+ "metadata": {
338
+ "needs_background": "light"
339
+ },
340
+ "output_type": "display_data"
341
+ }
342
+ ],
343
+ "source": [
344
+ "xbv,ybv = next(iter(dls.train))\n",
345
+ "logits = model(xbv)\n",
346
+ "probs = F.softmax(logits, dim=1)\n",
347
+ "idx = 5\n",
348
+ "_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
349
+ "for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
350
+ " ax.imshow(im)\n",
351
+ " ax.set_axis_off()\n",
352
+ " ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
353
+ " "
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": 158,
359
  "metadata": {
360
  "tags": [
361
  "exclude"
 
368
  },
369
  {
370
  "cell_type": "code",
371
+ "execution_count": 159,
372
  "metadata": {},
373
  "outputs": [],
374
  "source": [
375
  "loaded_model = cnn_classifier()\n",
376
+ "loaded_model.load_state_dict(torch.load('classifier.pth'));\n",
377
  "loaded_model.eval();"
378
  ]
379
  },
380
  {
381
  "cell_type": "code",
382
+ "execution_count": 160,
383
+ "metadata": {
384
+ "tags": [
385
+ "exclude"
386
+ ]
387
+ },
388
+ "outputs": [
389
+ {
390
+ "data": {
391
+ "image/png": "",
392
+ "text/plain": [
393
+ "<Figure size 720x720 with 5 Axes>"
394
+ ]
395
+ },
396
+ "metadata": {
397
+ "needs_background": "light"
398
+ },
399
+ "output_type": "display_data"
400
+ }
401
+ ],
402
+ "source": [
403
+ "with torch.no_grad():\n",
404
+ " xbv,ybv = next(iter(dls.train))\n",
405
+ " logits = loaded_model(xbv)\n",
406
+ " probs = F.softmax(logits, dim=1)\n",
407
+ " idx = 5\n",
408
+ " _,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
409
+ " for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
410
+ " ax.imshow(im)\n",
411
+ " ax.set_axis_off()\n",
412
+ " ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
413
+ " "
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 161,
419
+ "metadata": {
420
+ "tags": [
421
+ "exclude"
422
+ ]
423
+ },
424
+ "outputs": [
425
+ {
426
+ "data": {
427
+ "image/png": "",
428
+ "text/plain": [
429
+ "<Figure size 720x720 with 5 Axes>"
430
+ ]
431
+ },
432
+ "metadata": {
433
+ "needs_background": "light"
434
+ },
435
+ "output_type": "display_data"
436
+ }
437
+ ],
438
+ "source": [
439
+ "logits = model(xbv)\n",
440
+ "probs = F.softmax(logits, dim=1)\n",
441
+ "idx = 5\n",
442
+ "_,axs = plt.subplots(1, idx, figsize=(10, 10))\n",
443
+ "for actual, pred, im, ax in zip(ybv[:idx], probs[:idx],xbv.permute(0,2,3,1)[:idx], axs.flat):\n",
444
+ " ax.imshow(im)\n",
445
+ " ax.set_axis_off()\n",
446
+ " ax.set_title(f'pred: {pred.argmax(0).item()}, actual:{actual.item()}')\n",
447
+ " "
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 164,
453
  "metadata": {},
454
  "outputs": [],
455
  "source": [
 
465
  },
466
  {
467
  "cell_type": "code",
468
+ "execution_count": 167,
469
  "metadata": {
470
  "tags": [
471
  "exclude"
 
476
  "name": "stdout",
477
  "output_type": "stream",
478
  "text": [
479
+ "tensor(4)\n"
480
  ]
481
  },
482
  {
483
  "data": {
484
  "text/plain": [
485
+ "[{'digit': 0, 'prob': '0.12%', 'logits': tensor(-1.1319)},\n",
486
+ " {'digit': 1, 'prob': '0.00%', 'logits': tensor(-4.7852)},\n",
487
+ " {'digit': 2, 'prob': '2.15%', 'logits': tensor(1.7912)},\n",
488
+ " {'digit': 3, 'prob': '0.07%', 'logits': tensor(-1.6584)},\n",
489
+ " {'digit': 4, 'prob': '97.03%', 'logits': tensor(5.5990)},\n",
490
+ " {'digit': 5, 'prob': '0.01%', 'logits': tensor(-3.5289)},\n",
491
+ " {'digit': 6, 'prob': '0.00%', 'logits': tensor(-4.4016)},\n",
492
+ " {'digit': 7, 'prob': '0.09%', 'logits': tensor(-1.3343)},\n",
493
+ " {'digit': 8, 'prob': '0.07%', 'logits': tensor(-1.6577)},\n",
494
+ " {'digit': 9, 'prob': '0.45%', 'logits': tensor(0.2194)}]"
495
  ]
496
  },
497
+ "execution_count": 167,
498
  "metadata": {},
499
  "output_type": "execute_result"
500
  }
501
  ],
502
  "source": [
503
+ "img = xb[1].reshape(1, 28, 28)\n",
504
+ "print(yb[1])\n",
505
  "predict(img)"
506
  ]
507
  },
 
518
  },
519
  {
520
  "cell_type": "code",
521
+ "execution_count": 168,
522
  "metadata": {
523
  "tags": [
524
  "exclude"
 
530
  "output_type": "stream",
531
  "text": [
532
  "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
533
+ "[NbConvertApp] Writing 3187 bytes to mnist_classifier.py\n"
534
  ]
535
  }
536
  ],
mnist_classifier.py CHANGED
@@ -11,7 +11,6 @@ import matplotlib as mpl
11
  import torchvision.transforms.functional as TF
12
  from torch.utils.data import default_collate, DataLoader
13
  import torch.optim as optim
14
- import pickle
15
 
16
 
17
  def transform_ds(b):
@@ -23,7 +22,7 @@ bs = 1024
23
  class DataLoaders:
24
  def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
25
  self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
26
- self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)
27
 
28
  def collate_fn(b):
29
  collate = default_collate(b)
@@ -65,13 +64,24 @@ class ResBlock(nn.Module):
65
  return self.act(self.convs(x) + self.idconv(self.pool(x)))
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  def cnn_classifier():
69
  return nn.Sequential(
70
- ResBlock(1, 8, norm=nn.BatchNorm2d(8)),
71
- ResBlock(8, 16, norm=nn.BatchNorm2d(16)),
72
- ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
73
- ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
74
- ResBlock(64, 64, norm=nn.BatchNorm2d(64)),
75
  conv(64, 10, act=False),
76
  nn.Flatten(),
77
  )
@@ -83,7 +93,7 @@ def kaiming_init(m):
83
 
84
 
85
  loaded_model = cnn_classifier()
86
- loaded_model.load_state_dict(torch.load('classifier.pth'))
87
  loaded_model.eval();
88
 
89
 
 
11
  import torchvision.transforms.functional as TF
12
  from torch.utils.data import default_collate, DataLoader
13
  import torch.optim as optim
 
14
 
15
 
16
  def transform_ds(b):
 
22
  class DataLoaders:
23
  def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
24
  self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
25
+ self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)
26
 
27
  def collate_fn(b):
28
  collate = default_collate(b)
 
64
  return self.act(self.convs(x) + self.idconv(self.pool(x)))
65
 
66
 
67
+ # def cnn_classifier():
68
+ # return nn.Sequential(
69
+ # ResBlock(1, 8, norm=nn.BatchNorm2d(8)),
70
+ # ResBlock(8, 16, norm=nn.BatchNorm2d(16)),
71
+ # ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
72
+ # ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
73
+ # ResBlock(64, 64, norm=nn.BatchNorm2d(64)),
74
+ # conv(64, 10, act=False),
75
+ # nn.Flatten(),
76
+ # )
77
+
78
  def cnn_classifier():
79
  return nn.Sequential(
80
+ ResBlock(1, 8,),
81
+ ResBlock(8, 16, ),
82
+ ResBlock(16, 32,),
83
+ ResBlock(32, 64, ),
84
+ ResBlock(64, 64,),
85
  conv(64, 10, act=False),
86
  nn.Flatten(),
87
  )
 
93
 
94
 
95
  loaded_model = cnn_classifier()
96
+ loaded_model.load_state_dict(torch.load('classifier.pth'));
97
  loaded_model.eval();
98
 
99
 
server.py CHANGED
@@ -6,6 +6,7 @@ from fastapi.staticfiles import StaticFiles
6
  from pathlib import Path
7
  import torchvision.transforms as transforms
8
  import mnist_classifier
 
9
 
10
  app = FastAPI()
11
 
 
6
  from pathlib import Path
7
  import torchvision.transforms as transforms
8
  import mnist_classifier
9
+ import torch
10
 
11
  app = FastAPI()
12